PyTorch深度學(xué)習(xí)模型的保存和加載流程詳解
一、模型參數(shù)的保存和加載
-
torch.save(module.state_dict(), path):使用module.state_dict()函數(shù)獲取各層已經(jīng)訓(xùn)練好的參數(shù)和緩沖區(qū),然后將參數(shù)和緩沖區(qū)保存到path所指定的文件存放路徑(常用文件格式為.pt、.pth或.pkl)。 torch.nn.Module.load_state_dict(state_dict):從state_dict中加載參數(shù)和緩沖區(qū)到Module及其子類中 。torch.nn.Module.state_dict()函數(shù)返回python中的一個(gè)OrderedDict類型字典對(duì)象,該對(duì)象將每一層與它的對(duì)應(yīng)參數(shù)和緩沖區(qū)建立映射關(guān)系,字典的鍵值是參數(shù)或緩沖區(qū)的名稱。只有那些參數(shù)可以訓(xùn)練的層才會(huì)被保存到OrderedDict中,例如:卷積層、線性層等。Python中的字典類以“鍵:值”方式存取數(shù)據(jù),OrderedDict是它的一個(gè)子類,實(shí)現(xiàn)了對(duì)字典對(duì)象中元素的排序(OrderedDict根據(jù)放入元素的先后順序進(jìn)行排序)。由于進(jìn)行了排序,所以順序不同的兩個(gè)OrderedDict字典對(duì)象會(huì)被當(dāng)做是兩個(gè)不同的對(duì)象。- 示例:
import torch
import torch.nn as nn
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 2, 3)
self.pool1 = nn.MaxPool2d(2, 2)
def forward(self, x):
x = self.conv1(x)
x = self.pool1(x)
return x
# 初始化網(wǎng)絡(luò)
net = Net()
net.conv1.weight[0].detach().fill_(1)
net.conv1.weight[1].detach().fill_(2)
net.conv1.bias.data.detach().zero_()
# 獲取state_dict
state_dict = net.state_dict()
# 字典的遍歷默認(rèn)是遍歷key,所以param_tensor實(shí)際上是鍵值
for param_tensor in state_dict:
print(param_tensor,':\n',state_dict[param_tensor])
# 保存模型參數(shù)
torch.save(state_dict,"net_params.pth")
# 通過加載state_dict獲取模型參數(shù)
net.load_state_dict(state_dict)
輸出:

二、完整模型的保存和加載
-
torch.save(module, path):將訓(xùn)練完的整個(gè)網(wǎng)絡(luò)模型module保存到path所指定的文件存放路徑(常用文件格式為.pt或.pth)。 torch.load(path):加載保存到path中的整個(gè)神經(jīng)網(wǎng)絡(luò)模型。- 示例:
import torch
import torch.nn as nn
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 2, 3)
self.pool1 = nn.MaxPool2d(2, 2)
def forward(self, x):
x = self.conv1(x)
x = self.pool1(x)
return x
# 初始化網(wǎng)絡(luò)
net = Net()
net.conv1.weight[0].detach().fill_(1)
net.conv1.weight[1].detach().fill_(2)
net.conv1.bias.data.detach().zero_()
# 保存整個(gè)網(wǎng)絡(luò)
torch.save(net,"net.pth")
# 加載網(wǎng)絡(luò)
net = torch.load("net.pth")
到此這篇關(guān)于PyTorch深度學(xué)習(xí)模型的保存和加載流程詳解的文章就介紹到這了,更多相關(guān)PyTorch 模型的保存 內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
基于Python實(shí)現(xiàn)音樂播放器的實(shí)現(xiàn)示例代碼
這篇文章主要介紹了如何利用Python編寫簡易的音樂播放器,文中的示例代碼講解詳細(xì),具有一的參考價(jià)值,需要的小伙伴可以參考一下2022-04-04
python實(shí)現(xiàn)批量獲取指定文件夾下的所有文件的廠商信息
這篇文章主要介紹了python實(shí)現(xiàn)批量獲取指定文件夾下的所有文件的廠商信息的方法,是非常實(shí)用的技巧,涉及到文件的讀寫與字典的操作等技巧,需要的朋友可以參考下2014-09-09
一文帶你學(xué)會(huì)Python?Flask框架設(shè)置響應(yīng)頭
本篇博客我們將帶大家全面了解Python中Flask框架關(guān)于請(qǐng)求的相關(guān)設(shè)置的相關(guān)知識(shí),文中的示例代碼講解詳細(xì),對(duì)我們學(xué)習(xí)Python有一定幫助,需要的可以參考一下2023-01-01
Python3.5內(nèi)置模塊之time與datetime模塊用法實(shí)例分析
這篇文章主要介紹了Python3.5內(nèi)置模塊之time與datetime模塊用法,結(jié)合實(shí)例形式分析了Python3.5 time與datetime模塊日期時(shí)間相關(guān)操作技巧,需要的朋友可以參考下2019-04-04
基于Python實(shí)現(xiàn)最新房價(jià)信息的獲取
這篇文章主要為大家介紹了如何利用Python獲取房價(jià)信息(以北京為例),整個(gè)數(shù)據(jù)獲取的信息是通過房源平臺(tái)獲取的,通過下載網(wǎng)頁元素并進(jìn)行數(shù)據(jù)提取分析完成整個(gè)過程,需要的可以參考一下2022-04-04
Python中猜拳游戲與猜篩子游戲的實(shí)現(xiàn)方法
這篇文章主要給大家介紹了關(guān)于Python中猜拳游戲與猜篩子游戲的相關(guān)資料,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2020-09-09
Python簡單實(shí)現(xiàn)子網(wǎng)掩碼轉(zhuǎn)換的方法
這篇文章主要介紹了Python簡單實(shí)現(xiàn)子網(wǎng)掩碼轉(zhuǎn)換的方法,涉及Python字符串相關(guān)操作技巧,需要的朋友可以參考下2016-04-04
Python入門教程(二十九)Python的RegEx正則表達(dá)式
這篇文章主要介紹了Python入門教程(二十九)Python的RegEx,RegEx 或正則表達(dá)式是形成搜索模式的字符序列。RegEx 可用于檢查字符串是否包含指定的搜索模式,需要的朋友可以參考下2023-04-04

