在PyTorch中使用標(biāo)簽平滑正則化的問(wèn)題
什么是標(biāo)簽平滑?在PyTorch中如何去使用它?
在訓(xùn)練深度學(xué)習(xí)模型的過(guò)程中,過(guò)擬合和概率校準(zhǔn)(probability calibration)是兩個(gè)常見(jiàn)的問(wèn)題。一方面,正則化技術(shù)可以解決過(guò)擬合問(wèn)題,其中較為常見(jiàn)的方法有將權(quán)重調(diào)小,迭代提前停止以及丟棄一些權(quán)重等。另一方面,Platt標(biāo)度法和isotonic regression法能夠?qū)δP瓦M(jìn)行校準(zhǔn)。但是有沒(méi)有一種方法可以同時(shí)解決過(guò)擬合和模型過(guò)度自信呢?
標(biāo)簽平滑也許可以。它是一種去改變目標(biāo)變量的正則化技術(shù),能使模型的預(yù)測(cè)結(jié)果不再僅為一個(gè)確定值。標(biāo)簽平滑之所以被看作是一種正則化技術(shù),是因?yàn)樗梢苑乐馆斎氲絪oftmax函數(shù)的最大logits值變得特別大,從而使得分類(lèi)模型變得更加準(zhǔn)確。
在這篇文章中,我們定義了標(biāo)簽平滑化,在測(cè)試過(guò)程中我們將它應(yīng)用到交叉熵?fù)p失函數(shù)中。
標(biāo)簽平滑?
假設(shè)這里有一個(gè)多分類(lèi)問(wèn)題,在這個(gè)問(wèn)題中,目標(biāo)變量通常是一個(gè)one-hot向量,即當(dāng)處于正確分類(lèi)時(shí)結(jié)果為1,否則結(jié)果是0。
標(biāo)簽平滑改變了目標(biāo)向量的最小值,使它為ε。因此,當(dāng)模型進(jìn)行分類(lèi)時(shí),其結(jié)果不再僅是1或0,而是我們所要求的1-ε和ε,從而帶標(biāo)簽平滑的交叉熵?fù)p失函數(shù)為如下公式。

在這個(gè)公式中,ce(x)表示x的標(biāo)準(zhǔn)交叉熵?fù)p失函數(shù),例如:-log(p(x)),ε是一個(gè)非常小的正數(shù),i表示對(duì)應(yīng)的正確分類(lèi),N為所有分類(lèi)的數(shù)量。
直觀(guān)上看,標(biāo)記平滑限制了正確類(lèi)的logit值,并使得它更接近于其他類(lèi)的logit值。從而在一定程度上,它被當(dāng)作為一種正則化技術(shù)和一種對(duì)抗模型過(guò)度自信的方法。
PyTorch中的使用
在PyTorch中,帶標(biāo)簽平滑的交叉熵?fù)p失函數(shù)實(shí)現(xiàn)起來(lái)非常簡(jiǎn)單。首先,讓我們使用一個(gè)輔助函數(shù)來(lái)計(jì)算兩個(gè)值之間的線(xiàn)性組合。
deflinear_combination(x, y, epsilon):return epsilon*x + (1-epsilon)*y
下一步,我們使用PyTorch中一個(gè)全新的損失函數(shù):nn.Module.
import torch.nn.functional as F
defreduce_loss(loss, reduction='mean'):return loss.mean() if reduction=='mean'else loss.sum() if reduction=='sum'else loss
classLabelSmoothingCrossEntropy(nn.Module):def__init__(self, epsilon:float=0.1, reduction='mean'):
super().__init__()
self.epsilon = epsilon
self.reduction = reduction
defforward(self, preds, target):
n = preds.size()[-1]
log_preds = F.log_softmax(preds, dim=-1)
loss = reduce_loss(-log_preds.sum(dim=-1), self.reduction)
nll = F.nll_loss(log_preds, target, reduction=self.reduction)
return linear_combination(loss/n, nll, self.epsilon)
我們現(xiàn)在可以在代碼中刪除這個(gè)類(lèi)。對(duì)于這個(gè)例子,我們使用標(biāo)準(zhǔn)的fast.ai pets example.
from fastai.vision import *
from fastai.metrics import error_rate
# prepare the data
path = untar_data(URLs.PETS)
path_img = path/'images'
fnames = get_image_files(path_img)
bs = 64
np.random.seed(2)
pat = r'/([^/]+)_\d+.jpg$'
data = ImageDataBunch.from_name_re(path_img, fnames, pat, ds_tfms=get_transforms(), size=224, bs=bs) \
.normalize(imagenet_stats)
# train the model
learn = cnn_learner(data, models.resnet34, metrics=error_rate)
learn.loss_func = LabelSmoothingCrossEntropy()
learn.fit_one_cycle(4)
最后將數(shù)據(jù)轉(zhuǎn)換成模型可以使用的格式,選擇ResNet架構(gòu)并以帶標(biāo)簽平滑的交叉熵?fù)p失函數(shù)作為優(yōu)化目標(biāo)。經(jīng)過(guò)四輪循環(huán)后,其結(jié)果如下

我們所得結(jié)果的錯(cuò)誤率僅為7.5%,這對(duì)于10行左右的代碼來(lái)說(shuō)是完全可以接受的,并且在模型中大多數(shù)參數(shù)還都選擇的是默認(rèn)設(shè)置。
因此,在模型中還有許多參數(shù)可以進(jìn)行調(diào)整,從而使得模型的表現(xiàn)性能更好,例如:可以使用不同的優(yōu)化器、超參數(shù)、模型架構(gòu)等。
結(jié)論
在這篇文章中,我們了解了什么是標(biāo)簽平滑以及什么時(shí)候去使用它,并且我們還知道了如何在PyTorch中實(shí)現(xiàn)它。之后,我們訓(xùn)練了一個(gè)先進(jìn)的計(jì)算機(jī)視覺(jué)模型,僅使用十行代碼就識(shí)別出了不同品種的貓和狗。
模型正則化和模型校準(zhǔn)是兩個(gè)重要的概念。若想成為一個(gè)深度學(xué)習(xí)的資深玩家,就應(yīng)該好好地去理解這些能夠?qū)惯^(guò)擬合和模型過(guò)度自信的工具。
作者簡(jiǎn)介: Dimitris Poulopoulos,是BigDataStack的一名機(jī)器學(xué)習(xí)研究員,同時(shí)也是希臘Piraeus大學(xué)的博士。曾為歐盟委員會(huì)、歐盟統(tǒng)計(jì)局、國(guó)際貨幣基金組織、歐洲央行等客戶(hù)設(shè)計(jì)過(guò)與AI相關(guān)的軟件。
總結(jié)
到此這篇關(guān)于如何在PyTorch中使用標(biāo)簽平滑正則化的文章就介紹到這了,更多相關(guān)PyTorch正則化內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
python數(shù)據(jù)可視化Seaborn繪制山脊圖
這篇文章主要介紹了利用python數(shù)據(jù)可視化Seaborn繪制山脊圖,山脊圖一般由垂直堆疊的折線(xiàn)圖組成,這些折線(xiàn)圖中的折線(xiàn)區(qū)域間彼此重疊,此外它們還共享相同的x軸.下面來(lái)看看具體的繪制過(guò)程吧,需要的小伙伴可以參考一下2022-01-01
Python使用reportlab將目錄下所有的文本文件打印成pdf的方法
這篇文章主要介紹了Python使用reportlab將目錄下所有的文本文件打印成pdf的方法,涉及reportlab模塊操作pdf文件的相關(guān)技巧,需要的朋友可以參考下2015-05-05
對(duì)python遍歷文件夾中的所有jpg文件的實(shí)例詳解
今天小編就為大家分享一篇對(duì)python遍歷文件夾中的所有jpg文件的實(shí)例詳解,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2018-12-12
Python中的bytes類(lèi)型用法及實(shí)例分享
這篇文章主要介紹了Python中的bytes類(lèi)型及其用法,Python?bytes?類(lèi)型用來(lái)表示一個(gè)字節(jié)串,bytes?只負(fù)責(zé)以字節(jié)序列的形式來(lái)存儲(chǔ)數(shù)據(jù),下面對(duì)其的相關(guān)內(nèi)容介紹,需要的小伙伴可以參考一下2022-03-03
使用Python快速實(shí)現(xiàn)文件共享并通過(guò)內(nèi)網(wǎng)穿透技術(shù)公網(wǎng)訪(fǎng)問(wèn)
數(shù)據(jù)共享作為和連接作為互聯(lián)網(wǎng)的基礎(chǔ)應(yīng)用,不僅在商業(yè)和辦公場(chǎng)景有廣泛的應(yīng)用,對(duì)于個(gè)人用戶(hù)也有很強(qiáng)的實(shí)用意義,今天,筆者就為大家介紹,如何使用python這樣的簡(jiǎn)單程序語(yǔ)言,在自己的電腦上搭建一個(gè)共享文件服務(wù)器,需要的朋友可以參考下2023-10-10
python機(jī)器學(xué)習(xí)MATLAB最小二乘法的兩種解讀
這篇文章主要為大家介紹了python機(jī)器學(xué)習(xí)中MATLAB最小二乘法的兩種解讀方式,有需要的朋友可以借鑒參考下希望能夠有所幫助2022-02-02
Python線(xiàn)程池模塊ThreadPoolExecutor用法分析
這篇文章主要介紹了Python線(xiàn)程池模塊ThreadPoolExecutor用法,結(jié)合實(shí)例形式分析了Python線(xiàn)程池模塊ThreadPoolExecutor的導(dǎo)入與基本使用方法,需要的朋友可以參考下2018-12-12

