Pytorch實現(xiàn)簡單自定義網(wǎng)絡(luò)層的方法
前言
Pytorch、Tensoflow等許多深度學(xué)習(xí)框架集成了大量常見的網(wǎng)絡(luò)層,為我們搭建神經(jīng)網(wǎng)絡(luò)提供了諸多便利。但在實際工作中,因為項目要求、研究需要或者發(fā)論文需要等等,大家一般都會需要自己發(fā)明一個現(xiàn)在在深度學(xué)習(xí)框架中還不存在的層。 在這些情況下,就必須構(gòu)建自定義層。
博主在學(xué)習(xí)了沐神的動手學(xué)深度學(xué)習(xí)這本書之后,學(xué)到了許多東西。這里記錄一下書中基于Pytorch實現(xiàn)簡單自定義網(wǎng)絡(luò)層的方法,僅供參考。
一、不帶參數(shù)的層
首先,我們構(gòu)造一個沒有任何參數(shù)的自定義層,要構(gòu)建它,只需繼承基礎(chǔ)層類并實現(xiàn)前向傳播功能。
import torch
import torch.nn.functional as F
from torch import nn
class CenteredLayer(nn.Module):
def __init__(self):
super().__init__()
def forward(self, X):
return X - X.mean()
輸入一些數(shù)據(jù),驗證一下網(wǎng)絡(luò)是否能正常工作:
layer = CenteredLayer() print(layer(torch.FloatTensor([1, 2, 3, 4, 5])))
輸出結(jié)果如下:
tensor([-2., -1., 0., 1., 2.])
運行正常,表明網(wǎng)絡(luò)沒有問題。
現(xiàn)在將我們自建的網(wǎng)絡(luò)層作為組件合并到更復(fù)雜的模型中,并輸入數(shù)據(jù)進行驗證:
net = nn.Sequential(nn.Linear(8, 128), CenteredLayer()) Y = net(torch.rand(4, 8)) print(Y.mean()) # 因為模型參數(shù)較多,輸出也較多,所以這里輸出Y的均值,驗證模型可運行即可
結(jié)果如下:
tensor(-5.5879e-09, grad_fn=<MeanBackward0>)
二、帶參數(shù)的層
這里使用內(nèi)置函數(shù)來創(chuàng)建參數(shù),這些函數(shù)可以提供一些基本的管理功能,使用更加方便。
這里實現(xiàn)了一個簡單的自定義的全連接層,大家可根據(jù)需要自行修改即可。
class MyLinear(nn.Module):
def __init__(self, in_units, units):
super().__init__()
self.weight = nn.Parameter(torch.randn(in_units, units))
self.bias = nn.Parameter(torch.randn(units,))
def forward(self, X):
linear = torch.matmul(X, self.weight.data) + self.bias.data
return F.relu(linear)
接下來實例化類并訪問其模型參數(shù):
linear = MyLinear(5, 3) print(linear.weight)
結(jié)果如下:
Parameter containing:
tensor([[-0.3708, 1.2196, 1.3658],
[ 0.4914, -0.2487, -0.9602],
[ 1.8458, 0.3016, -0.3956],
[ 0.0616, -0.3942, 1.6172],
[ 0.7839, 0.6693, -0.8890]], requires_grad=True)
而后輸入一些數(shù)據(jù),查看模型輸出結(jié)果:
print(linear(torch.rand(2, 5)))
# 結(jié)果如下
tensor([[1.2394, 0.0000, 0.0000],
[1.3514, 0.0968, 0.6667]])
我們還可以使用自定義層構(gòu)建模型,使用方法與使用內(nèi)置的全連接層相同。
net = nn.Sequential(MyLinear(64, 8), MyLinear(8, 1))
print(net(torch.rand(2, 64)))
# 結(jié)果如下
tensor([[4.1416],
[0.2567]])
三、總結(jié)
我們可以通過基本層類設(shè)計自定義層。這允許我們定義靈活的新層,其行為與深度學(xué)習(xí)框架中的任何現(xiàn)有層不同。
在自定義層定義完成后,我們就可以在任意環(huán)境和網(wǎng)絡(luò)架構(gòu)中調(diào)用該自定義層。
層可以有局部參數(shù),這些參數(shù)可以通過內(nèi)置函數(shù)創(chuàng)建。
四、參考
《動手學(xué)深度學(xué)習(xí)》 — 動手學(xué)深度學(xué)習(xí) 2.0.0-beta0 documentation
附:pytorch獲取網(wǎng)絡(luò)的層數(shù)和每層的名字
#創(chuàng)建自己的網(wǎng)絡(luò) import models model = models.__dict__["resnet50"](pretrained=True) for index ,(name, param) in enumerate(model.named_parameters()): ? ? print( str(index) + " " +name)
結(jié)果如下:
0 conv1.weight
1 bn1.weight
2 bn1.bias
3 layer1.0.conv1.weight
4 layer1.0.bn1.weight
5 layer1.0.bn1.bias
6 layer1.0.conv2.weight
7 layer1.0.bn2.weight
8 layer1.0.bn2.bias
9 layer1.0.conv3.weight
到此這篇關(guān)于Pytorch實現(xiàn)簡單自定義網(wǎng)絡(luò)層的文章就介紹到這了,更多相關(guān)Pytorch自定義網(wǎng)絡(luò)層內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
python實現(xiàn)的DES加密算法和3DES加密算法實例
這篇文章主要介紹了python實現(xiàn)的DES加密算法和3DES加密算法,以實例形式較為詳細(xì)的分析了DES加密算法和3DES加密算法的原理與實現(xiàn)技巧,需要的朋友可以參考下2015-06-06
python requests 庫請求帶有文件參數(shù)的接口實例
今天小編就為大家分享一篇python requests 庫請求帶有文件參數(shù)的接口實例,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2019-01-01
Python將Excel轉(zhuǎn)換為多種圖片格式的方法(PNG, JPG, BMP, SVG)
有時,你可能希望以圖片形式分享Excel數(shù)據(jù),以防止他人對數(shù)據(jù)進行修改或編輯,將Excel轉(zhuǎn)換為圖片可以將數(shù)據(jù)鎖定為靜態(tài)圖片,確保數(shù)據(jù)的完整性和準(zhǔn)確性,這篇文章將探討如何使用Python實現(xiàn)將Excel工作表轉(zhuǎn)換為多種圖片格式,如PNG,JPG,BMP和SVG,需要的朋友可以參考下2025-03-03
Python使用Dijkstra算法實現(xiàn)求解圖中最短路徑距離問題詳解
這篇文章主要介紹了Python使用Dijkstra算法實現(xiàn)求解圖中最短路徑距離問題,簡單描述了Dijkstra算法的原理并結(jié)合具體實例形式分析了Python使用Dijkstra算法實現(xiàn)求解圖中最短路徑距離的相關(guān)步驟與操作技巧,需要的朋友可以參考下2018-05-05
Python實現(xiàn)樹莓派WiFi斷線自動重連的實例代碼
實現(xiàn) WiFi 斷線自動重連,原理是用 Python 監(jiān)測網(wǎng)絡(luò)是否斷線,如果斷線則重啟網(wǎng)絡(luò)服務(wù)。接下來給大家分享實現(xiàn)代碼,需要的朋友參考下2017-03-03

