Pytorch 實現(xiàn)權重初始化
在TensorFlow中,權重的初始化主要是在聲明張量的時候進行的。 而PyTorch則提供了另一種方法:首先應該聲明張量,然后修改張量的權重。通過調(diào)用torch.nn.init包中的多種方法可以將權重初始化為直接訪問張量的屬性。
1、不初始化的效果
在Pytorch中,定義一個tensor,不進行初始化,打印看看結果:
w = torch.Tensor(3,4) print (w)
可以看到這時候的初始化的數(shù)值都是隨機的,而且特別大,這對網(wǎng)絡的訓練必定不好,最后導致精度提不上,甚至損失無法收斂。
2、初始化的效果
PyTorch提供了多種參數(shù)初始化函數(shù):
torch.nn.init.constant(tensor, val) torch.nn.init.normal(tensor, mean=0, std=1) torch.nn.init.xavier_uniform(tensor, gain=1)
等等。詳細請參考:http://pytorch.org/docs/nn.html#torch-nn-init
注意上面的初始化函數(shù)的參數(shù)tensor,雖然寫的是tensor,但是也可以是Variable類型的。而神經(jīng)網(wǎng)絡的參數(shù)類型Parameter是Variable類的子類,所以初始化函數(shù)可以直接作用于神經(jīng)網(wǎng)絡參數(shù)。實際上,我們初始化也是直接去初始化神經(jīng)網(wǎng)絡的參數(shù)。
讓我們試試效果:
w = torch.Tensor(3,4) torch.nn.init.normal_(w) print (w)
3、初始化神經(jīng)網(wǎng)絡的參數(shù)
對神經(jīng)網(wǎng)絡的初始化往往放在模型的__init__()函數(shù)中,如下所示:
class Net(nn.Module):
def __init__(self, block, layers, num_classes=1000):
self.inplanes = 64
super(Net, self).__init__()
***
*** #定義自己的網(wǎng)絡層
***
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
***
*** #定義后續(xù)的函數(shù)
***
也可以采取另一種方式:
定義一個權重初始化函數(shù),如下:
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv2d') != -1:
init.xavier_normal_(m.weight.data)
init.constant_(m.bias.data, 0.0)
elif classname.find('Linear') != -1:
init.xavier_normal_(m.weight.data)
init.constant_(m.bias.data, 0.0)
在模型聲明時,調(diào)用初始化函數(shù),初始化神經(jīng)網(wǎng)絡參數(shù):
model = Net(*****) model.apply(weights_init)
以上這篇Pytorch 實現(xiàn)權重初始化就是小編分享給大家的全部內(nèi)容了,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關文章
Tensorflow tf.nn.atrous_conv2d如何實現(xiàn)空洞卷積的
這篇文章主要介紹了Tensorflow tf.nn.atrous_conv2d如何實現(xiàn)空洞卷積的,文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友們下面隨著小編來一起學習學習吧2020-04-04
pytorch獲取模型某一層參數(shù)名及參數(shù)值方式
今天小編就為大家分享一篇pytorch獲取模型某一層參數(shù)名及參數(shù)值方式,具有很好的價值,希望對大家有所幫助。一起跟隨小編過來看看吧2019-12-12
Python3.7 讀取音頻根據(jù)文件名生成腳本的代碼
這篇文章主要介紹了Python3.7 讀取音頻根據(jù)文件名生成字幕腳本的方法,本文通過實例代碼給大家介紹的非常詳細,對大家的學習或工作具有一定的參考借鑒價值,需要的朋友可以參考下2020-04-04
Python開發(fā)必備知識內(nèi)存管理與垃圾回收
Python是一種高級編程語言,因其簡潔而強大而備受歡迎,然而如其他編程語言一樣,Python也面臨著內(nèi)存管理的挑戰(zhàn),在Python中,垃圾回收是一項關鍵任務,用于自動釋放不再使用的內(nèi)存,以避免內(nèi)存泄漏,本文將介紹Python中的垃圾回收機制,以及如何通過優(yōu)化代碼來提高性能2023-11-11
Python使用pydub模塊轉換音頻格式以及對音頻進行剪輯
這篇文章主要給大家介紹了關于Python使用pydub模塊轉換音頻格式以及對音頻進行剪輯的相關資料pydub是python的高級一個音頻處理庫,可以讓你以一種不那么蠢的方法處理音頻。需要的朋友可以參考下2021-06-06
Python實現(xiàn)五子棋聯(lián)機對戰(zhàn)小游戲
本文主要介紹了通過Python實現(xiàn)簡單的支持聯(lián)機對戰(zhàn)的游戲——支持局域網(wǎng)聯(lián)機對戰(zhàn)的五子棋小游戲。廢話不多說,快來跟隨小編一起學習吧2021-12-12

