MxNet預(yù)訓(xùn)練模型到Pytorch模型的轉(zhuǎn)換方式
預(yù)訓(xùn)練模型在不同深度學(xué)習(xí)框架中的轉(zhuǎn)換是一種常見的任務(wù)。今天剛好DPN預(yù)訓(xùn)練模型轉(zhuǎn)換問題,順手將這個(gè)過程記錄一下。
核心轉(zhuǎn)換函數(shù)如下所示:
def convert_from_mxnet(model, checkpoint_prefix, debug=False):
_, mxnet_weights, mxnet_aux = mxnet.model.load_checkpoint(checkpoint_prefix, 0)
remapped_state = {}
for state_key in model.state_dict().keys():
k = state_key.split('.')
aux = False
mxnet_key = ''
if k[0] == 'features':
if k[1] == 'conv1_1':
# input block
mxnet_key += 'conv1_x_1__'
if k[2] == 'bn':
mxnet_key += 'relu-sp__bn_'
aux, key_add = _convert_bn(k[3])
mxnet_key += key_add
else:
assert k[3] == 'weight'
mxnet_key += 'conv_' + k[3]
elif k[1] == 'conv5_bn_ac':
# bn + ac at end of features block
mxnet_key += 'conv5_x_x__relu-sp__bn_'
assert k[2] == 'bn'
aux, key_add = _convert_bn(k[3])
mxnet_key += key_add
else:
# middle blocks
if model.b and 'c1x1_c' in k[2]:
bc_block = True # b-variant split c-block special treatment
else:
bc_block = False
ck = k[1].split('_')
mxnet_key += ck[0] + '_x__' + ck[1] + '_'
ck = k[2].split('_')
mxnet_key += ck[0] + '-' + ck[1]
if ck[1] == 'w' and len(ck) > 2:
mxnet_key += '(s/2)' if ck[2] == 's2' else '(s/1)'
mxnet_key += '__'
if k[3] == 'bn':
mxnet_key += 'bn_' if bc_block else 'bn__bn_'
aux, key_add = _convert_bn(k[4])
mxnet_key += key_add
else:
ki = 3 if bc_block else 4
assert k[ki] == 'weight'
mxnet_key += 'conv_' + k[ki]
elif k[0] == 'classifier':
if 'fc6-1k_weight' in mxnet_weights:
mxnet_key += 'fc6-1k_'
else:
mxnet_key += 'fc6_'
mxnet_key += k[1]
else:
assert False, 'Unexpected token'
if debug:
print(mxnet_key, '=> ', state_key, end=' ')
mxnet_array = mxnet_aux[mxnet_key] if aux else mxnet_weights[mxnet_key]
torch_tensor = torch.from_numpy(mxnet_array.asnumpy())
if k[0] == 'classifier' and k[1] == 'weight':
torch_tensor = torch_tensor.view(torch_tensor.size() + (1, 1))
remapped_state[state_key] = torch_tensor
if debug:
print(list(torch_tensor.size()), torch_tensor.mean(), torch_tensor.std())
model.load_state_dict(remapped_state)
return model
從中可以看出,其轉(zhuǎn)換步驟如下:
(1)創(chuàng)建pytorch的網(wǎng)絡(luò)結(jié)構(gòu)模型,設(shè)為model
(2)利用mxnet來讀取其存儲(chǔ)的預(yù)訓(xùn)練模型,得到mxnet_weights;
(3)遍歷加載后模型mxnet_weights的state_dict().keys
(4)對(duì)一些指定的key值,需要進(jìn)行相應(yīng)的處理和轉(zhuǎn)換
(5)對(duì)修改鍵名之后的key利用numpy之間的轉(zhuǎn)換來實(shí)現(xiàn)加載。
為了實(shí)現(xiàn)上述轉(zhuǎn)換,首先pip安裝mxnet,現(xiàn)在新版的mxnet安裝還是非常方便的。

第二步,運(yùn)行轉(zhuǎn)換程序,實(shí)現(xiàn)預(yù)訓(xùn)練模型的轉(zhuǎn)換。

可以看到在相當(dāng)?shù)奈募A下已經(jīng)出現(xiàn)了轉(zhuǎn)換后的模型。
以上這篇MxNet預(yù)訓(xùn)練模型到Pytorch模型的轉(zhuǎn)換方式就是小編分享給大家的全部?jī)?nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
python 數(shù)字轉(zhuǎn)換為日期的三種實(shí)現(xiàn)方法
在Python中,我們經(jīng)常需要處理日期和時(shí)間,本文主要介紹了python 數(shù)字轉(zhuǎn)換為日期的三種實(shí)現(xiàn)方法,包含datetime模塊,strftime方法及pandas庫(kù),具有一定的參考價(jià)值,感興趣的可以了解一下2024-02-02
淺談keras中自定義二分類任務(wù)評(píng)價(jià)指標(biāo)metrics的方法以及代碼
這篇文章主要介紹了淺談keras中自定義二分類任務(wù)評(píng)價(jià)指標(biāo)metrics的方法以及代碼,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2020-06-06
Python虛擬機(jī)棧幀對(duì)象及獲取源碼學(xué)習(xí)
這篇文章主要為大家介紹了Python虛擬機(jī)棧幀對(duì)象及獲取源碼學(xué)習(xí),有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪2023-03-03
Python寫的Discuz7.2版faq.php注入漏洞工具
這篇文章主要介紹了Python寫的Discuz7.2版faq.php注入漏洞工具,全自動(dòng)的一款注入工具,針對(duì)Discuz7.2版,需要的朋友可以參考下2014-08-08
python使用xlsx和pandas處理Excel表格的操作步驟
python的神器pandas庫(kù)就可以非常方便地處理excel,csv,矩陣,表格 等數(shù)據(jù),下面這篇文章主要給大家介紹了關(guān)于python使用xlsx和pandas處理Excel表格的操作步驟,文中通過圖文介紹的非常詳細(xì),需要的朋友可以參考下2023-01-01
在Python中使用循環(huán)進(jìn)行迭代的方法小結(jié)
Python中的循環(huán)結(jié)構(gòu)是編程中的重要組成部分,本文詳細(xì)介紹這兩種循環(huán)的使用方法、它們之間的差異以及如何選擇合適的循環(huán)類型,此外,我還將介紹一些高級(jí)循環(huán)控制技巧,如列表推導(dǎo)式和生成器表達(dá)式,感興趣的朋友一起看看吧2024-01-01
python基于三階貝塞爾曲線的數(shù)據(jù)平滑算法
這篇文章主要介紹了python基于三階貝塞爾曲線的數(shù)據(jù)平滑算法,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2019-12-12

