pytorch模型轉(zhuǎn)onnx模型的方法詳解
學(xué)習(xí)目標(biāo)
1.掌握pytorch模型轉(zhuǎn)換到onnx模型
2.順利運(yùn)行onnx模型
3.比對onnx模型和pytorch模型的輸出結(jié)果
學(xué)習(xí)大綱
- pytorch模型轉(zhuǎn)換onnx模型
- 運(yùn)行onnx模型
- onnx模型輸出與pytorch模型比對
學(xué)習(xí)內(nèi)容
前提條件:需要安裝onnx 和 onnxruntime,可以通過 pip install onnx 和 pip install onnxruntime 進(jìn)行安裝
1 . pytorch 轉(zhuǎn) onnx
pytorch 轉(zhuǎn) onnx 只需要一個(gè)函數(shù) torch.onnx.export
torch.onnx.export(model, args, path, export_params, verbose, input_names, output_names, do_constant_folding, dynamic_axes, opset_version)
參數(shù)說明:
- model——需要導(dǎo)出的pytorch模型
- args——模型的輸入?yún)?shù),滿足輸入層的shape正確即可。
- path——輸出的onnx模型的位置。例如‘yolov5.onnx’。
- export_params——輸出模型是否可訓(xùn)練。default=True,表示導(dǎo)出trained model,否則untrained。
- verbose——是否打印模型轉(zhuǎn)換信息。default=False。
- input_names——輸入節(jié)點(diǎn)名稱。default=None。
- output_names——輸出節(jié)點(diǎn)名稱。default=None。
- do_constant_folding——是否使用常量折疊(不了解),默認(rèn)即可。default=True。
- dynamic_axes——模型的輸入輸出有時(shí)是可變的,如Rnn,或者輸出圖像的batch可變,可通過該參數(shù)設(shè)置。如輸入層的shape為(b,3,h,w),batch,height,width是可變的,但是chancel是固定三通道。
格式如下 :
1)僅list(int) dynamic_axes={‘input’:[0,2,3],‘output’:[0,1]}
2)僅dict<int, string> dynamic_axes={‘input’:{0:‘batch’,2:‘height’,3:‘width’},‘output’:{0:‘batch’,1:‘c’}}
3)mixed dynamic_axes={‘input’:{0:‘batch’,2:‘height’,3:‘width’},‘output’:[0,1]} - opset_version——opset的版本,低版本不支持upsample等操作。
import torch
import torch.nn
import onnx
model = torch.load('best.pt')
model.eval()
input_names = ['input']
output_names = ['output']
x = torch.randn(1,3,32,32,requires_grad=True)
torch.onnx.export(model, x, 'best.onnx', input_names=input_names, output_names=output_names, verbose='True')2 . 運(yùn)行onnx模型
檢查onnx模型,并使用onnxruntime運(yùn)行。
import onnx
import onnxruntime as ort
model = onnx.load('best.onnx')
onnx.checker.check_model(model)
session = ort.InferenceSession('best.onnx')
x=np.random.randn(1,3,32,32).astype(np.float32) # 注意輸入type一定要np.float32!!!!!
# x= torch.randn(batch_size,chancel,h,w)
outputs = session.run(None,input = { 'input' : x })
參數(shù)說明:
- output_names: default=None
用來指定輸出哪些,以及順序
若為None,則按序輸出所有的output,即返回[output_0,output_1]
若為[‘output_1’,‘output_0’],則返回[output_1,output_0]
若為[‘output_0’],則僅返回[output_0:tensor] - input:dict
可以通過session.get_inputs().name獲得名稱
其中key值要求與torch.onnx.export中設(shè)定的一致
3.onnx模型輸出與pytorch模型比對
import numpy as np np.testing.assert_allclose(torch_result[0].detach().numpu(),onnx_result,rtol=0.0001)
如前所述,經(jīng)驗(yàn)表明,ONNX 模型的運(yùn)行效率明顯優(yōu)于原 PyTorch 模型,這似乎是源于 ONNX 模型生成過程中的優(yōu)化,這也導(dǎo)致了模型的生成過程比較耗時(shí),但整體效率依舊可觀。
此外,根據(jù)對 ONNX 模型和 PyTorch 模型運(yùn)行結(jié)果的統(tǒng)計(jì)分析(誤差的均值和標(biāo)準(zhǔn)差),可以看出 ONNX 模型的運(yùn)行結(jié)果誤差很小、基本可靠。
內(nèi)容參考:https://zhuanlan.zhihu.com/p/422290231
總結(jié)
到此這篇關(guān)于pytorch模型轉(zhuǎn)onnx模型的文章就介紹到這了,更多相關(guān)pytorch模型轉(zhuǎn)onnx模型內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
python計(jì)算波峰波谷值的方法(極值點(diǎn))
這篇文章主要介紹了python求極值點(diǎn)(波峰波谷)求極值點(diǎn)主要用到了scipy庫,本文通過實(shí)例代碼給大家介紹的非常詳細(xì),具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2020-02-02
python list等分并從等分的子集中隨機(jī)選取一個(gè)數(shù)
這篇文章主要介紹了python list等分并從等分的子集中隨機(jī)選取一個(gè)數(shù),文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2020-11-11
Python利用matplotlib實(shí)現(xiàn)繪制密度散點(diǎn)圖
這篇文章主要介紹了如何基于Python語言的matplotlib模塊,對Excel表格文件中的指定數(shù)據(jù)加以密度散點(diǎn)圖繪制的方法,有需要的小伙伴可以參考下2024-04-04
Python 保持登錄狀態(tài)進(jìn)行接口測試的方法示例
這篇文章主要介紹了Python 保持登錄狀態(tài)進(jìn)行接口測試的方法示例,小編覺得挺不錯(cuò)的,現(xiàn)在分享給大家,也給大家做個(gè)參考。一起跟隨小編過來看看吧2019-08-08
使用Python可設(shè)置抽獎(jiǎng)?wù)邫?quán)重的抽獎(jiǎng)腳本代碼
這篇文章主要介紹了Python可設(shè)置抽獎(jiǎng)?wù)邫?quán)重的抽獎(jiǎng)腳本,抽獎(jiǎng)系統(tǒng)包含可給不同抽獎(jiǎng)?wù)咴O(shè)置不同的權(quán)重,先從價(jià)值高的獎(jiǎng)品開始抽,已經(jīng)中獎(jiǎng)的人,不再參與后續(xù)的抽獎(jiǎng),本文通過實(shí)例代碼給大家介紹的非常詳細(xì),需要的朋友可以參考下2022-11-11
python將天數(shù)轉(zhuǎn)換為日期字符串的方法實(shí)例
這篇文章主要給大家介紹了關(guān)于python將天數(shù)轉(zhuǎn)換為日期字符串的相關(guān)資料,以及將將字符串的時(shí)間轉(zhuǎn)換為時(shí)間戳的實(shí)例代碼,需要的朋友可以參考下2022-01-01

