pytorch交叉熵損失函數(shù)的weight參數(shù)的使用
首先
必須將權(quán)重也轉(zhuǎn)為Tensor的cuda格式;
然后
將該class_weight作為交叉熵函數(shù)對應(yīng)參數(shù)的輸入值。
class_weight = torch.FloatTensor([0.13859937, 0.5821059, 0.63871904, 2.30220396, 7.1588294, 0]).cuda()
補充:關(guān)于pytorch的CrossEntropyLoss的weight參數(shù)
首先這個weight參數(shù)比想象中的要考慮的多
你可以試試下面代碼
import torch import torch.nn as nn inputs = torch.FloatTensor([0,1,0,0,0,1]) outputs = torch.LongTensor([0,1]) inputs = inputs.view((1,3,2)) outputs = outputs.view((1,2)) weight_CE = torch.FloatTensor([1,1,1]) ce = nn.CrossEntropyLoss(ignore_index=255,weight=weight_CE) loss = ce(inputs,outputs) print(loss)
tensor(1.4803)
這里的手動計算是:
loss1 = 0 + ln(e0 + e0 + e0) = 1.098
loss2 = 0 + ln(e1 + e0 + e1) = 1.86
求平均 = (loss1 *1 + loss2 *1)/ 2 = 1.4803
加權(quán)呢?
import torch import torch.nn as nn inputs = torch.FloatTensor([0,1,0,0,0,1]) outputs = torch.LongTensor([0,1]) inputs = inputs.view((1,3,2)) outputs = outputs.view((1,2)) weight_CE = torch.FloatTensor([1,2,3]) ce = nn.CrossEntropyLoss(ignore_index=255,weight=weight_CE) loss = ce(inputs,outputs) print(loss)
tensor(1.6075)
手算發(fā)現(xiàn),并不是單純的那權(quán)重相乘:
loss1 = 0 + ln(e0 + e0 + e0) = 1.098
loss2 = 0 + ln(e1 + e0 + e1) = 1.86
求平均 = (loss1 * 1 + loss2 * 2)/ 2 = 2.4113
而是
loss1 = 0 + ln(e0 + e0 + e0) = 1.098
loss2 = 0 + ln(e1 + e0 + e1) = 1.86
求平均 = (loss1 *1 + loss2 *2) / 3 = 1.6075
發(fā)現(xiàn)了么,加權(quán)后,除以的是權(quán)重的和,不是數(shù)目的和。
我們再驗證一遍:
import torch import torch.nn as nn inputs = torch.FloatTensor([0,1,2,0,0,0,0,0,0,1,0,0.5]) outputs = torch.LongTensor([0,1,2,2]) inputs = inputs.view((1,3,4)) outputs = outputs.view((1,4)) weight_CE = torch.FloatTensor([1,2,3]) ce = nn.CrossEntropyLoss(weight=weight_CE) # ce = nn.CrossEntropyLoss(ignore_index=255) loss = ce(inputs,outputs) print(loss)
tensor(1.5472)
手算:
loss1 = 0 + ln(e0 + e0 + e0) = 1.098
loss2 = 0 + ln(e1 + e0 + e1) = 1.86
loss3 = 0 + ln(e2 + e0 + e0) = 2.2395
loss4 = -0.5 + ln(e0.5 + e0 + e0) = 0.7943
求平均 = (loss1 * 1 + loss2 * 2+loss3 * 3+loss4 * 3) / 9 = 1.5472
可能有人對loss的CE計算過程有疑問,我這里細致寫寫交叉熵的計算過程,就拿最后一個例子的loss4的計算說明

以上為個人經(jīng)驗,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關(guān)文章
關(guān)于tensorflow中tf.keras.models.Sequential()的用法
這篇文章主要介紹了關(guān)于tensorflow中tf.keras.models.Sequential()的用法,具有很好的參考價值,希望對大家有所幫助,如有錯誤或未考慮完全的地方,望不吝賜教2024-01-01
Pandas數(shù)據(jù)操作分析基本常用的15個代碼片段
這篇文章主要介紹了Pandas數(shù)據(jù)操作分析基本常用的15個代碼片段,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進步,早日升職加薪2023-09-09
flask使用session保存登錄狀態(tài)及攔截未登錄請求代碼
這篇文章主要介紹了flask使用session保存登錄狀態(tài)及攔截未登錄請求代碼,具有一定借鑒價值,需要的朋友可以參考下2018-01-01
Python標(biāo)準(zhǔn)庫shutil模塊使用方法解析
這篇文章主要介紹了Python標(biāo)準(zhǔn)庫shutil模塊使用方法解析,文中通過示例代碼介紹的非常詳細,對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友可以參考下2020-03-03
Python_查看sqlite3表結(jié)構(gòu),查詢語句的示例代碼
今天小編就為大家分享一篇Python_查看sqlite3表結(jié)構(gòu),查詢語句的示例代碼,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2019-07-07
python如何實現(xiàn)wifi自動連接,解決電腦wifi經(jīng)常斷開問題
這篇文章主要介紹了python實現(xiàn)wifi自動連接,解決電腦wifi經(jīng)常斷開的問題,具有很好的參考價值,希望對大家有所幫助。如有錯誤或未考慮完全的地方,望不吝賜教2023-06-06
解決Pycharm 中遇到Unresolved reference ''sklearn''的問題
這篇文章主要介紹了解決Pycharm 中遇到Unresolved reference 'sklearn'的問題,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2020-07-07

