Pytorch模型參數(shù)的保存和加載
一、前言
在模型訓(xùn)練完成后,我們需要保存模型參數(shù)值用于后續(xù)的測試過程。由于保存整個模型將耗費大量的存儲,故推薦的做法是只保存參數(shù),使用時只需在建好模型的基礎(chǔ)上加載。
通常來說,保存的對象包括網(wǎng)絡(luò)參數(shù)值、優(yōu)化器參數(shù)值、epoch值等。本文將簡單介紹保存和加載模型參數(shù)的方法,同時也給出保存整個模型的方法供大家參考。
二、參數(shù)保存
在這里我們使用 torch.save() 函數(shù)保存模型參數(shù):
import torch path = './model.pth' torch.save(model.state_dict(), path)
model——指定義的模型實例變量,如model=net( )
state_dict()——state_dict( )是一個可以輕松地保存、更新、修改和恢復(fù)的python字典對象, 對于model來說,表示模型的每一層的權(quán)重及偏置等參數(shù)信息;對于 optimizer 來說,其包含了優(yōu)化器的狀態(tài)以及被使用的超參數(shù)(如lr, momentum,weight_decay等)
path——path是保存參數(shù)的路徑,一般設(shè)置為 path='./model.pth' , path='./model.pkl'等形式。
此外,如果想保存某一次訓(xùn)練采用的optimizer、epochs等信息,可將這些信息組合起來構(gòu)成一個字典保存起來:
import torch
path = './model.pth'
state = {'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch}
torch.save(state, path)三、參數(shù)的加載
使用 load_state_dict()函數(shù)加載參數(shù)到模型中, 當(dāng)僅保存了模型參數(shù),而沒有optimizer、epochs等信息時:
model.load_state_dict(torch.load(path))
model——事先定義好的跟原模型一致的模型
path——之前保存的模型參數(shù)文件
如若保存了optimizer、epochs等信息,我們這樣載入信息:
# 使用torch.load()函數(shù)將文件中字典信息載入 state_dict 變量中 state_dict = torch.load(path) # 分布加載參數(shù)到模型和優(yōu)化器 model.load_state_dict(state_dict['model']) optimizer.load_state_dict(state_dict['optimizer']) epoch = state_dict(['epoch'])
我們還可以在每n個epoch后保存一次參數(shù),以觀察不同迭代次數(shù)模型的表現(xiàn)。此時我們可設(shè)置不同的path,如 path='./model' + str(epoch) +'.pth',這樣,不同epoch的參數(shù)就能保存在不同的文件中。
四、保存和加載整個模型
使用上文提到的方法即可:
torch.save(model, path) model = torch.load(path)
五、總結(jié)
pytorch中state_dict()和load_state_dict()函數(shù)配合使用可以實現(xiàn)狀態(tài)的獲取與重載,load()和save()函數(shù)配合使用可以實現(xiàn)參數(shù)的存儲與讀取。掌握對應(yīng)的函數(shù)使用方法就可以游刃有余地進(jìn)行運用。
到此這篇關(guān)于Pytorch模型參數(shù)的保存和加載的文章就介紹到這了,更多相關(guān)Pytorch模型參數(shù)保存內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
Python實現(xiàn)的遠(yuǎn)程文件自動打包并下載功能示例
這篇文章主要介紹了Python實現(xiàn)的遠(yuǎn)程文件自動打包并下載功能,結(jié)合實例形式分析了Python使用spawn()方法執(zhí)行ssh、scp 命令實現(xiàn)遠(yuǎn)程文件的相關(guān)操作技巧,需要的朋友可以參考下2019-07-07
Python+Pygame實戰(zhàn)之吃豆豆游戲的實現(xiàn)
這篇文章主要為大家介紹了如何利用Python中的Pygame模塊實現(xiàn)仿吃豆豆游戲,文中的示例代碼講解詳細(xì),對我們學(xué)習(xí)Python游戲開發(fā)有一定幫助,需要的可以參考一下2022-06-06
用python給csv里的數(shù)據(jù)排序的具體代碼
在本文里小編給大家分享的是關(guān)于用python給csv里的數(shù)據(jù)排序的具體代碼內(nèi)容,需要的朋友們可以學(xué)習(xí)下。2020-07-07
python神經(jīng)網(wǎng)絡(luò)Keras構(gòu)建CNN網(wǎng)絡(luò)訓(xùn)練
這篇文章主要為大家介紹了python神經(jīng)網(wǎng)絡(luò)學(xué)習(xí)使用Keras構(gòu)建CNN網(wǎng)絡(luò)訓(xùn)練,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪2022-05-05
Python 數(shù)據(jù)庫操作 SQLAlchemy的示例代碼
這篇文章主要介紹了Python 數(shù)據(jù)庫操作 SQLAlchemy的示例代碼,小編覺得挺不錯的,現(xiàn)在分享給大家,也給大家做個參考。一起跟隨小編過來看看吧2019-02-02
Python采用socket模擬TCP通訊的實現(xiàn)方法
這篇文章主要介紹了Python采用socket模擬TCP通訊的實現(xiàn)方法,程序分為TCP的server端與client端兩部分,分別對這兩部分進(jìn)行了較為深入的分析,需要的朋友可以參考下2014-11-11
Flask框架URL管理操作示例【基于@app.route】
這篇文章主要介紹了Flask框架URL管理操作,結(jié)合實例形式分析了@app.route進(jìn)行URL控制的相關(guān)操作技巧,需要的朋友可以參考下2018-07-07

