解決Pytorch修改預(yù)訓(xùn)練模型時(shí)遇到key不匹配的情況
一、Pytorch修改預(yù)訓(xùn)練模型時(shí)遇到key不匹配
最近想著修改網(wǎng)絡(luò)的預(yù)訓(xùn)練模型vgg.pth,但是發(fā)現(xiàn)當(dāng)我加載預(yù)訓(xùn)練模型權(quán)重到新建的模型并保存之后。
在我使用新賦值的網(wǎng)絡(luò)模型時(shí)出現(xiàn)了key不匹配的問題
#加載后保存(未修改網(wǎng)絡(luò)) base_weights = torch.load(args.save_folder + args.basenet) ssd_net.vgg.load_state_dict(base_weights) torch.save(ssd_net.state_dict(), args.save_folder + 'ssd_base' + '.pth')
# 將新保存的網(wǎng)絡(luò)代替之前的預(yù)訓(xùn)練模型
ssd_net = build_ssd('train', cfg['min_dim'], cfg['num_classes'])
net = ssd_net
...
if args.resume:
...
else:
base_weights = torch.load(args.save_folder + args.basenet)
#args.basenet為ssd_base.pth
print('Loading base network...')
ssd_net.vgg.load_state_dict(base_weights)
此時(shí)會(huì)如下出錯(cuò)誤:
Loading base network…
Traceback (most recent call last):
File “train.py”, line 264, in
train()
File “train.py”, line 110, in train
ssd_net.vgg.load_state_dict(base_weights)
…
RuntimeError: Error(s) in loading state_dict for ModuleList:
Missing key(s) in state_dict: “0.weight”, “0.bias”, … “33.weight”, “33.bias”.
Unexpected key(s) in state_dict: “vgg.0.weight”, “vgg.0.bias”, … “vgg.33.weight”, “vgg.33.bias”.
說明之前的預(yù)訓(xùn)練模型 key參數(shù)為"0.weight", “0.bias”,但是經(jīng)過加載保存之后變?yōu)榱?vgg.0.weight", “vgg.0.bias”
我認(rèn)為是因?yàn)楸旧淼哪P投x文件里self.vgg = nn.ModuleList(base)這一句。
現(xiàn)在的問題是因?yàn)樽约憾x保存的模型key參數(shù)多了一個(gè)前綴。
可以通過如下語句進(jìn)行修改,并加載
from collections import OrderedDict #導(dǎo)入此模塊
base_weights = torch.load(args.save_folder + args.basenet)
print('Loading base network...')
new_state_dict = **OrderedDict()**
for k, v in base_weights.items():
name = k[4:] # remove `vgg.`,即只取vgg.0.weights的后面幾位
new_state_dict[name] = v
ssd_net.vgg.load_state_dict(new_state_dict)
此時(shí)就不會(huì)再出錯(cuò)了。
參考了這個(gè)篇。修改一下就可以應(yīng)用到自己的模型啦。
//www.dhdzp.com/article/214214.htm
二、pytorch加載預(yù)訓(xùn)練模型遇到的問題:KeyError: ‘bn1.num_batches_tracked‘
最近在使用pytorch1.0加載resnet預(yù)訓(xùn)練模型時(shí),遇到的一個(gè)問題,在此記錄一下。
KeyError: 'layer1.0.bn1.num_batches_tracked'
其實(shí)是使用的版本的問題,pytorch0.4.1之后在BN層加入了track_running_stats這個(gè)參數(shù),
這個(gè)參數(shù)的作用如下:
訓(xùn)練時(shí)用來統(tǒng)計(jì)訓(xùn)練時(shí)的forward過的min-batch數(shù)目,每經(jīng)過一個(gè)min-batch, track_running_stats+=1
如果沒有指定momentum, 則使用1/num_batches_tracked 作為因數(shù)來計(jì)算均值和方差(running mean and variance).
其實(shí),這個(gè)參數(shù)沒啥用.但因?yàn)楣俜教峁┑念A(yù)訓(xùn)練模型是pytorch0.3版本訓(xùn)練出來的,因此沒有這個(gè)參數(shù).
所以,只要過濾一下預(yù)訓(xùn)練權(quán)重字典中的關(guān)鍵字即可,‘num_batches_tracked'.代碼例子,如下.
有問題的代碼:
def load_specific_param(self, state_dict, param_name, model_path):
param_dict = torch.load(model_path)
for i in state_dict:
key = param_name + '.' + i
state_dict[i].copy_(param_dict[key])
del param_dict
對(duì)'num_batches_tracked進(jìn)行過濾:
def load_specific_param(self, state_dict, param_name, model_path):
param_dict = torch.load(model_path)
param_dict = {k: v for k, v in param_dict.items() if 'num_batches_tracked' not in k}
for i in state_dict:
key = param_name + '.' + i
if 'num_batches_tracked' in key:
continue
state_dict[i].copy_(param_dict[key])
del param_dict
以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
python實(shí)現(xiàn)將一個(gè)數(shù)組逆序輸出的方法
今天小編就為大家分享一篇python實(shí)現(xiàn)將一個(gè)數(shù)組逆序輸出的方法,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2018-06-06
Python3隨機(jī)漫步生成數(shù)據(jù)并繪制
這篇文章主要為大家詳細(xì)介紹了Python3隨機(jī)漫步生成數(shù)據(jù)并繪制的方法,具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2018-08-08
Python數(shù)據(jù)處理Pandas庫的使用詳解
這篇文章主要為大家詳細(xì)介紹了pandas庫的使用方法,包括數(shù)據(jù)導(dǎo)入與導(dǎo)出、數(shù)據(jù)查看和篩選、數(shù)據(jù)處理和分組操作等,感興趣的小伙伴可以了解一下2023-07-07
如何利用itertuples對(duì)DataFrame進(jìn)行遍歷
這篇文章主要介紹了如何利用itertuples對(duì)DataFrame進(jìn)行遍歷問題,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2023-06-06
celery4+django2定時(shí)任務(wù)的實(shí)現(xiàn)代碼
這篇文章主要介紹了celery4+django2定時(shí)任務(wù)的實(shí)現(xiàn)代碼,小編覺得挺不錯(cuò)的,現(xiàn)在分享給大家,也給大家做個(gè)參考。一起跟隨小編過來看看吧2018-12-12
Django實(shí)現(xiàn)web端tailf日志文件功能及實(shí)例詳解
這篇文章主要介紹了Django實(shí)現(xiàn)web端tailf日志文件功能,本文通過實(shí)例給大家介紹的非常詳細(xì),具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2019-07-07

