pytorch 使用加載訓(xùn)練好的模型做inference
前提: 模型參數(shù)和結(jié)構(gòu)是分別保存的
1、 構(gòu)建模型(# load model graph)
model = MODEL()
2、加載模型參數(shù)(# load model state_dict)
model.load_state_dict
(
{
k.replace('module.',''):v for k,v in
torch.load(config.model_path, map_location=config.device).items()
}
)
model = self.model.to(config.device)
* config.device 指定使用哪塊GPU或者CPU
*k.replace('module.',''):v 防止torch.DataParallel訓(xùn)練的模型出現(xiàn)加載錯(cuò)誤
(解決RuntimeError: module must have its parameters and buffers on device cuda:0 (device_ids[0]) but found one of them on device: cuda:1問題)
3、設(shè)置當(dāng)前階段為inference(# predict)
model.eval()
以上這篇pytorch 使用加載訓(xùn)練好的模型做inference就是小編分享給大家的全部?jī)?nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
NumPy實(shí)現(xiàn)ndarray多維數(shù)組操作
NumPy一個(gè)非常重要的作用就是可以進(jìn)行多維數(shù)組的操作,這篇文章主要介紹了NumPy實(shí)現(xiàn)ndarray多維數(shù)組操作,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2021-05-05
Python使用Transformers實(shí)現(xiàn)機(jī)器翻譯功能
近年來,機(jī)器翻譯技術(shù)飛速發(fā)展,從傳統(tǒng)的基于規(guī)則的翻譯到統(tǒng)計(jì)機(jī)器翻譯,再到如今流行的神經(jīng)網(wǎng)絡(luò)翻譯模型,尤其是基于Transformer架構(gòu)的模型,翻譯效果已經(jīng)有了質(zhì)的飛躍,本文將詳細(xì)介紹如何使用Transformers庫(kù)來實(shí)現(xiàn)一個(gè)機(jī)器翻譯模型,需要的朋友可以參考下2024-11-11
python實(shí)現(xiàn)獲取序列中最小的幾個(gè)元素
Python+Sklearn實(shí)現(xiàn)異常檢測(cè)
Python字符串中的單詞反轉(zhuǎn)的實(shí)現(xiàn)示例
關(guān)于python tushare Tkinter構(gòu)建的簡(jiǎn)單股票可視化查詢系統(tǒng)(Beta v0.13)

