怎樣保存模型權(quán)重和checkpoint
概述
在pytorch中有兩種方式可以保存推理模型,第一種是只保存模型的參數(shù),比如parameters和buffers;另外一種是保存整個模型;
1.保存模型 - 權(quán)重參數(shù)
我們可以用torch.save()函數(shù)來保存model.state_dict();state_dict()里面包含模型的parameters&buffers;這種方法只保存模型中必要的訓(xùn)練參數(shù)。
你可以用pytorch中的pickle來保存模型;使用這種方法可以生成最直觀的語法,并涉及最少的代碼;這種方法的缺點是,序列化的數(shù)據(jù)被綁定到特定的類和保存模型時使用的確切的目錄結(jié)構(gòu)。
這樣做的原因是pickle并不保存模型類本身。相反,它保存包含類的文件的路徑,在加載期間使用;因此,當(dāng)在其他項目中使用或重構(gòu)后,您的代碼可能以各種方式中斷。
我們將探討如何保存和加載模型進(jìn)行推斷的兩種方法。
步驟:
(1)導(dǎo)入所有必要的庫來加載我們的數(shù)據(jù)
(2)定義和初始化神經(jīng)網(wǎng)絡(luò)
(3)初始化優(yōu)化器
(4)保存并通過state_dict加載模型
(5)保存并加載整個模型
1.1代碼
# -*- coding: utf-8 -*-
# @Project: zc
# @Author: zc
# @File name: Neural_Network_test
# @Create time: 2022/3/19 15:33
# 1.導(dǎo)入相關(guān)數(shù)據(jù)庫
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn import functional as F
# 2.定義神經(jīng)網(wǎng)絡(luò)模型
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
# 3. 實例化神經(jīng)網(wǎng)絡(luò)
net = Net()
# 4. 實例化優(yōu)化器
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
# 5. 保存模型參數(shù)
# Specify a path
PATH = "state_dict_model.pt"
# 6. 保存模型的參數(shù)字典:parameters and buffers
torch.save(net.state_dict(), PATH)
# 7. 實例化新的模型
model = Net()
# 8. 給新的實例加載之前的模型參數(shù)
model.load_state_dict(torch.load(PATH))
# 9. 設(shè)置模型為評估模式
model.eval()
注意(1):
pytorch中常用的慣例是將model.state_dict()保存為"state_dict_model.pt",即文件的格式一般是.pt或者.pth格式文件;注意load_state_dict加載的是一個字典,而不是路徑。
注意(2):
模型參數(shù)在推理階段一定要設(shè)置model.eval();這樣可以讓dropout和batchnorm失效,如果沒設(shè)置推理模式,會得到不一樣的結(jié)果。
2.保存模型 - 整個模型
將模型所有的內(nèi)容都保存下來。
# Specify a path PATH = "entire_model.pt" # Save torch.save(net, PATH) # Load model = torch.load(PATH) model.eval()
3.保存模型 - checkpoints
我們按照checkpoints模式來保存模型,本質(zhì)上就是按照字典的模式進(jìn)行分門別類的保存,我們可以通過鍵值進(jìn)行加載。
epoch:訓(xùn)練周期model_state_dict:模型可訓(xùn)練參數(shù)optimizer_state_dict:模型優(yōu)化器參數(shù)loss:模型的損失函數(shù)
# Additional information
EPOCH = 5
PATH = "model.pt"
LOSS = 0.4
torch.save({
'epoch': EPOCH,
'model_state_dict': net.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': LOSS,
}, PATH)
保存和加載通用的檢查點模型以進(jìn)行推斷或恢復(fù)訓(xùn)練,這有助于您從上一個地方繼續(xù)進(jìn)行。
當(dāng)保存一個常規(guī)檢查點時,您必須保存模型的state_dict之外的更多信息。
保存優(yōu)化器的state_dict也很重要,因為它包含緩沖區(qū)和參數(shù),隨著模型的運行而更新。
您可能希望保存的其他項目是您離開的時期,最新記錄的訓(xùn)練損失,外部torch.nn.嵌入層,以及更多,基于自己的算法
3.1代碼
# 1.導(dǎo)入相關(guān)數(shù)據(jù)庫
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn import functional as F
# 2. 定義神經(jīng)網(wǎng)絡(luò)
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
# 3. 實例化神經(jīng)網(wǎng)絡(luò)
net = Net()
# 4. 實例化優(yōu)化器
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
# Additional information
# 5. 定義超參數(shù)
EPOCH = 5
PATH = "model.pt"
LOSS = 0.4
# 6. 以checkpoints形式保存模型的相關(guān)數(shù)據(jù)
torch.save({
'epoch': EPOCH,
'model_state_dict': net.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': LOSS,
}, PATH)
# 7. 重新實例化一個模型
model = Net()
# 8. 實例化優(yōu)化器
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
# 9. 加載以前的checkpoint
checkpoint = torch.load(PATH)
# 10. 通過鍵值來加載相關(guān)參數(shù)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
# 11.設(shè)置推理模式
model.eval()
# - or -
model.train()
4.保存雙模型
當(dāng)保存有多個神經(jīng)網(wǎng)絡(luò)模型組成的神經(jīng)網(wǎng)絡(luò)時,比如GAN對抗模型,sequence-to-sequence序列到序列模型,或者一個組合模型,你必須為每一個模型保存狀態(tài)字典state_dict()和其對應(yīng)的優(yōu)化器參數(shù)optimizer.state_dict();您還可以保存任何其他項目,可能會幫助您恢復(fù)訓(xùn)練,只需將它們添加到字典;為了加載模型,第一步是初始化神經(jīng)網(wǎng)絡(luò)模型和優(yōu)化器,然后用torch.load()去加載checkpoint對應(yīng)的數(shù)據(jù),因為checkpoints是字典,所以我們可以通過鍵值進(jìn)行查詢導(dǎo)入;
4.1相關(guān)步驟
(1)導(dǎo)入所有相關(guān)的數(shù)據(jù)庫
(2)定義和實例化神經(jīng)網(wǎng)絡(luò)模型
(3)初始化優(yōu)化器
(4)保存多重模型
(5)加載多重模型
# 1.導(dǎo)入相關(guān)數(shù)據(jù)庫
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn import functional as F
# 2. 定義神經(jīng)網(wǎng)絡(luò)
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
# 3. 實例化神經(jīng)網(wǎng)絡(luò)A,B
netA = Net()
netB = Net()
# 4. 實例化優(yōu)化器A,B
optimizerA = optim.SGD(netA.parameters(), lr=0.001, momentum=0.9)
optimizerB = optim.SGD(netB.parameters(), lr=0.001, momentum=0.9)
# 5. 保存模型
# Specify a path to save to
PATH = "model.pt"
torch.save({
'modelA_state_dict': netA.state_dict(),
'modelB_state_dict': netB.state_dict(),
'optimizerA_state_dict': optimizerA.state_dict(),
'optimizerB_state_dict': optimizerB.state_dict(),
}, PATH)
# 6.重新實例化新的網(wǎng)絡(luò)模型A,B
modelA = Net()
modelB = Net()
# 7. 重新實例化新的網(wǎng)絡(luò)模型A,B
optimModelA = optim.SGD(modelA.parameters(), lr=0.001, momentum=0.9)
optimModelB = optim.SGD(modelB.parameters(), lr=0.001, momentum=0.9)
# 8. 將以前模型的參數(shù)重新加載到新的模型A,B中
checkpoint = torch.load(PATH)
modelA.load_state_dict(checkpoint['modelA_state_dict'])
modelB.load_state_dict(checkpoint['modelB_state_dict'])
optimizerA.load_state_dict(checkpoint['optimizerA_state_dict'])
optimizerB.load_state_dict(checkpoint['optimizerB_state_dict'])
# 9. 開啟預(yù)測模式
modelA.eval()
modelB.eval()
# - or -
# 10.開啟訓(xùn)練模式
modelA.train()
modelB.train()
5.機(jī)器學(xué)習(xí)流程圖

6.機(jī)器學(xué)習(xí)常用庫

總結(jié)
以上為個人經(jīng)驗,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關(guān)文章
Python?Diagrams創(chuàng)建高質(zhì)量圖表和流程圖實例探究
Python?Diagrams是一個強(qiáng)大的Python庫,使創(chuàng)建這些圖表變得簡單且靈活,本文將深入介紹Python?Diagrams,包括其基本概念、安裝方法、示例代碼以及一些高級用法,以幫助大家充分利用這一工具來創(chuàng)建令人印象深刻的圖表2024-01-01
Python實現(xiàn)多并發(fā)訪問網(wǎng)站功能示例
這篇文章主要介紹了Python實現(xiàn)多并發(fā)訪問網(wǎng)站功能,結(jié)合具體實例形式分析了Python線程結(jié)合URL模塊并發(fā)訪問網(wǎng)站的相關(guān)操作技巧,需要的朋友可以參考下2017-06-06
解決import tensorflow導(dǎo)致jupyter內(nèi)核死亡的問題
這篇文章主要介紹了解決import tensorflow導(dǎo)致jupyter內(nèi)核死亡的問題,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2021-02-02
Win10 GPU運算環(huán)境搭建(CUDA10.0+Cudnn 7.6.5+pytroch1.2+tensorflow1.
熟悉深度學(xué)習(xí)的人都知道,深度學(xué)習(xí)是需要訓(xùn)練的,本文主要介紹了Win10 GPU運算環(huán)境搭建,文中通過示例代碼介紹的非常詳細(xì),具有一定的參考價值,感興趣的小伙伴們可以參考一下2021-09-09
基于K.image_data_format() == ''channels_first'' 的理解
這篇文章主要介紹了基于K.image_data_format() == 'channels_first' 的理解,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2020-06-06
python sitk.show()與imageJ結(jié)合使用常見的問題
這篇文章主要介紹了python sitk.show()與imageJ結(jié)合使用常見的問題,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2020-04-04
Python?asyncore?socket客戶端開發(fā)基本使用教程
asyncore庫是python的一個標(biāo)準(zhǔn)庫,提供了以異步的方式寫入套接字服務(wù)的客戶端和服務(wù)器的基礎(chǔ)結(jié)構(gòu),這篇文章主要介紹了Python?asyncore?socket客戶端開發(fā)基本使用,需要的朋友可以參考下2022-12-12

