python神經(jīng)網(wǎng)絡(luò)pytorch中BN運(yùn)算操作自實(shí)現(xiàn)
BN 想必大家都很熟悉,來自論文:
《Batch Normalization Accelerating Deep Network Training by Reducing Internal Covariate Shift》
也是面試??疾斓膬?nèi)容,雖然一行代碼就能搞定,但是還是很有必要用代碼自己實(shí)現(xiàn)一下,也可以加深一下對其內(nèi)部機(jī)制的理解。
通用公式:

直奔代碼:
首先是定義一個函數(shù),實(shí)現(xiàn)BN的運(yùn)算操作:
def batch_norm(is_training, x, gamma, beta, moving_mean, moving_var, eps=1e-5, momentum=0.9):
# 判斷當(dāng)前模式是訓(xùn)練模式還是預(yù)測模式
if not is_training:
# 如果是在預(yù)測模式下,直接使用傳入的移動平均所得的均值和方差
x_hat = (x - moving_mean) / torch.sqrt(moving_var + eps)
else:
if len(x.shape) == 2:
# 使用全連接層的情況,計(jì)算特征維上的均值和方差
mean = x.mean(dim=0)
var = ((x - mean) ** 2).mean(dim=0)
else:
# 使用二維卷積層的情況,計(jì)算通道維上(axis=1)的均值和方差。這里我們需要保持
# x的形狀以便后面可以做廣播運(yùn)算
mean = x.mean(dim=0, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True)
var = ((x - mean) ** 2).mean(dim=0, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True)
# 訓(xùn)練模式下用當(dāng)前的均值和方差做標(biāo)準(zhǔn)化
x_hat = (x - mean) / torch.sqrt(var + eps)
# 更新移動平均的均值和方差
moving_mean = momentum * moving_mean + (1.0 - momentum) * mean
moving_var = momentum * moving_var + (1.0 - momentum) * var
x = gamma * x_hat + beta # 拉伸和偏移
return Y, moving_mean, moving_var然后再定義一個類,就是常用的集成nn.Module的類了。
這里說明三點(diǎn):
- 在卷積上的BN實(shí)現(xiàn),是在 Batch,W,H上進(jìn)行歸一化操作的,也就是BWH拉成一個維度求均值和方差,均值和方差以及beta和gamma的尺寸為channel。當(dāng)然其他各種N,包括IN,LN,GN都是包含WH維度的。
- 不需要計(jì)算梯度和參與梯度更新的參數(shù),可以用self.register_buffer直接注冊就可以了;注冊的變量同樣使用;
- 被包成nn.Parameter的參數(shù),需要求梯度,但是不能加cuda(),否則會報錯。 如果想在gpu上運(yùn)算,可以將整個類的實(shí)例加.cuda()。例如 bn = BatchNorm(**param),bn=bn.cuda().
class BatchNorm(nn.Module):
def __init__(self, num_features, num_dims):
super(BatchNorm, self).__init__()
if num_dims == 2: # 同樣是判斷是全連層還是卷積層
shape = (1, num_features)
else:
shape = (1, num_features, 1, 1)
# 參與求梯度和迭代的拉伸和偏移參數(shù),分別初始化成0和1
self.gamma = nn.Parameter(torch.ones(shape))
self.beta = nn.Parameter(torch.zeros(shape))
# 不參與求梯度和迭代的變量,全初始化成0
self.register_buffer('moving_mean', torch.zeros(shape))
self.register_buffer('moving_var', torch.ones(shape))
def forward(self, x):
# 如果X不在內(nèi)存上,將moving_mean和moving_var復(fù)制到X所在顯存上
if self.moving_mean.device != x.device:
self.moving_mean = self.moving_mean.to(X.device)
self.moving_var = self.moving_var.to(X.device)
# 保存更新過的moving_mean和moving_var, Module實(shí)例的traning屬性默認(rèn)為true, 調(diào)用.eval()后設(shè)成false
y, self.moving_mean, self.moving_var = batch_norm(self.training,
x, self.gamma, self.beta, self.moving_mean,
self.moving_var, eps=1e-5, momentum=0.9)
return x以上就是python神經(jīng)網(wǎng)絡(luò)pytorch中BN運(yùn)算操作自實(shí)現(xiàn)的詳細(xì)內(nèi)容,更多關(guān)于pytorch BN運(yùn)算的資料請關(guān)注腳本之家其它相關(guān)文章!
相關(guān)文章
Python之tkinter進(jìn)度條Progressbar用法解讀
這篇文章主要介紹了Python之tkinter進(jìn)度條Progressbar用法解讀,具有很好的參考價值,希望對大家有所幫助。如有錯誤或未考慮完全的地方,望不吝賜教2023-05-05
使用Python實(shí)現(xiàn)PDF與SVG互轉(zhuǎn)
SVG(可縮放矢量圖形)和PDF(便攜式文檔格式)是兩種常見且廣泛使用的文件格式,本文將詳細(xì)介紹如何使用?Python?實(shí)現(xiàn)?SVG?和?PDF?之間的相互轉(zhuǎn)換,感興趣的可以了解下2025-02-02
Keras SGD 隨機(jī)梯度下降優(yōu)化器參數(shù)設(shè)置方式
這篇文章主要介紹了Keras SGD 隨機(jī)梯度下降優(yōu)化器參數(shù)設(shè)置方式,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2020-06-06
Python將GIF動圖轉(zhuǎn)換為Base64編碼字符串的步驟詳解
在Web開發(fā)中,有時需要將圖像文件(如GIF動圖)轉(zhuǎn)換為Base64編碼的字符串,以便在HTML或CSS中直接嵌入圖像數(shù)據(jù),本文給大家就介紹了一個簡單的教程,教你如何使用Python將GIF動圖轉(zhuǎn)換為Base64編碼的字符串,需要的朋友可以參考下2025-02-02
python實(shí)現(xiàn)一組典型數(shù)據(jù)格式轉(zhuǎn)換
這篇文章主要為大家詳細(xì)介紹了python實(shí)現(xiàn)一組典型數(shù)據(jù)格式轉(zhuǎn)換,具有一定的參考價值,感興趣的小伙伴們可以參考一下2018-12-12
python中的?sorted()函數(shù)和sort()方法區(qū)別
這篇文章主要介紹了python中的?sorted()函數(shù)和sort()方法,首先看sort()方法,sort方法只能對列表進(jìn)行操作,而sorted可用于所有的可迭代對象。具體內(nèi)容需要的小伙伴可以參考下面章節(jié)2022-02-02
Python從Excel讀取數(shù)據(jù)并使用Matplotlib繪制成二維圖像
本課程實(shí)現(xiàn)使用 Python 從 Excel 讀取數(shù)據(jù),并使用 Matplotlib 繪制成二維圖像。這一過程中,將通過一系列操作來美化圖像,最終得到一個可以出版級別的圖像。本課程對于需要書寫實(shí)驗(yàn)報告,學(xué)位論文,發(fā)表文章,做報告的學(xué)員具有較大價值2023-02-02
matplotlib savefig 保存圖片大小的實(shí)例
今天小編就為大家分享一篇matplotlib savefig 保存圖片大小的實(shí)例,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2018-05-05

