python深度學(xué)習(xí)之多標(biāo)簽分類器及pytorch實(shí)現(xiàn)源碼
多標(biāo)簽分類器
多標(biāo)簽分類任務(wù)與多分類任務(wù)有所不同,多分類任務(wù)是將一個(gè)實(shí)例分到某個(gè)類別中,多標(biāo)簽分類任務(wù)是將某個(gè)實(shí)例分到多個(gè)類別中。多標(biāo)簽分類任務(wù)有有兩大特點(diǎn):
- 類標(biāo)數(shù)量不確定,有些樣本可能只有一個(gè)類標(biāo),有些樣本的類標(biāo)可能高達(dá)幾十甚至上百個(gè)
- 類標(biāo)之間相互依賴,例如包含藍(lán)天類標(biāo)的樣本很大概率上包含白云
如下圖所示,即為一個(gè)多標(biāo)簽分類學(xué)習(xí)的一個(gè)例子,一張圖片里有多個(gè)類別,房子,樹(shù),云等,深度學(xué)習(xí)模型需要將其一一分類識(shí)別出來(lái)。

多標(biāo)簽分類器損失函數(shù)

代碼實(shí)現(xiàn)
針對(duì)圖像的多標(biāo)簽分類器pytorch的簡(jiǎn)化代碼實(shí)現(xiàn)如下所示。因?yàn)閳D像的多標(biāo)簽分類器的數(shù)據(jù)集比較難獲取,所以可以通過(guò)對(duì)mnist數(shù)據(jù)集中的每個(gè)圖片打上特定的多標(biāo)簽,例如類別1的多標(biāo)簽可以為[1,1,0,1,0,1,0,0,1],然后再利用重新打標(biāo)后的數(shù)據(jù)集訓(xùn)練出一個(gè)mnist的多標(biāo)簽分類器。
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import os
class CNN(nn.Module):
def __init__(self):
super().__init__()
self.Sq1 = nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, stride=1, padding=2), # (16, 28, 28) # output: (16, 28, 28)
nn.ReLU(),
nn.MaxPool2d(kernel_size=2), # (16, 14, 14)
)
self.Sq2 = nn.Sequential(
nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=2), # (32, 14, 14)
nn.ReLU(),
nn.MaxPool2d(2), # (32, 7, 7)
)
self.out = nn.Linear(32 * 7 * 7, 100)
def forward(self, x):
x = self.Sq1(x)
x = self.Sq2(x)
x = x.view(x.size(0), -1)
x = self.out(x)
## Sigmoid activation
output = F.sigmoid(x) # 1/(1+e**(-x))
return output
def loss_fn(pred, target):
return -(target * torch.log(pred) + (1 - target) * torch.log(1 - pred)).sum()
def multilabel_generate(label):
Y1 = F.one_hot(label, num_classes = 100)
Y2 = F.one_hot(label+10, num_classes = 100)
Y3 = F.one_hot(label+50, num_classes = 100)
multilabel = Y1+Y2+Y3
return multilabel
# def multilabel_generate(label):
# multilabel_dict = {}
# multi_list = []
# for i in range(label.shape[0]):
# multi_list.append(multilabel_dict[label[i].item()])
# multilabel_tensor = torch.tensor(multi_list)
# return multilabel
def train():
epoches = 10
mnist_net = CNN()
mnist_net.train()
opitimizer = optim.SGD(mnist_net.parameters(), lr=0.002)
mnist_train = datasets.MNIST("mnist-data", train=True, download=True, transform=transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(mnist_train, batch_size= 128, shuffle=True)
for epoch in range(epoches):
loss = 0
for batch_X, batch_Y in train_loader:
opitimizer.zero_grad()
outputs = mnist_net(batch_X)
loss = loss_fn(outputs, multilabel_generate(batch_Y)) / batch_X.shape[0]
loss.backward()
opitimizer.step()
print(loss)
if __name__ == '__main__':
train()
以上就是python深度學(xué)習(xí)之多標(biāo)簽分類器及pytorch源碼的詳細(xì)內(nèi)容,更多關(guān)于多標(biāo)簽分類器pytorch源碼的資料請(qǐng)關(guān)注腳本之家其它相關(guān)文章!
相關(guān)文章
python實(shí)現(xiàn)通過(guò)flask和前端進(jìn)行數(shù)據(jù)收發(fā)
今天小編就為大家分享一篇python實(shí)現(xiàn)通過(guò)flask和前端進(jìn)行數(shù)據(jù)收發(fā),具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2019-08-08
Python利用lxml庫(kù)實(shí)現(xiàn)XML處理
lxml庫(kù)是Python中處理XML和HTML文檔的強(qiáng)大庫(kù),提供了豐富的API以進(jìn)行各種操作,本文將討論如何使用lxml庫(kù),包括如何創(chuàng)建XML文檔,如何使用XPath查詢,以及如何解析大型XML文檔,需要的可以參考下2023-08-08
Python學(xué)習(xí)之路安裝pycharm的教程詳解
pycharm 是一款功能強(qiáng)大的 Python 編輯器,具有跨平臺(tái)性。這篇文章主要介紹了Python學(xué)習(xí)之路安裝pycharm的教程,本文分步驟通過(guò)圖文并茂的形式給大家介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2020-06-06
Python3二分查找?guī)旌瘮?shù)bisect(),bisect_left()和bisect_right()的區(qū)別
這篇文章主要介紹了Python3二分查找?guī)旌瘮?shù)bisect(),bisect_left()和bisect_right()的區(qū)別,本文通過(guò)示例代碼給大家介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2023-03-03
FP-growth算法發(fā)現(xiàn)頻繁項(xiàng)集——構(gòu)建FP樹(shù)
常見(jiàn)的挖掘頻繁項(xiàng)集算法有兩類,一類是Apriori算法,另一類是FP-growth。Apriori通過(guò)不斷的構(gòu)造候選集、篩選候選集挖掘出頻繁項(xiàng)集,需要多次掃描原始數(shù)據(jù),當(dāng)原始數(shù)據(jù)較大時(shí),磁盤I/O次數(shù)太多,效率比較低下2021-06-06
通過(guò)Py2exe將自己的python程序打包成.exe/.app的方法
這篇文章主要介紹了通過(guò)Py2exe將自己的python程序打包成.exe/.app的方法,需要的朋友可以參考下2018-05-05
Python如何向SQLServer存儲(chǔ)二進(jìn)制圖片
這篇文章主要介紹了Python如何向SQLServer存儲(chǔ)二進(jìn)制圖片,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2020-06-06

