基于MSELoss()與CrossEntropyLoss()的區(qū)別詳解
基于pytorch來(lái)講
MSELoss()多用于回歸問題,也可以用于one_hotted編碼形式,
CrossEntropyLoss()名字為交叉熵?fù)p失函數(shù),不用于one_hotted編碼形式
MSELoss()要求batch_x與batch_y的tensor都是FloatTensor類型
CrossEntropyLoss()要求batch_x為Float,batch_y為L(zhǎng)ongTensor類型
(1)CrossEntropyLoss() 舉例說明:
比如二分類問題,最后一層輸出的為2個(gè)值,比如下面的代碼:
class CNN (nn.Module ) :
def __init__ ( self , hidden_size1 , output_size , dropout_p) :
super ( CNN , self ).__init__ ( )
self.hidden_size1 = hidden_size1
self.output_size = output_size
self.dropout_p = dropout_p
self.conv1 = nn.Conv1d ( 1,8,3,padding =1)
self.fc1 = nn.Linear (8*500, self.hidden_size1 )
self.out = nn.Linear (self.hidden_size1,self.output_size )
def forward ( self , encoder_outputs ) :
cnn_out = F.max_pool1d ( F.relu (self.conv1(encoder_outputs)),2)
cnn_out = F.dropout ( cnn_out ,self.dropout_p) #加一個(gè)dropout
cnn_out = cnn_out.view (-1,8*500)
output_1 = torch.tanh ( self.fc1 ( cnn_out ) )
output = self.out ( ouput_1)
return output
最后的輸出結(jié)果為:

上面一個(gè)tensor為output結(jié)果,下面為target,沒有使用one_hotted編碼。
訓(xùn)練過程如下:
cnn_optimizer = torch.optim.SGD(cnn.parameters(),learning_rate,momentum=0.9,\
weight_decay=1e-5)
criterion = nn.CrossEntropyLoss()
def train ( input_variable , target_variable , cnn , cnn_optimizer , criterion ) :
cnn_output = cnn( input_variable )
print(cnn_output)
print(target_variable)
loss = criterion ( cnn_output , target_variable)
cnn_optimizer.zero_grad ()
loss.backward( )
cnn_optimizer.step( )
#print('loss: ',loss.item())
return loss.item() #返回?fù)p失
說明CrossEntropyLoss()是output兩位為one_hotted編碼形式,但target不是one_hotted編碼形式。
(2)MSELoss() 舉例說明:
網(wǎng)絡(luò)結(jié)構(gòu)不變,但是標(biāo)簽是one_hotted編碼形式。下面的圖僅做說明,網(wǎng)絡(luò)結(jié)構(gòu)不太對(duì),出來(lái)的預(yù)測(cè)也不太對(duì)。

如果target不是one_hotted編碼形式會(huì)報(bào)錯(cuò),報(bào)的錯(cuò)誤如下。

目前自己理解的兩者的區(qū)別,就是這樣的,至于多分類問題是不是也是樣的有待考察。
以上這篇基于MSELoss()與CrossEntropyLoss()的區(qū)別詳解就是小編分享給大家的全部?jī)?nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
利用Python批量循環(huán)讀取Excel的技巧分享
這篇文章主要為大家詳細(xì)介紹了何用Python批量循環(huán)讀取Excel,文中的示例代碼講解詳細(xì),對(duì)我們的學(xué)習(xí)或工作有一定的幫助,感興趣的可以了解一下2023-07-07
matplotlib繪制餅圖的基本配置(萬(wàn)能模板案例)
餅圖是常見的一種圖表形式,本文主要介紹了matplotlib繪制餅圖的基本配置(萬(wàn)能模板案例),文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2022-04-04
keras小技巧——獲取某一個(gè)網(wǎng)絡(luò)層的輸出方式
這篇文章主要介紹了keras小技巧——獲取某一個(gè)網(wǎng)絡(luò)層的輸出方式,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來(lái)看看吧2020-05-05
Python 利用4行代碼實(shí)現(xiàn)圖片灰度化的項(xiàng)目實(shí)踐
灰度處理是將彩色圖像轉(zhuǎn)換為灰度圖像的過程,即每個(gè)像素的顏色由紅、綠、藍(lán)三個(gè)通道的值組成,轉(zhuǎn)換為一個(gè)單一的灰度值,本文主要介紹了Python 利用4行代碼實(shí)現(xiàn)圖片灰度化的項(xiàng)目實(shí)踐,感興趣的可以了解一下2024-04-04
python 計(jì)算方位角實(shí)例(根據(jù)兩點(diǎn)的坐標(biāo)計(jì)算)
今天小編就為大家分享一篇python 計(jì)算方位角實(shí)例(根據(jù)兩點(diǎn)的坐標(biāo)計(jì)算),具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來(lái)看看吧2020-01-01
python2.7安裝opencv-python很慢且總是失敗問題
這篇文章主要介紹了python2.7安裝opencv-python很慢且總是失敗問題,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2023-02-02
python報(bào)錯(cuò)解決之python運(yùn)行bat文件的各種問題處理
這篇文章主要介紹了python報(bào)錯(cuò)解決之python運(yùn)行bat文件的各種問題處理,具有很好的參考價(jià)值,希望對(duì)大家有所幫助,如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2024-06-06

