PyTorch模型的保存與加載方法實(shí)例
模型的保存與加載
首先,需要導(dǎo)入兩個(gè)包
import torch import torchvision.models as models
保存和加載模型參數(shù)
PyTorch模型將學(xué)習(xí)到的參數(shù)存儲(chǔ)在一個(gè)內(nèi)部狀態(tài)字典中,叫做state_dict。這可以通過torch.save方法來實(shí)現(xiàn)。
我們導(dǎo)入預(yù)訓(xùn)練好的VGG16模型,并將其保存。我們將state_dict字典保存在model_weights.pth文件中。
model = models.vgg16(pretrained=True) torch.save(model.state_dict(), 'model_weights.pth')
想要加載模型參數(shù),我們需要?jiǎng)?chuàng)建一個(gè)和原模型一樣的實(shí)例,然后通過load_state_dict()方法來加載模型參數(shù)
- 創(chuàng)建一個(gè)
VGG16模型實(shí)例(未經(jīng)過預(yù)訓(xùn)練的) - 加載本地參數(shù)
model = models.vgg16() # we do not specify pretrained=True, i.e. do not load default weights
model.load_state_dict(torch.load('model_weights.pth'))
model.eval()
注意:在進(jìn)行測(cè)試前,如果模型中有dropout層和batch normalization層的話,一定要使用model.eval()將模型轉(zhuǎn)到測(cè)試模式。
- 在
train模式下,dropout網(wǎng)絡(luò)層會(huì)按照設(shè)定的參數(shù)p設(shè)置保留激活單元的概率(保留概率=p);batchnorm層會(huì)繼續(xù)計(jì)算數(shù)據(jù)的mean和var等參數(shù)并更新。 - 在
val模式下,dropout層會(huì)讓所有的激活單元都通過,而batchnorm層會(huì)停止計(jì)算和更新mean和var,直接使用在訓(xùn)練階段已經(jīng)學(xué)出的mean和var值
當(dāng)然,相同的,在模型進(jìn)行訓(xùn)練之前,要使用model.train()來將模型轉(zhuǎn)為訓(xùn)練模式
保存和加載模型參數(shù)與結(jié)構(gòu)
當(dāng)加載模型權(quán)重時(shí),我們需要首先實(shí)例化模型類,因?yàn)轭惗x了網(wǎng)絡(luò)的結(jié)構(gòu)。我們可能希望將這個(gè)類的結(jié)構(gòu)與模型保存在一起。這樣的話,我們可以將model而不是model.state_dict()作為參數(shù)。
torch.save(model, 'model.pth')
這樣,我們加載模型的時(shí)候就不用再新建一個(gè)實(shí)例了。加載方式如下所示
model = torch.load('model.pth')這種方式在網(wǎng)絡(luò)比較大的時(shí)候可能比較慢,因?yàn)橄噍^于上面的方式多存儲(chǔ)了網(wǎng)絡(luò)的結(jié)構(gòu)
總結(jié)
到此這篇關(guān)于PyTorch模型的保存與加載方法的文章就介紹到這了,更多相關(guān)PyTorch模型保存加載內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
Python中GPU計(jì)算的庫(kù)pycuda的使用
本文主要介紹了Python中GPU計(jì)算的庫(kù)pycuda的使用,詳細(xì)介紹了PyCUDA 庫(kù)的特性、用法,并通過豐富的示例代碼展示其在實(shí)際項(xiàng)目中的應(yīng)用,感興趣的可以了解一下2024-05-05
對(duì)DataFrame數(shù)據(jù)中的重復(fù)行,利用groupby累加合并的方法詳解
今天小編就為大家分享一篇對(duì)DataFrame數(shù)據(jù)中的重復(fù)行,利用groupby累加合并的方法詳解,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2019-01-01
Python隨機(jī)生成一個(gè)6位的驗(yàn)證碼代碼分享
這篇文章主要介紹了Python隨機(jī)生成一個(gè)6位的驗(yàn)證碼代碼分享,本文直接給出代碼實(shí)例,需要的朋友可以參考下2015-03-03
使用Matplotlib繪制平行坐標(biāo)系的示例詳解
平行坐標(biāo)系,是一種含有多個(gè)垂直平行坐標(biāo)軸的統(tǒng)計(jì)圖表,這篇文章主要為大家介紹了如何使用繪制平行坐標(biāo)系,需要的小伙伴可以參考一下2023-07-07
Python中enumerate()函數(shù)詳細(xì)分析(附多個(gè)Demo)
Python的enumerate()函數(shù)是一個(gè)內(nèi)置函數(shù),主要用于在遍歷循環(huán)中獲取每個(gè)元素的索引以及對(duì)應(yīng)的值,這篇文章主要介紹了Python中enumerate()函數(shù)的相關(guān)資料,需要的朋友可以參考下2024-10-10
Python中卷積神經(jīng)網(wǎng)絡(luò)(CNN)入門教程分分享
卷積神經(jīng)網(wǎng)絡(luò)(Convolutional Neural Networks, CNN)是一類特別適用于處理圖像數(shù)據(jù)的深度學(xué)習(xí)模型,本文介紹了如何使用Keras創(chuàng)建一個(gè)簡(jiǎn)單的CNN模型,并用它對(duì)手寫數(shù)字進(jìn)行分類,需要的可以參考一下2023-05-05

