python人工智能tensorflow常見損失函數(shù)LOSS匯總
前言
損失函數(shù)在機(jī)器學(xué)習(xí)中用于表示預(yù)測(cè)值與真實(shí)值之間的差距。
一般而言,大多數(shù)機(jī)器學(xué)習(xí)模型都會(huì)通過(guò)一定的優(yōu)化器來(lái)減小損失函數(shù)從而達(dá)到優(yōu)化預(yù)測(cè)機(jī)器學(xué)習(xí)模型參數(shù)的目的。
哦豁,損失函數(shù)這么必要,那都存在什么損失函數(shù)呢?
一般常用的損失函數(shù)是均方差函數(shù)和交叉熵函數(shù)。
運(yùn)算公式
1 均方差函數(shù)
均方差函數(shù)主要用于評(píng)估回歸模型的使用效果,其概念相對(duì)簡(jiǎn)單,就是真實(shí)值與預(yù)測(cè)值差值的平方的均值,具體運(yùn)算公式可以表達(dá)如下:

其中f(xi?)是預(yù)測(cè)值,yi?是真實(shí)值。在二維圖像中,該函數(shù)代表每個(gè)散點(diǎn)到擬合曲線y軸距離的總和,非常直觀。
2 交叉熵函數(shù)
出自信息論中的一個(gè)概念,原來(lái)的含義是用來(lái)估算平均編碼長(zhǎng)度的。在機(jī)器學(xué)習(xí)領(lǐng)域中,其常常作為分類問(wèn)題的損失函數(shù)。

交叉熵函數(shù)是怎么工作的呢?假設(shè)在分類問(wèn)題中,被預(yù)測(cè)的物體只有是或者不是,預(yù)測(cè)值常常不是1或者0這樣絕對(duì)的預(yù)測(cè)結(jié)果,預(yù)測(cè)是常用的做法是將預(yù)測(cè)結(jié)果中大于0.5的當(dāng)作1,小于0.5的當(dāng)作0。
此時(shí)假設(shè)如果存在一個(gè)樣本,預(yù)測(cè)值接近于0,實(shí)際值卻是1,那么在交叉熵函數(shù)的前半部分:

其運(yùn)算結(jié)果會(huì)遠(yuǎn)遠(yuǎn)小于0,取符號(hào)后會(huì)遠(yuǎn)遠(yuǎn)大于0,導(dǎo)致該模型的損失函數(shù)巨大。通過(guò)減小交叉熵函數(shù)可以使得模型的預(yù)測(cè)精度大大提升。
tensorflow中損失函數(shù)的表達(dá)
1 均方差函數(shù)
loss = tf.reduce_mean(tf.square(logits-labels)) loss = tf.reduce_mean(tf.square(tf.sub(logits, labels))) loss = tf.losses.mean_squared_error(logits,labels)
2 交叉熵函數(shù)
loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=y,logits=logits) #計(jì)算方式:對(duì)輸入的logits先通過(guò)sigmoid函數(shù)計(jì)算,再計(jì)算它們的交叉熵 #但是它對(duì)交叉熵的計(jì)算方式進(jìn)行了優(yōu)化,使得結(jié)果不至于溢出。 loss = tf.nn.softmax_cross_entropy_with_logits(labels=y,logits=logits) #計(jì)算方式:對(duì)輸入的logits先通過(guò)softmax函數(shù)計(jì)算,再計(jì)算它們的交叉熵, #但是它對(duì)交叉熵的計(jì)算方式進(jìn)行了優(yōu)化,使得結(jié)果不至于溢出。
例子
1 均方差函數(shù)
這是一個(gè)一次函數(shù)擬合的例子。三個(gè)loss的意義相同。
import numpy as np
import tensorflow as tf
x_data = np.random.rand(100).astype(np.float32) #獲取隨機(jī)X值
y_data = x_data * 0.1 + 0.3 #計(jì)算對(duì)應(yīng)y值
Weights = tf.Variable(tf.random_uniform([1],-1.0,1.0)) #random_uniform返回[m,n]大小的矩陣,產(chǎn)生于low和high之間,產(chǎn)生的值是均勻分布的。
Biaxs = tf.Variable(tf.zeros([1])) #生成0
y = Weights*x_data + Biaxs
loss = tf.losses.mean_squared_error(y_data,y) #計(jì)算平方差
#loss = tf.reduce_mean(tf.square(y_data-y))
#loss = tf.reduce_mean(tf.square(tf.sub(y_data,y)))
optimizer = tf.train.GradientDescentOptimizer(0.6) #梯度下降法
train = optimizer.minimize(loss)
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
for i in range(200):
sess.run(train)
if i % 20 == 0:
print(sess.run(Weights),sess.run(Biaxs))
輸出結(jié)果為:
[0.10045234] [0.29975605]
[0.10010818] [0.2999417]
[0.10002586] [0.29998606]
[0.10000619] [0.29999667]
[0.10000149] [0.2999992]
2 交叉熵函數(shù)
這是一個(gè)Mnist手寫體識(shí)別的例子。兩個(gè)loss函數(shù)都可以進(jìn)行交叉熵運(yùn)算,在計(jì)算loss函數(shù)的時(shí)候中間經(jīng)過(guò)的函數(shù)不同。
import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data",one_hot = "true")
def add_layer(inputs,in_size,out_size,n_layer,activation_function = None):
layer_name = 'layer%s'%n_layer
with tf.name_scope(layer_name):
with tf.name_scope("Weights"):
Weights = tf.Variable(tf.random_normal([in_size,out_size]),name = "Weights")
tf.summary.histogram(layer_name+"/weights",Weights)
with tf.name_scope("biases"):
biases = tf.Variable(tf.zeros([1,out_size]) + 0.1,name = "biases")
tf.summary.histogram(layer_name+"/biases",biases)
with tf.name_scope("Wx_plus_b"):
Wx_plus_b = tf.matmul(inputs,Weights) + biases
tf.summary.histogram(layer_name+"/Wx_plus_b",Wx_plus_b)
if activation_function == None :
outputs = Wx_plus_b
else:
outputs = activation_function(Wx_plus_b)
tf.summary.histogram(layer_name+"/outputs",outputs)
return outputs
def compute_accuracy(x_data,y_data):
global prediction
y_pre = sess.run(prediction,feed_dict={xs:x_data})
correct_prediction = tf.equal(tf.arg_max(y_data,1),tf.arg_max(y_pre,1)) #判斷是否相等
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) #賦予float32數(shù)據(jù)類型,求平均。
result = sess.run(accuracy,feed_dict = {xs:batch_xs,ys:batch_ys}) #執(zhí)行
return result
xs = tf.placeholder(tf.float32,[None,784])
ys = tf.placeholder(tf.float32,[None,10])
layer1 = add_layer(xs,784,150,"layer1",activation_function = tf.nn.tanh)
prediction = add_layer(layer1,150,10,"layer2")
#由于loss函數(shù)在運(yùn)算的時(shí)候會(huì)自動(dòng)進(jìn)行softmax或者sigmoid函數(shù)的運(yùn)算,所以不需要特殊激勵(lì)函數(shù)。
with tf.name_scope("loss"):
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=ys,logits = prediction),name = 'loss')
#loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=ys,logits = prediction),name = 'loss')
#label是標(biāo)簽,logits是預(yù)測(cè)值,交叉熵。
tf.summary.scalar("loss",loss)
train = tf.train.AdamOptimizer(4e-3).minimize(loss)
init = tf.initialize_all_variables()
merged = tf.summary.merge_all()
with tf.Session() as sess:
sess.run(init)
write = tf.summary.FileWriter("logs/",sess.graph)
for i in range(5001):
batch_xs,batch_ys = mnist.train.next_batch(100)
sess.run(train,feed_dict = {xs:batch_xs,ys:batch_ys})
if i % 1000 == 0:
print("訓(xùn)練%d次的識(shí)別率為:%f。"%((i+1),compute_accuracy(mnist.test.images,mnist.test.labels)))
result = sess.run(merged,feed_dict={xs:batch_xs,ys:batch_ys})
write.add_summary(result,i)
輸出結(jié)果為
訓(xùn)練1次的識(shí)別率為:0.103100。
訓(xùn)練1001次的識(shí)別率為:0.900700。
訓(xùn)練2001次的識(shí)別率為:0.928100。
訓(xùn)練3001次的識(shí)別率為:0.938900。
訓(xùn)練4001次的識(shí)別率為:0.945600。
訓(xùn)練5001次的識(shí)別率為:0.952100。
以上就是python人工智能tensorflowf常見損失函數(shù)LOSS匯總的詳細(xì)內(nèi)容,更多關(guān)于tensorflowf損失函數(shù)LOSS的資料請(qǐng)關(guān)注腳本之家其它相關(guān)文章!
相關(guān)文章
Python callable內(nèi)置函數(shù)原理解析
這篇文章主要介紹了Python callable內(nèi)置函數(shù)原理解析,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2020-03-03
如何解決Python中ModuleNotFoundError錯(cuò)誤
使用模塊時(shí),了解它們的工作方式以及如何將它們導(dǎo)入我們的代碼非常重要,?如果沒(méi)有這種理解或錯(cuò)誤,我們可能會(huì)遇到不同的錯(cuò)誤,本文我們就來(lái)討論一下在Python中解決?ModuleNotFoundError?的方法,希望對(duì)大家有所幫助2023-12-12
MySQL中表的復(fù)制以及大型數(shù)據(jù)表的備份教程
這篇文章主要介紹了MySQL中表的復(fù)制以及大型數(shù)據(jù)表的備份教程,其中大表備份是采用添加觸發(fā)器增量備份的方法,需要的朋友可以參考下2015-11-11
python腳本實(shí)現(xiàn)查找webshell的方法
這篇文章主要介紹了python腳本實(shí)現(xiàn)查找webshell的方法,是很實(shí)用的一個(gè)功能,需要的朋友可以參考下2014-07-07
python實(shí)現(xiàn)電子產(chǎn)品商店
這篇文章主要為大家詳細(xì)介紹了python實(shí)現(xiàn)電子產(chǎn)品商店,具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2019-02-02
用Python實(shí)現(xiàn)職工信息管理系統(tǒng)
這篇文章主要介紹了用Python實(shí)現(xiàn)職工信息管理系統(tǒng),文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2020-12-12

