pytorch hook 鉤子函數(shù)的用法
鉤子編程(hooking),也稱作“掛鉤”,是計(jì)算機(jī)程序設(shè)計(jì)術(shù)語(yǔ),指通過(guò)攔截軟件模塊間的函數(shù)調(diào)用、消息傳遞、事件傳遞來(lái)修改或擴(kuò)展操作系統(tǒng)、應(yīng)用程序或其他軟件組件的行為的各種技術(shù)。處理被攔截的函數(shù)調(diào)用、事件、消息的代碼,被稱為鉤子(hook)。
Hook 是 PyTorch 中一個(gè)十分有用的特性。利用它,我們可以不必改變網(wǎng)絡(luò)輸入輸出的結(jié)構(gòu),方便地獲取、改變網(wǎng)絡(luò)中間層變量的值和梯度。這個(gè)功能被廣泛用于可視化神經(jīng)網(wǎng)絡(luò)中間層的 feature、gradient,從而診斷神經(jīng)網(wǎng)絡(luò)中可能出現(xiàn)的問(wèn)題,分析網(wǎng)絡(luò)有效性。
本文主要用 hook 函數(shù)輸出網(wǎng)絡(luò)執(zhí)行過(guò)程中 forward 和 backward 的執(zhí)行順序,以此找到了bug所在。
用法如下:
# 設(shè)置hook func
def hook_func(name, module):
? ? def hook_function(module, inputs, outputs):
? ? ? ? # 請(qǐng)依據(jù)使用場(chǎng)景自定義函數(shù)
? ? ? ? print(name+' inputs', inputs)
? ? ? ? print(name+' outputs', outputs)
? ? return hook_function
# 注冊(cè)正反向hook
for name, module in model.named_modules():
? ? module.register_forward_hook(hook_func('[forward]: '+name, module))
? ? module.register_backward_hook(hook_func('[backward]: '+name, module))如一個(gè)簡(jiǎn)單的 MNIST 手寫(xiě)數(shù)字識(shí)別的模型結(jié)構(gòu)如下:
class Net(nn.Module): ? ? def __init__(self): ? ? ? ? super(Net, self).__init__() ? ? ? ? self.conv1 = nn.Conv2d(1, 32, 3, 1) ? ? ? ? self.conv2 = nn.Conv2d(32, 64, 3, 1) ? ? ? ? self.dropout1 = nn.Dropout(0.25) ? ? ? ? self.dropout2 = nn.Dropout(0.5) ? ? ? ? self.fc1 = nn.Linear(9216, 128) ? ? ? ? self.fc2 = nn.Linear(128, 10) ? ? def forward(self, x): ? ? ? ? x = self.conv1(x) ? ? ? ? x = F.relu(x) ? ? ? ? x = self.conv2(x) ? ? ? ? x = F.relu(x) ? ? ? ? x = F.max_pool2d(x, 2) ? ? ? ? x = self.dropout1(x) ? ? ? ? x = torch.flatten(x, 1) ? ? ? ? x = self.fc1(x) ? ? ? ? x = F.relu(x) ? ? ? ? x = self.dropout2(x) ? ? ? ? x = self.fc2(x) ? ? ? ? output = F.log_softmax(x, dim=1) ? ? ? ? return output
打印模型:
Net( ? (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1)) ? (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1)) ? (dropout1): Dropout(p=0.25, inplace=False) ? (dropout2): Dropout(p=0.5, inplace=False) ? (fc1): Linear(in_features=9216, out_features=128, bias=True) ? (fc2): Linear(in_features=128, out_features=10, bias=True) )
構(gòu)建hook函數(shù):
# 設(shè)置hook func
def hook_func(name, module):
? ? def hook_function(module, inputs, outputs):
? ? ? ? with open("log_model.txt", 'a+') as f:
? ? ? ? ? ? # 請(qǐng)依據(jù)使用場(chǎng)景自定義函數(shù)
? ? ? ? ? ? f.write(name + ' ? len(inputs): ' + str(len(inputs)) + '\n')
? ? ? ? ? ? f.write(name + ' ? len(outputs): ?' + str(len(outputs)) + '\n')
? ? return hook_function
# 注冊(cè)正反向hook
for name, module in model.named_modules():
? ? module.register_forward_hook(hook_func('[forward]: '+name, module))
? ? module.register_backward_hook(hook_func('[backward]: '+name, module))輸出的前向和反向傳播過(guò)程:
[forward]: conv1 len(inputs): 1
[forward]: conv1 len(outputs): 8
[forward]: conv2 len(inputs): 1
[forward]: conv2 len(outputs): 8
[forward]: dropout1 len(inputs): 1
[forward]: dropout1 len(outputs): 8
[forward]: fc1 len(inputs): 1
[forward]: fc1 len(outputs): 8
[forward]: dropout2 len(inputs): 1
[forward]: dropout2 len(outputs): 8
[forward]: fc2 len(inputs): 1
[forward]: fc2 len(outputs): 8
[forward]: len(inputs): 1
[forward]: len(outputs): 8
[backward]: len(inputs): 2
[backward]: len(outputs): 1
[backward]: fc2 len(inputs): 3
[backward]: fc2 len(outputs): 1
[backward]: dropout2 len(inputs): 1
[backward]: dropout2 len(outputs): 1
[backward]: fc1 len(inputs): 3
[backward]: fc1 len(outputs): 1
[backward]: dropout1 len(inputs): 1
[backward]: dropout1 len(outputs): 1
[backward]: conv2 len(inputs): 2
[backward]: conv2 len(outputs): 1
[backward]: conv1 len(inputs): 2
[backward]: conv1 len(outputs): 1
因?yàn)橹灰P吞幱趖rain狀態(tài),hook_func 就會(huì)執(zhí)行,導(dǎo)致不斷輸出 [forward] 和 [backward],所以將輸出內(nèi)容建議寫(xiě)到文件中,而不是 print
到此這篇關(guān)于pytorch hook 鉤子函數(shù)的用法的文章就介紹到這了,更多相關(guān)pytorch hook 鉤子函數(shù)內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
python for循環(huán)如何實(shí)現(xiàn)控制步長(zhǎng)
這篇文章主要介紹了python for循環(huán)如何實(shí)現(xiàn)控制步長(zhǎng),具有很好的參考價(jià)值,希望對(duì)大家有所幫助。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2022-05-05
Python多線程經(jīng)典問(wèn)題之乘客做公交車(chē)算法實(shí)例
這篇文章主要介紹了Python多線程經(jīng)典問(wèn)題之乘客做公交車(chē)算法,簡(jiǎn)單描述了乘客坐公交車(chē)問(wèn)題并結(jié)合實(shí)例形式分析了Python多線程實(shí)現(xiàn)乘客坐公交車(chē)算法的相關(guān)技巧,需要的朋友可以參考下2017-03-03
Python制作簡(jiǎn)易注冊(cè)登錄系統(tǒng)
這篇文章主要為大家詳細(xì)介紹了Python簡(jiǎn)易注冊(cè)登錄系統(tǒng)的制作方法,具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2016-12-12
Python控制臺(tái)輸出時(shí)刷新當(dāng)前行內(nèi)容而不是輸出新行的實(shí)現(xiàn)
今天小編就為大家分享一篇Python控制臺(tái)輸出時(shí)刷新當(dāng)前行內(nèi)容而不是輸出新行的實(shí)現(xiàn),具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2020-02-02

