pytorch 中的重要模塊化接口nn.Module的使用
torch.nn 是專門為神經(jīng)網(wǎng)絡(luò)設(shè)計的模塊化接口,nn構(gòu)建于autgrad之上,可以用來定義和運行神經(jīng)網(wǎng)絡(luò)
nn.Module 是nn中重要的類,包含網(wǎng)絡(luò)各層的定義,以及forward方法
查看源碼
初始化部分:
def __init__(self): self._backend = thnn_backend self._parameters = OrderedDict() self._buffers = OrderedDict() self._backward_hooks = OrderedDict() self._forward_hooks = OrderedDict() self._forward_pre_hooks = OrderedDict() self._state_dict_hooks = OrderedDict() self._load_state_dict_pre_hooks = OrderedDict() self._modules = OrderedDict() self.training = True
屬性解釋:
- _parameters:字典,保存用戶直接設(shè)置的 Parameter
- _modules:子 module,即子類構(gòu)造函數(shù)中的內(nèi)容
- _buffers:緩存
- _backward_hooks與_forward_hooks:鉤子技術(shù),用來提取中間變量
- training:判斷值來決定前向傳播策略
方法定義:
def forward(self, *input): raise NotImplementedError
沒有實際內(nèi)容,用于被子類的 forward() 方法覆蓋
且 forward 方法在 __call__ 方法中被調(diào)用:
def __call__(self, *input, **kwargs):
for hook in self._forward_pre_hooks.values():
hook(self, input)
if torch._C._get_tracing_state():
result = self._slow_forward(*input, **kwargs)
else:
result = self.forward(*input, **kwargs)
...
...
對于自己定義的網(wǎng)絡(luò),需要注意以下幾點:
1)需要繼承nn.Module類,并實現(xiàn)forward方法,只要在nn.Module的子類中定義forward方法,backward函數(shù)就會被自動實現(xiàn)(利用autograd機(jī)制)
2)一般把網(wǎng)絡(luò)中可學(xué)習(xí)參數(shù)的層放在構(gòu)造函數(shù)中__init__(),沒有可學(xué)習(xí)參數(shù)的層如Relu層可以放在構(gòu)造函數(shù)中,也可以不放在構(gòu)造函數(shù)中(在forward函數(shù)中使用nn.Functional)
3)在forward中可以使用任何Variable支持的函數(shù),在整個pytorch構(gòu)建的圖中,是Variable在流動,也可以使用for,print,log等
4)基于nn.Module構(gòu)建的模型中,只支持mini-batch的Variable的輸入方式,如,N*C*H*W
代碼示例:
class LeNet(nn.Module):
def __init__(self):
# nn.Module的子類函數(shù)必須在構(gòu)造函數(shù)中執(zhí)行父類的構(gòu)造函數(shù)
super(LeNet, self).__init__() # 等價與nn.Module.__init__()
# nn.Conv2d返回的是一個Conv2d class的一個對象,該類中包含forward函數(shù)的實現(xiàn)
# 當(dāng)調(diào)用self.conv1(input)的時候,就會調(diào)用該類的forward函數(shù)
self.conv1 = nn.Conv2d(1, 6, (5, 5)) # output (N, C_{out}, H_{out}, W_{out})`
self.conv2 = nn.Conv2d(6, 16, (5, 5))
self.fc1 = nn.Linear(256, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
# F.max_pool2d的返回值是一個Variable, input:(10,1,28,28) ouput:(10, 6, 12, 12)
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
# input:(10, 6, 12, 12) output:(10,6,4,4)
x = F.max_pool2d(F.relu(self.conv2(x)), (2, 2))
# 固定樣本個數(shù),將其他維度的數(shù)據(jù)平鋪,無論你是幾通道,最終都會變成參數(shù), output:(10, 256)
x = x.view(x.size()[0], -1)
# 全連接
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = F.relu(self.fc3(x))
# 返回值也是一個Variable對象
return x
def output_name_and_params(net):
for name, parameters in net.named_parameters():
print('name: {}, param: {}'.format(name, parameters))
if __name__ == '__main__':
net = LeNet()
print('net: {}'.format(net))
params = net.parameters() # generator object
print('params: {}'.format(params))
output_name_and_params(net)
input_image = torch.FloatTensor(10, 1, 28, 28)
# 和tensorflow不一樣,pytorch中模型的輸入是一個Variable,而且是Variable在圖中流動,不是Tensor。
# 這可以從forward中每一步的執(zhí)行結(jié)果可以看出
input_image = Variable(input_image)
output = net(input_image)
print('output: {}'.format(output))
print('output.size: {}'.format(output.size()))
到此這篇關(guān)于pytorch 中的重要模塊化接口nn.Module的使用的文章就介紹到這了,更多相關(guān)pytorch nn.Module內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
Python3實現(xiàn)爬取簡書首頁文章標(biāo)題和文章鏈接的方法【測試可用】
這篇文章主要介紹了Python3實現(xiàn)爬取簡書首頁文章標(biāo)題和文章鏈接的方法,結(jié)合實例形式分析了Python3基于urllib及bs4庫針對簡書網(wǎng)進(jìn)行文章抓取相關(guān)操作技巧,需要的朋友可以參考下2018-12-12
使用wxPython和pandas模塊生成Excel文件的代碼實現(xiàn)
在Python編程中,有時我們需要根據(jù)特定的數(shù)據(jù)生成Excel文件,本文將介紹如何使用wxPython和pandas模塊來實現(xiàn)這個目標(biāo),文中通過代碼示例給大家講解的非常詳細(xì),具有一定的參考價值,需要的朋友可以參考下2024-05-05
淺談Python中的函數(shù)(def)及參數(shù)傳遞操作
這篇文章主要介紹了淺談Python中的函數(shù)(def)及參數(shù)傳遞操作,具有很好的參考價值,希望對大家有所幫助。如有錯誤或未考慮完全的地方,望不吝賜教2021-05-05
Python實現(xiàn)實時增量數(shù)據(jù)加載工具的解決方案
這篇文章主要分享結(jié)合單例模式實際應(yīng)用案例:實現(xiàn)實時增量數(shù)據(jù)加載工具的解決方案。最關(guān)鍵的是實現(xiàn)一個可進(jìn)行添加、修改、刪除等操作的增量ID記錄表。需要的可以參考一下2022-02-02

