Pytorch損失函數(shù)nn.NLLLoss2d()用法說明
最近做顯著星檢測用到了NLL損失函數(shù)
對于NLL函數(shù),需要自己計算log和softmax的概率值,然后從才能作為輸入
輸入 [batch_size, channel , h, w]

目標(biāo) [batch_size, h, w]
輸入的目標(biāo)矩陣,每個像素必須是類型.舉個例子。第一個像素是0,代表著類別屬于輸入的第1個通道;第二個像素是0,代表著類別屬于輸入的第0個通道,以此類推。
x = Variable(torch.Tensor([[[1, 2, 1],
[2, 2, 1],
[0, 1, 1]],
[[0, 1, 3],
[2, 3, 1],
[0, 0, 1]]]))
x = x.view([1, 2, 3, 3])
print("x輸入", x)
這里輸入x,并改成[batch_size, channel , h, w]的格式。
soft = nn.Softmax(dim=1)
log_soft = nn.LogSoftmax(dim=1)
然后使用softmax函數(shù)計算每個類別的概率,這里dim=1表示從在1維度
上計算,也就是channel維度。logsoftmax是計算完softmax后在計算log值

手動計算舉個栗子:第一個元素

y = Variable(torch.LongTensor([[1, 0, 1],
[0, 0, 1],
[1, 1, 1]]))
y = y.view([1, 3, 3])
輸入label y,改變成[batch_size, h, w]格式
loss = nn.NLLLoss2d() out = loss(x, y) print(out)
輸入函數(shù),得到loss=0.7947
來手動計算
第一個label=1,則 loss=-1.3133
第二個label=0, 則loss=-0.3133
. … … loss= -(-1.3133-0.3133-0.1269-0.6931-1.3133-0.6931-0.6931-1.3133-0.6931)/9 =0.7947222222222223
是一致的
注意:這個函數(shù)會對每個像素做平均,每個batch也會做平均,這里有9個像素,1個batch_size。
補充知識:PyTorch:NLLLoss2d
我就廢話不多說了,大家還是直接看代碼吧~
import torch
import torch.nn as nn
from torch import autograd
import torch.nn.functional as F
inputs_tensor = torch.FloatTensor([
[[2, 4],
[1, 2]],
[[5, 3],
[3, 0]],
[[5, 3],
[5, 2]],
[[4, 2],
[3, 2]],
])
inputs_tensor = torch.unsqueeze(inputs_tensor,0)
# inputs_tensor = torch.unsqueeze(inputs_tensor,1)
print '--input size(nBatch x nClasses x height x width): ', inputs_tensor.shape
targets_tensor = torch.LongTensor([
[0, 2],
[2, 3]
])
targets_tensor = torch.unsqueeze(targets_tensor,0)
print '--target size(nBatch x height x width): ', targets_tensor.shape
inputs_variable = autograd.Variable(inputs_tensor, requires_grad=True)
inputs_variable = F.log_softmax(inputs_variable)
targets_variable = autograd.Variable(targets_tensor)
loss = nn.NLLLoss2d()
output = loss(inputs_variable, targets_variable)
print '--NLLLoss2d: {}'.format(output)
以上這篇Pytorch損失函數(shù)nn.NLLLoss2d()用法說明就是小編分享給大家的全部內(nèi)容了,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關(guān)文章
Python3實現(xiàn)發(fā)送QQ郵件功能(附件)
這篇文章主要為大家詳細介紹了Python3實現(xiàn)發(fā)送QQ郵件功能,附件方面,文中示例代碼介紹的非常詳細,具有一定的參考價值,感興趣的小伙伴們可以參考一下2017-12-12
python+selenium定時爬取丁香園的新型冠狀病毒數(shù)據(jù)并制作出類似的地圖(部署到云服務(wù)器)
這篇文章主要介紹了python+selenium定時爬取丁香園的新冠病毒每天的數(shù)據(jù)并制作出類似的地圖(部署到云服務(wù)器),本文給大家介紹的非常詳細,具有一定的參考借鑒價值,需要的朋友可以參考下2020-02-02
Python selenium爬取微博數(shù)據(jù)代碼實例
這篇文章主要介紹了Python selenium爬取微博數(shù)據(jù)代碼實例,文中通過示例代碼介紹的非常詳細,對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友可以參考下2020-05-05
Django框架設(shè)置cookies與獲取cookies操作詳解
這篇文章主要介紹了Django框架設(shè)置cookies與獲取cookies操作,結(jié)合實例形式詳細分析了Django框架針對cookie操作的各種常見技巧與操作注意事項,需要的朋友可以參考下2019-05-05
Django ForeignKey與數(shù)據(jù)庫的FOREIGN KEY約束詳解
這篇文章主要介紹了Django ForeignKey與數(shù)據(jù)庫的FOREIGN KEY約束詳解,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2020-05-05
Python中JSON轉(zhuǎn)換的全面指南與最佳實踐
JSON是現(xiàn)代應(yīng)用程序中最流行的數(shù)據(jù)交換格式之一,Python通過內(nèi)置的json模塊提供了強大的JSON處理能力,本文將深入探討Python中的JSON轉(zhuǎn)換,包括基本用法、高級特性以及最佳實踐,需要的朋友可以參考下2025-03-03

