PyTorch加載模型model.load_state_dict()問題及解決
PyTorch加載模型model.load_state_dict()問題
希望將訓(xùn)練好的模型加載到新的網(wǎng)絡(luò)上。
如上面題目所描述的,PyTorch在加載之前保存的模型參數(shù)的時候,遇到了問題。
Unexpected key(s) in state_dict: "module.features. ...".,Expected ".features....". 直接原因是key值名字不對應(yīng)。
表明了加載過程中,期望獲得的key值為feature...,而不是module.features....。
這是由模型保存過程中導(dǎo)致的,模型應(yīng)該是在DataParallel模式下面,也就是采用了多GPU訓(xùn)練模型,然后直接保存的。
You probably saved the model using nn.DataParallel, which stores the model in module, and now you are trying to load it without . You can either add a nn.DataParallel temporarily in your network for loading purposes, or you can load the weights file, create a new ordered dict without the module prefix, and load it back.
解決上面的問題有三個辦法:
1. 對load的模型創(chuàng)建新的字典
去掉不需要的key值"module".
# original saved file with DataParallel
state_dict = torch.load('checkpoint.pt') # 模型可以保存為pth文件,也可以為pt文件。
# create new OrderedDict that does not contain `module.`
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove `module.`,表面從第7個key值字符取到最后一個字符,正好去掉了module.
new_state_dict[name] = v #新字典的key值對應(yīng)的value為一一對應(yīng)的值。
# load params
model.load_state_dict(new_state_dict) # 從新加載這個模型。2. 直接用空白''代替'module.'
model.load_state_dict({k.replace('module.',''):v for k,v in torch.load('checkpoint.pt').items()})
# 相當(dāng)于用''代替'module.'。
#直接使得需要的鍵名等于期望的鍵名。3. 最簡單的方法
加載模型之后,接著將模型DataParallel,此時就可以load_state_dict。
如果有多個GPU,將模型并行化,用DataParallel來操作。
這個過程會將key值加一個"module. ***"。
model = VGGNet()
params=model.state_dict() #獲得模型的原始狀態(tài)以及參數(shù)。
for k,v in params.items():
print(k) #只打印key值,不打印具體參數(shù)。4. 總結(jié)
從出錯顯示的問題就可以看出,key值不匹配,因此可以選擇多種方法,將模型參數(shù)加載進(jìn)去。
這個方法通常會在load_state_dict過程中遇到。將訓(xùn)練好的一個網(wǎng)絡(luò)參數(shù),移植到另外一個網(wǎng)絡(luò)上面,繼續(xù)訓(xùn)練。
或者將訓(xùn)練好的網(wǎng)絡(luò)checkpoint加載進(jìn)模型,再次進(jìn)行訓(xùn)練??梢源蛴〕鰉odel state_dict來看出兩者的差別。
model = VGGNet()
params=model.state_dict() #獲得模型的原始狀態(tài)以及參數(shù)。
for k,v in params.items():
print(k) #只打印key值,不打印具體參數(shù)。features.0.0.weight
features.0.1.weight
features.1.conv.3.weight
features.1.conv.4.num_batches_tracked

model = VGGNet()
checkpoint = torch.load('checkpoint.pt', map_location='cpu')
# Load weights to resume from checkpoint。
# print('**************************************')
# 這個方法能夠直接打印出你保存的checkpoint的鍵和值。
for k,v in checkpoint.items():
print(k)
print("*****************************************")
輸出結(jié)果為:
module.features.0.0.weight",
"module.features.0.1.weight",
"module.features.0.1.bias
可以看出不匹配,模型的參數(shù)中,key值不同,多了module。
PS: 追加
在移植參數(shù)的過程中,對于出現(xiàn) .total_ops和.total_params結(jié)尾的參數(shù),可參考以下代碼:
from collections import OrderedDict
checkpoint = torch.load(
pretrained_model_file_path,
map_location=(None if use_cuda and not remap_to_cpu else "cpu"))
new_state_dict = OrderedDict()
for k, v in checkpoint.items():
if not k.endswith('total_ops') and not k.endswith('total_params'):
name = k[7:]
new_state_dict[name] = v最后
以上為個人經(jīng)驗(yàn),希望能給大家一個參考,也希望大家多多支持腳本之家。
相關(guān)文章
Python matplotlib畫圖實(shí)例之繪制擁有彩條的圖表
這篇文章主要介紹了Python matplotlib畫圖實(shí)例之繪制擁有彩條的圖表,具有一定借鑒價值,需要的朋友可以參考下2017-12-12
TensorFlow MNIST手寫數(shù)據(jù)集的實(shí)現(xiàn)方法
MNIST數(shù)據(jù)集中包含了各種各樣的手寫數(shù)字圖片,這篇文章主要介紹了TensorFlow MNIST手寫數(shù)據(jù)集的實(shí)現(xiàn)方法,需要的朋友可以參考下2020-02-02
python對矩陣進(jìn)行轉(zhuǎn)置的2種處理方法
這篇文章主要介紹了python對矩陣進(jìn)行轉(zhuǎn)置的2種處理方法,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2019-07-07
python實(shí)現(xiàn)從ftp服務(wù)器下載文件
這篇文章主要為大家詳細(xì)介紹了python實(shí)現(xiàn)從ftp服務(wù)器下載文件,文中示例代碼介紹的非常詳細(xì),具有一定的參考價值,感興趣的小伙伴們可以參考一下2020-03-03
Python實(shí)現(xiàn)簡單網(wǎng)頁圖片抓取完整代碼實(shí)例
這篇文章主要介紹了Python實(shí)現(xiàn)簡單網(wǎng)頁圖片抓取完整代碼實(shí)例,具有一定借鑒價值,需要的朋友可以參考下。2017-12-12

