PyTorch?模型?onnx?文件導(dǎo)出及調(diào)用詳情
前言
Open Neural Network Exchange (ONNX,開放神經(jīng)網(wǎng)絡(luò)交換) 格式,是一個(gè)用于表示深度學(xué)習(xí)模型的標(biāo)準(zhǔn),可使模型在不同框架之間進(jìn)行轉(zhuǎn)移
PyTorch 所定義的模型為動(dòng)態(tài)圖,其前向傳播是由類方法定義和實(shí)現(xiàn)的
但是 Python 代碼的效率是比較底下的,試想把動(dòng)態(tài)圖轉(zhuǎn)化為靜態(tài)圖,模型的推理速度應(yīng)當(dāng)有所提升
PyTorch 框架中,torch.onnx.export 可以將父類為 nn.Module 的模型導(dǎo)出到 onnx 文件中,
最重要的有三個(gè)參數(shù):
- model:父類為 nn.Module 的模型
- args:傳入 model 的 forward 方法的變量列表,類型應(yīng)為
- tuplef:onnx 文件名稱的字符串
import torch from torchvision.models import resnet50 file = 'resnet.onnx' # 聲明模型 resnet = resnet50(pretrained=False).eval() image = torch.rand([1, 3, 224, 224]) # 導(dǎo)出為 onnx 文件 torch.onnx.export(resnet, (image,), file)
onnx 文件可被 Netron 打開,以查看模型結(jié)構(gòu)

基本用法
要在 Python 中運(yùn)行 onnx 模型,需要下載 onnxruntime
# 選其一即可 pip install onnxruntime # CPU 版本 pip install onnxruntime-gpu # GPU 版本
推理時(shí)需要借助其中的 InferenceSession,其中較為重要的實(shí)例方法有:
- get_inputs():得到輸入變量的列表 (變量屬性:name、shape、type)
- get_outputs():得到輸入變量的列表 (變量屬性:name、shape、type)run(output_names, input_feed):輸入變量為 numpy.ndarray (注意 dtype 應(yīng)為 float32),使用模型推理并返回輸出
可得出 onnx 模型的基本用法:
import onnxruntime as ort
import numpy as np
file = 'resnet.onnx'
# 找到 GPU / CPU
provider = ort.get_available_providers()[
1 if ort.get_device() == 'GPU' else 0]
print('設(shè)備:', provider)
# 聲明 onnx 模型
model = ort.InferenceSession(file, providers=[provider])
# 參考: ort.NodeArg
for node_list in model.get_inputs(), model.get_outputs():
for node in node_list:
attr = {'name': node.name,
'shape': node.shape,
'type': node.type}
print(attr)
print('-' * 60)
# 得到輸入、輸出結(jié)點(diǎn)的名稱
input_node_name = model.get_inputs()[0].name
ouput_node_name = [node.name for node in model.get_outputs()]
image = np.random.random([1, 3, 224, 224]).astype(np.float32)
print(model.run(output_names=ouput_node_name,
input_feed={input_node_name: image}))高級(jí) API
為了簡(jiǎn)化使用步驟,使用類進(jìn)行封裝:
class Onnx_Module(ort.InferenceSession):
''' onnx 推理模型
provider: 優(yōu)先使用 GPU'''
provider = ort.get_available_providers()[
1 if ort.get_device() == 'GPU' else 0]
def __init__(self, file):
super(Onnx_Module, self).__init__(file, providers=[self.provider])
# 參考: ort.NodeArg
self.inputs = [node_arg.name for node_arg in self.get_inputs()]
self.outputs = [node_arg.name for node_arg in self.get_outputs()]
def __call__(self, *arrays):
input_feed = {name: x for name, x in zip(self.inputs, arrays)}
return self.run(self.outputs, input_feed)在 PyTorch 中,對(duì)于卷積神經(jīng)網(wǎng)絡(luò) model 與圖像 image,推理的代碼為 "model(image)",而使用這個(gè)封裝的類也是類似:
import numpy as np file = 'resnet.onnx' model = Onnx_Module(file) image = np.random.random([1, 3, 224, 224]).astype(np.float32) print(model(image))
為了方便觀察 Torch 模型與 onnx 模型的速度差異,同時(shí)檢查兩個(gè)模型的輸出是否一致,又編寫了 test 函數(shù)
test 方法的參數(shù)與 torch.onnx.export 一致,其基本流程為:
- 得到 Torch 模型的輸出,并 print 推斷耗時(shí)
- 將 Torch 模型導(dǎo)出為 onnx 文件,將輸入變量中的 torch.tensor 轉(zhuǎn)化為 numpy.ndarray
- 初始化 onnx 模型,得到 onnx 模型的輸出,并 print 推斷耗時(shí)
- 計(jì)算 Torch 模型與 onnx 模型輸出的絕對(duì)誤差的均值
- 將 onnx 模型 return
class Timer:
repeat = 3
def __new__(cls, fun, *args, **kwargs):
import time
start = time.time()
for _ in range(cls.repeat): fun(*args, **kwargs)
cost = (time.time() - start) / cls.repeat
return cost * 1e3 # ms
class Onnx_Module(ort.InferenceSession):
''' onnx 推理模型
provider: 優(yōu)先使用 GPU'''
provider = ort.get_available_providers()[
1 if ort.get_device() == 'GPU' else 0]
def __init__(self, file):
super(Onnx_Module, self).__init__(file, providers=[self.provider])
# 參考: ort.NodeArg
self.inputs = [node_arg.name for node_arg in self.get_inputs()]
self.outputs = [node_arg.name for node_arg in self.get_outputs()]
def __call__(self, *arrays):
input_feed = {name: x for name, x in zip(self.inputs, arrays)}
return self.run(self.outputs, input_feed)
@classmethod
def test(cls, model, args, file, **export_kwargs):
# 測(cè)試 Torch 的運(yùn)行時(shí)間
torch_output = model(*args).data.numpy()
print(f'Torch: {Timer(model, *args):.2f} ms')
# model: Torch -> onnx
torch.onnx.export(model, args, file, **export_kwargs)
# data: tensor -> array
args = tuple(map(lambda tensor: tensor.data.numpy(), args))
onnx_model = cls(file)
# 測(cè)試 onnx 的運(yùn)行時(shí)間
onnx_output = onnx_model(*args)
print(f'Onnx: {Timer(onnx_model, *args):.2f} ms')
# 計(jì)算 Torch 模型與 onnx 模型輸出的絕對(duì)誤差
abs_error = np.abs(torch_output - onnx_output).mean()
print(f'Mean Error: {abs_error:.2f}')
return onnx_model對(duì)于 ResNet50 而言,Torch 模型的推斷耗時(shí)為 172.67 ms,onnx 模型的推斷耗時(shí)為 36.56 ms,onnx 模型的推斷耗時(shí)僅為 Torch 模型的 21.17%
到此這篇關(guān)于PyTorch 模型 onnx 文件導(dǎo)出及調(diào)用詳情的文章就介紹到這了,更多相關(guān)PyTorch文件導(dǎo)出內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
Python爬蟲實(shí)現(xiàn)HTTP網(wǎng)絡(luò)請(qǐng)求多種實(shí)現(xiàn)方式
這篇文章主要介紹了Python爬蟲實(shí)現(xiàn)HTTP網(wǎng)絡(luò)請(qǐng)求多種實(shí)現(xiàn)方式,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2020-06-06
python 列表,數(shù)組,矩陣兩兩轉(zhuǎn)換tolist()的實(shí)例
下面小編就為大家分享一篇python 列表,數(shù)組,矩陣兩兩轉(zhuǎn)換tolist()的實(shí)例,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2018-04-04
Python的Tkinter點(diǎn)擊按鈕觸發(fā)事件的例子
今天小編就為大家分享一篇Python的Tkinter點(diǎn)擊按鈕觸發(fā)事件的例子,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2019-07-07
PYQT5 實(shí)現(xiàn)給listwidget的滾動(dòng)條添加滾動(dòng)信號(hào)
這篇文章主要介紹了PYQT5 實(shí)現(xiàn)給listwidget的滾動(dòng)條添加滾動(dòng)信號(hào),具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2021-03-03
Python通過Django實(shí)現(xiàn)用戶注冊(cè)和郵箱驗(yàn)證功能代碼
這篇文章主要介紹了Python通過Django實(shí)現(xiàn)用戶注冊(cè)和郵箱驗(yàn)證功能代碼,具有一定借鑒價(jià)值,需要的朋友可以參考下。2017-12-12
Python利用pyHook實(shí)現(xiàn)監(jiān)聽用戶鼠標(biāo)與鍵盤事件
這篇文章主要介紹了Python利用pyHook實(shí)現(xiàn)監(jiān)聽用戶鼠標(biāo)與鍵盤事件,很有實(shí)用價(jià)值的一個(gè)技巧,需要的朋友可以參考下2014-08-08
Python實(shí)現(xiàn)接口自動(dòng)化測(cè)試的方法詳解
Python接口自動(dòng)化測(cè)試是一種高效、可重復(fù)的軟件質(zhì)量驗(yàn)證方法,尤其在現(xiàn)代軟件開發(fā)中,它已經(jīng)成為不可或缺的一部分,本文將深入探討如何使用Python進(jìn)行接口自動(dòng)化測(cè)試,文中通過代碼示例介紹的非常詳細(xì),需要的朋友可以參考下2024-08-08

