pytorch載入預(yù)訓(xùn)練模型后,實(shí)現(xiàn)訓(xùn)練指定層
1、有了已經(jīng)訓(xùn)練好的模型參數(shù),對這個(gè)模型的某些層做了改變,如何利用這些訓(xùn)練好的模型參數(shù)繼續(xù)訓(xùn)練:
pretrained_params = torch.load('Pretrained_Model')
model = The_New_Model(xxx)
model.load_state_dict(pretrained_params.state_dict(), strict=False)
strict=False 使得預(yù)訓(xùn)練模型參數(shù)中和新模型對應(yīng)上的參數(shù)會(huì)被載入,對應(yīng)不上或沒有的參數(shù)被拋棄。
2、如果載入的這些參數(shù)中,有些參數(shù)不要求被更新,即固定不變,不參與訓(xùn)練,需要手動(dòng)設(shè)置這些參數(shù)的梯度屬性為Fasle,并且在optimizer傳參時(shí)篩選掉這些參數(shù):
# 載入預(yù)訓(xùn)練模型參數(shù)后...
for name, value in model.named_parameters():
if name 滿足某些條件:
value.requires_grad = False
# setup optimizer
params = filter(lambda p: p.requires_grad, model.parameters())
optimizer = torch.optim.Adam(params, lr=1e-4)
將滿足條件的參數(shù)的 requires_grad 屬性設(shè)置為False, 同時(shí) filter 函數(shù)將模型中屬性 requires_grad = True 的參數(shù)帥選出來,傳到優(yōu)化器(以Adam為例)中,只有這些參數(shù)會(huì)被求導(dǎo)數(shù)和更新。
3、如果載入的這些參數(shù)中,所有參數(shù)都更新,但要求一些參數(shù)和另一些參數(shù)的更新速度(學(xué)習(xí)率learning rate)不一樣,最好知道這些參數(shù)的名稱都有什么:
# 載入預(yù)訓(xùn)練模型參數(shù)后... for name, value in model.named_parameters(): print(name) # 或 print(model.state_dict().keys())
假設(shè)該模型中有encoder,viewer和decoder兩部分,參數(shù)名稱分別是:
'encoder.visual_emb.0.weight', 'encoder.visual_emb.0.bias', 'viewer.bd.Wsi', 'viewer.bd.bias', 'decoder.core.layer_0.weight_ih', 'decoder.core.layer_0.weight_hh',
假設(shè)要求encode、viewer的學(xué)習(xí)率為1e-6, decoder的學(xué)習(xí)率為1e-4,那么在將參數(shù)傳入優(yōu)化器時(shí):
ignored_params = list(map(id, model.decoder.parameters()))
base_params = filter(lambda p: id(p) not in ignored_params, model.parameters())
optimizer = torch.optim.Adam([{'params':base_params,'lr':1e-6},
{'params':model.decoder.parameters()}
],
lr=1e-4, momentum=0.9)
代碼的結(jié)果是除decoder參數(shù)的learning_rate=1e-4 外,其他參數(shù)的額learning_rate=1e-6。
在傳入optimizer時(shí),和一般的傳參方法torch.optim.Adam(model.parameters(), lr=xxx) 不同,參數(shù)部分用了一個(gè)list, list的每個(gè)元素有params和lr兩個(gè)鍵值。如果沒有 lr則應(yīng)用Adam的lr屬性。Adam的屬性除了lr, 其他都是參數(shù)所共有的(比如momentum)。
以上這篇pytorch載入預(yù)訓(xùn)練模型后,實(shí)現(xiàn)訓(xùn)練指定層就是小編分享給大家的全部內(nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
參考:
相關(guān)文章
python實(shí)現(xiàn)大文本文件分割成多個(gè)小文件
這篇文章主要為大家詳細(xì)介紹了python實(shí)現(xiàn)大文本文件分割成多個(gè)小文件,文中示例代碼介紹的非常詳細(xì),具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2021-04-04
python Flask實(shí)現(xiàn)restful api service
本篇文章主要介紹了python Flask實(shí)現(xiàn)restful api service,小編覺得挺不錯(cuò)的,現(xiàn)在分享給大家,也給大家做個(gè)參考。一起跟隨小編過來看看吧2017-12-12
python 實(shí)現(xiàn)將Numpy數(shù)組保存為圖像
今天小編就為大家分享一篇python 實(shí)現(xiàn)將Numpy數(shù)組保存為圖像,具有很好的參考價(jià)值,希望對大家有所幫助。一起跟隨小編過來看看吧2020-01-01
Python中實(shí)現(xiàn)地圖可視化的方法小結(jié)
Python提供了多個(gè)強(qiáng)大的庫,如Folium、Matplotlib、Geopandas等,使得創(chuàng)建漂亮而具有信息量的地圖變得簡單而靈活,本文將詳細(xì)介紹如何使用這些庫繪制漂亮的地圖,感興趣的可以了解下2023-12-12
淺析python實(shí)現(xiàn)動(dòng)態(tài)規(guī)劃背包問題
這篇文章主要介紹了python實(shí)現(xiàn)動(dòng)態(tài)規(guī)劃背包問題,本文給大家介紹的非常詳細(xì),對大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2020-12-12
打包遷移Python?env環(huán)境的三種方法總結(jié)
平常工作中可能遇到python虛擬環(huán)境遷移的場景,總結(jié)了如下幾個(gè)方法,下面這篇文章主要給大家介紹了關(guān)于打包遷移Python?env環(huán)境的三種方法,需要的朋友可以參考下2024-08-08
在Python中使用Neo4j數(shù)據(jù)庫的教程
這篇文章主要介紹了在Python中使用Neo4j數(shù)據(jù)庫的教程,Neo4j是一個(gè)具有一定人氣的非關(guān)系型的數(shù)據(jù)庫,需要的朋友可以參考下2015-04-04
Python之自動(dòng)獲取公網(wǎng)IP的實(shí)例講解
下面小編就為大家?guī)硪黄狿ython之自動(dòng)獲取公網(wǎng)IP的實(shí)例講解。小編覺得挺不錯(cuò)的,現(xiàn)在就分享給大家,也給大家做個(gè)參考。一起跟隨小編過來看看吧2017-10-10
淺談python中的__init__、__new__和__call__方法
這篇文章主要給大家介紹了關(guān)于python中__init__、__new__和__call__方法的相關(guān)資料,文中通過示例代碼介紹的非常詳細(xì),對大家具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考學(xué)習(xí),下面來跟著小編一起看看吧。2017-07-07

