Pytorch加載部分預(yù)訓(xùn)練模型的參數(shù)實(shí)例
前言
自從從深度學(xué)習(xí)框架caffe轉(zhuǎn)到Pytorch之后,感覺Pytorch的優(yōu)點(diǎn)妙不可言,各種設(shè)計(jì)簡潔,方便研究網(wǎng)絡(luò)結(jié)構(gòu)修改,容易上手,比TensorFlow的臃腫好多了。對于深度學(xué)習(xí)的初學(xué)者,Pytorch值得推薦。今天主要主要談?wù)凱ytorch是如何加載預(yù)訓(xùn)練模型的參數(shù)以及代碼的實(shí)現(xiàn)過程。
直接加載預(yù)選臉模型
如果我們使用的模型和預(yù)訓(xùn)練模型完全一樣,那么我們就可以直接加載別人的模型,還有一種情況,我們在訓(xùn)練自己模型的過程中,突然中斷了,但只要我們保存了之前的模型的參數(shù)也可以使用下面的代碼直接加載我們保存的模型繼續(xù)訓(xùn)練,不用從頭開始。
model=DPN(*args, **kwargs)
model.load_state_dict(torch.load("DPN.pth"))
這樣的加載方式是基于Pytorch使用的模型存儲方法:
torch.save(DPN.state_dict(), "DPN.pth")
加載部分預(yù)訓(xùn)練模型參數(shù)
其實(shí)大多數(shù)時候我們根據(jù)自己的任物所提出的模型是在一些公開模型的基礎(chǔ)上改變而來,其中公開模型的參數(shù)我們沒有必要在從頭開始訓(xùn)練,只要加載其訓(xùn)練好的模型參數(shù)即可,這樣有助于提高訓(xùn)練的準(zhǔn)確率和我們模型的泛化能力。
model = DPN(num_init_features=64, k_R=96, G=32, k_sec=(3,4,20,3), inc_sec=(16,32,24,128), num_classes=1,decoder=args.decoder)
http = {'url': 'http://data.lip6.fr/cadene/pretrainedmodels/dpn92_extra-b040e4a9b.pth'}
pretrained_dict=model_zoo.load_url(http['url'])
model_dict = model.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}#filter out unnecessary keys
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
model = torch.nn.DataParallel(model).cuda()
因?yàn)樾枰獎h除預(yù)訓(xùn)練模型中不匹配的的鍵,也就是層的名字。
以上這篇Pytorch加載部分預(yù)訓(xùn)練模型的參數(shù)實(shí)例就是小編分享給大家的全部內(nèi)容了,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關(guān)文章
python3 面向?qū)ο骭_類的內(nèi)置屬性與方法的實(shí)例代碼
這篇文章主要介紹了python3 面向?qū)ο骭_類的內(nèi)置屬性與方法的實(shí)例代碼,非常不錯,具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2018-11-11
eclipse創(chuàng)建python項(xiàng)目步驟詳解
在本篇內(nèi)容里小編給大家分享了關(guān)于eclipse創(chuàng)建python項(xiàng)目的具體步驟和方法,需要的朋友們跟著學(xué)習(xí)下。2019-05-05
tensorflow使用tf.data.Dataset 處理大型數(shù)據(jù)集問題
這篇文章主要介紹了tensorflow使用tf.data.Dataset 處理大型數(shù)據(jù)集問題,具有很好的參考價(jià)值,希望對大家有所幫助。如有錯誤或未考慮完全的地方,望不吝賜教2022-12-12

