使用Pytorch如何完成多分類問題
Pytorch如何完成多分類
多分類問題在最后的輸出層采用的Softmax Layer,其具有兩個特點(diǎn):1.每個輸出的值都是在(0,1);2.所有值加起來和為1.
假設(shè)
是最后線性層的輸出,則對應(yīng)的Softmax function為:


輸出經(jīng)過sigmoid運(yùn)算即可是西安輸出的分類概率都大于0且總和為1。




上圖的交叉熵?fù)p失就包含了softmax計(jì)算和右邊的標(biāo)簽輸入計(jì)算(即框起來的部分)
所以在使用交叉熵?fù)p失的時候,神經(jīng)網(wǎng)絡(luò)的最后一層是不要做激活的,因?yàn)榘阉龀煞植嫉募せ钍前诮徊骒負(fù)p失里面的,最后一層不要做非線性變換,直接交給交叉熵?fù)p失。

如上圖,做交叉熵?fù)p失時要求y是一個長整型的張量,構(gòu)造時直接用
criterion = torch.nn.CrossEntropyLoss()

3個類別,分別是2,0,1
Y_pred1 ,Y_pred2還是線性輸出,沒經(jīng)過softmax,還不是概率分布,比如Y_pred1,0.9最大,表示對應(yīng)為第3個的概率最大,和2吻合,1.1最大,表示對應(yīng)為第1個的概率最大,和0吻合,2.1最大,表示對應(yīng)為第2個的概率最大,和1吻合,那么Y_pred1 的損失會比較小
對于Y_pred2,0.8最大,表示對應(yīng)為第1個的概率最大,和0不吻合,0.5最大,表示對應(yīng)為第3個的概率最大,和2不吻合,0.5最大,表示對應(yīng)為第3個的概率最大,和2不吻合,那么Y_pred2 的損失會比較大
Exercise 9-1: CrossEntropyLoss vs NLLLoss
What are the differences?
• Reading the document:
• https://pytorch.org/docs/stable/nn.html#crossentropyloss
• https://pytorch.org/docs/stable/nn.html#nllloss
• Try to know why:
• CrossEntropyLoss <==> LogSoftmax + NLLLoss


為什么要用transform
transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, )) ])
PyTorch讀圖像用的是python的imageLibrary,就是PIL,現(xiàn)在用的都是pillow,pillow讀進(jìn)來的圖像用神經(jīng)網(wǎng)絡(luò)處理的時候,神經(jīng)網(wǎng)絡(luò)有一個特點(diǎn)就是希望輸入的數(shù)值比較小,最好是在-1到+1之間,最好是輸入遵從正態(tài)分布,這樣的輸入對神經(jīng)網(wǎng)絡(luò)訓(xùn)練是最有幫助的

原始圖像是28*28的像素值在0到255之間,我們把它轉(zhuǎn)變成圖像張量,像素值是0到1
在視覺里面,灰度圖就是一個矩陣,但實(shí)際上并不是一個矩陣,我們把它叫做單通道圖像,彩色圖像是3通道,通道有寬度和高度,一般我們讀進(jìn)來的圖像張量是WHC(寬高通道)
在PyTorch里面我們需要轉(zhuǎn)化成CWH,把通道放在前面是為了在PyTorch里面進(jìn)行更高效的圖像處理,卷積運(yùn)算。所以拿到圖像之后,我們就把它先轉(zhuǎn)化成pytorch里面的一個Tensor,把0到255的值變成0到1的浮點(diǎn)數(shù),然后把維度由2828變成128*28的張量,由單通道變成多通道,
這個過程可以用transforms的ToTensor這個函數(shù)實(shí)現(xiàn)


歸一化
transforms.Normalize((0.1307, ), (0.3081, ))

這里的0.1307,0.3081是對Mnist數(shù)據(jù)集所有的像素求均值方差得到的
也就是說,將來拿到了圖像,先變成張量,然后Normalize,切換到0,1分布,然后供神經(jīng)網(wǎng)絡(luò)訓(xùn)練
如上圖,定義好transform變換之后,直接把它放到數(shù)據(jù)集里面,為什么要放在數(shù)據(jù)集里面呢,是為了在讀取第i個數(shù)據(jù)的時候,直接用transform處理

模型
輸入是一組圖像,激活層改用Relu
全連接神經(jīng)網(wǎng)絡(luò)要求輸入是一個矩陣
所以需要把輸入的張量變成一階的,這里的N表示有N個圖片

view函數(shù)可以改變張量的形狀,-1表示將來自動去算它的值是多少,比如輸入是n128*28
將來會自動把n算出來,輸入了張量就知道形狀,就知道有多少個數(shù)值


最后輸出是(N,10)因?yàn)槭怯?-9這10個標(biāo)簽嘛,10表示該圖像屬于某一個標(biāo)簽的概率,現(xiàn)在還是線性值,我們再用softmax把它變成概率

#沿著第一個維度找最大值的下標(biāo),返回值有兩個,因?yàn)槭?0列嘛,返回值一個是每一行的最大值,另一個是最大值的下標(biāo)(每一個樣本就是一行,每一行有10個量)(行是第0個維度,列是第1個維度)

MNIST數(shù)據(jù)集訓(xùn)練代碼
import torch
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim
# prepare dataset
batch_size = 64
transform = transforms.Compose([
transforms.ToTensor(), #先將圖像變換成一個張量tensor。
transforms.Normalize((0.1307,), (0.3081,))
#其中的0.1307是MNIST數(shù)據(jù)集的均值,0.3081是MNIST數(shù)據(jù)集的標(biāo)準(zhǔn)差。
]) # 歸一化,均值和方差
train_dataset = datasets.MNIST(root='../dataset/mnist/', train=True,
download=True, transform=transform)
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)
test_dataset = datasets.MNIST(root='../dataset/mnist/', train=False,
download=True, transform=transform)
test_loader = DataLoader(test_dataset, shuffle=False, batch_size=batch_size)
# design model using class
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.l1 = torch.nn.Linear(784, 512)
self.l2 = torch.nn.Linear(512, 256)
self.l3 = torch.nn.Linear(256, 128)
self.l4 = torch.nn.Linear(128, 64)
self.l5 = torch.nn.Linear(64, 10)
def forward(self, x):
# 28 * 28 = 784
# 784 = 28 * 28,即將N *1*28*28轉(zhuǎn)化成 N *1*784
x = x.view(-1, 784) # -1其實(shí)就是自動獲取mini_batch
x = F.relu(self.l1(x))
x = F.relu(self.l2(x))
x = F.relu(self.l3(x))
x = F.relu(self.l4(x))
return self.l5(x) # 最后一層不做激活,不進(jìn)行非線性變換
model = Net()
#CrossEntropyLoss <==> LogSoftmax + NLLLoss。
#也就是說使用CrossEntropyLoss最后一層(線性層)是不需要做其他變化的;
#使用NLLLoss之前,需要對最后一層(線性層)先進(jìn)行SoftMax處理,再進(jìn)行l(wèi)og操作。
# construct loss and optimizer
criterion = torch.nn.CrossEntropyLoss()
#momentum 是帶有優(yōu)化的一個訓(xùn)練過程參數(shù)
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
# training cycle forward, backward, update
def train(epoch):
running_loss = 0.0
#enumerate()函數(shù)用于將一個可遍歷的數(shù)據(jù)對象(如列表、元組或字符串)組合為一個索引序列,
#同時列出數(shù)據(jù)和數(shù)據(jù)下標(biāo),一般用在 for 循環(huán)當(dāng)中。
#enumerate(sequence, [start=0])
for batch_idx, data in enumerate(train_loader, 0):
# 獲得一個批次的數(shù)據(jù)和標(biāo)簽
inputs, target = data
optimizer.zero_grad()
#forward + backward + update
# 獲得模型預(yù)測結(jié)果(64, 10)
outputs = model(inputs)
# 交叉熵代價(jià)函數(shù)outputs(64,10),target(64)
loss = criterion(outputs, target)
loss.backward()
optimizer.step()
running_loss += loss.item()
if batch_idx % 300 == 299:
print('[%d, %5d] loss: %.3f' % (epoch + 1, batch_idx + 1, running_loss / 300))
running_loss = 0.0
def test():
correct = 0
total = 0
with torch.no_grad():#不需要計(jì)算梯度。
for data in test_loader:
images, labels = data
outputs = model(images)
#orch.max的返回值有兩個,第一個是每一行的最大值是多少,第二個是每一行最大值的下標(biāo)(索引)是多少。
_, predicted = torch.max(outputs.data, dim=1) # dim = 1 列是第0個維度,行是第1個維度
total += labels.size(0)
correct += (predicted == labels).sum().item() # 張量之間的比較運(yùn)算
print('accuracy on test set: %d %% ' % (100 * correct / total))
if __name__ == '__main__':
for epoch in range(10):
train(epoch)
test()總結(jié)
以上為個人經(jīng)驗(yàn),希望能給大家一個參考,也希望大家多多支持腳本之家。
相關(guān)文章
解決virtualenv -p python3 venv報(bào)錯的問題
這篇文章主要介紹了解決virtualenv -p python3 venv報(bào)錯的問題,具有很好的參考價(jià)值,希望對大家有所幫助。一起跟隨小編過來看看吧2021-02-02
Python正確調(diào)用 jar 包加密得到加密值的操作方法
這篇文章主要介紹了Python 正確調(diào)用 jar 包加密得到加密值的操作方法,本文給大家介紹的非常詳細(xì),對大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2021-08-08
詳解Python中的函數(shù)參數(shù)傳遞方法*args與**kwargs
本文將討論P(yáng)ython的函數(shù)參數(shù)。我們將了解args和kwargs,/和的都是什么,雖然這個問題是一個基本的python問題,但是在我們寫代碼時會經(jīng)常遇到,比如timm中就大量使用了這樣的參數(shù)傳遞方式2023-03-03
使用pyQT5顯示網(wǎng)頁的實(shí)現(xiàn)步驟
本文主要介紹了使用pyQT5顯示網(wǎng)頁的實(shí)現(xiàn)步驟,文中通過示例代碼介紹的非常詳細(xì),具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2021-10-10
Selenium?4.2.0?標(biāo)簽定位8種方法詳解
這篇文章主要介紹了Selenium?4.2.0?標(biāo)簽定位8種方法,本文給大家介紹的非常詳細(xì),對大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2022-06-06
python 裝飾器功能以及函數(shù)參數(shù)使用介紹
之前學(xué)習(xí)編程語言大多也就是學(xué)的很淺很淺,基本上也是很少涉及到裝飾器這些的類似的內(nèi)容??偸怯X得是一樣很神奇的東西,舍不得學(xué)(嘿嘿)。今天看了一下書籍。發(fā)現(xiàn)道理還是很簡單的2012-01-01

