pytorch自定義初始化權(quán)重的方法
在常見的pytorch代碼中,我們見到的初始化方式都是調(diào)用init類對每層所有參數(shù)進(jìn)行初始化。但是,有時我們有些特殊需求,比如用某一層的權(quán)重取優(yōu)化其它層,或者手動指定某些權(quán)重的初始值。
核心思想就是構(gòu)造和該層權(quán)重同一尺寸的矩陣去對該層權(quán)重賦值。但是,值得注意的是,pytorch中各層權(quán)重的數(shù)據(jù)類型是nn.Parameter,而不是Tensor或者Variable。
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
# 第一一個卷積層,我們可以看到它的權(quán)值是隨機初始化的
w=torch.nn.Conv2d(2,2,3,padding=1)
print(w.weight)
# 第一種方法
print("1.使用另一個Conv層的權(quán)值")
q=torch.nn.Conv2d(2,2,3,padding=1) # 假設(shè)q代表一個訓(xùn)練好的卷積層
print(q.weight) # 可以看到q的權(quán)重和w是不同的
w.weight=q.weight # 把一個Conv層的權(quán)重賦值給另一個Conv層
print(w.weight)
# 第二種方法
print("2.使用來自Tensor的權(quán)值")
ones=torch.Tensor(np.ones([2,2,3,3])) # 先創(chuàng)建一個自定義權(quán)值的Tensor,這里為了方便將所有權(quán)值設(shè)為1
w.weight=torch.nn.Parameter(ones) # 把Tensor的值作為權(quán)值賦值給Conv層,這里需要先轉(zhuǎn)為torch.nn.Parameter類型,否則將報錯
print(w.weight)
附:Variable和Parameter的區(qū)別
Parameter 是torch.autograd.Variable的一個字類,常被用于Module的參數(shù)。例如權(quán)重和偏置。
Parameters和Modules一起使用的時候會有一些特殊的屬性。parameters賦值給Module的屬性的時候,它會被自動加到Module的參數(shù)列表中,即會出現(xiàn)在Parameter()迭代器中。將Varaible賦給Module的時候沒有這樣的屬性。這可以在nn.Module的實現(xiàn)中詳細(xì)看一下。這樣做是為了保存模型的時候只保存權(quán)重偏置參數(shù),不保存節(jié)點值。所以復(fù)寫Variable加以區(qū)分。
另外一個不同是parameter不能設(shè)置volatile,而且require_grad默認(rèn)設(shè)置為true。Varaible默認(rèn)設(shè)置為False.
參數(shù):
parameter.data 得到tensor數(shù)據(jù)
parameter.requires_grad 默認(rèn)為True, BP過程中會求導(dǎo)
Parameter一般是在Modules中作為權(quán)重和偏置,自動加入?yún)?shù)列表,可以進(jìn)行保存恢復(fù)。和Variable具有相同的運算。
我們可以這樣簡單區(qū)分,在計算圖中,數(shù)據(jù)(包括輸入數(shù)據(jù)和計算過程中產(chǎn)生的feature map等)時variable類型,該類型不會被保存到模型中。 網(wǎng)絡(luò)的權(quán)重是parameter類型,在計算過程中會被更新,將會被保存到模型中。
以上這篇pytorch自定義初始化權(quán)重的方法就是小編分享給大家的全部內(nèi)容了,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關(guān)文章
python時間與Unix時間戳相互轉(zhuǎn)換方法詳解
這篇文章主要介紹了python時間與Unix時間戳相互轉(zhuǎn)換方法詳解,需要的朋友可以參考下2020-02-02
Flask SQLAlchemy一對一,一對多的使用方法實踐
Flask-SQLAlchemy一對一,一對多的使用方法實踐,需要的朋友可以參考下2013-02-02
Python開發(fā)必備知識內(nèi)存管理與垃圾回收
Python是一種高級編程語言,因其簡潔而強大而備受歡迎,然而如其他編程語言一樣,Python也面臨著內(nèi)存管理的挑戰(zhàn),在Python中,垃圾回收是一項關(guān)鍵任務(wù),用于自動釋放不再使用的內(nèi)存,以避免內(nèi)存泄漏,本文將介紹Python中的垃圾回收機制,以及如何通過優(yōu)化代碼來提高性能2023-11-11
Python實現(xiàn)PS圖像調(diào)整之對比度調(diào)整功能示例
這篇文章主要介紹了Python實現(xiàn)PS圖像調(diào)整之對比度調(diào)整功能,結(jié)合實例形式分析了Python實現(xiàn)PS圖像對比度調(diào)整的原理、實現(xiàn)方法及相關(guān)操作技巧,需要的朋友可以參考下2018-01-01

