PyTorch中nn.Module示例詳解
直接print(dir(nn.Module)),得到如下內(nèi)容:

一、模型結(jié)構(gòu)與參數(shù)
parameters()- 用途:返回模塊的所有可訓(xùn)練參數(shù)(如權(quán)重、偏置)。
- 示例:
for param in model.parameters(): print(param.shape)
named_parameters()- 用途:返回帶名稱的參數(shù)迭代器,便于調(diào)試和訪問(wèn)特定參數(shù)。
- 示例:
for name, param in model.named_parameters(): if 'weight' in name: print(name, param.shape)
children()- 用途:返回直接子模塊的迭代器。
- 示例:
for child in model.children(): print(type(child))
modules()- 用途:遞歸返回所有子模塊(包括自身)。
- 示例:
for module in model.modules(): if isinstance(module, nn.Conv2d): print(module.kernel_size)
二、模型狀態(tài)與模式
train()和eval()- 用途:切換訓(xùn)練/推理模式(影響Dropout、BatchNorm等層)。
- 示例:
model.train() # 訓(xùn)練模式 model.eval() # 推理模式
training- 用途:布爾屬性,指示當(dāng)前模式(
True為訓(xùn)練,False為推理)。 - 示例:
print(model.training) # 輸出:True/False
- 用途:布爾屬性,指示當(dāng)前模式(
三、模型保存與加載
state_dict()- 用途:返回包含模型所有參數(shù)的字典(
OrderedDict)。 - 示例:
torch.save(model.state_dict(), 'model.pth')
- 用途:返回包含模型所有參數(shù)的字典(
load_state_dict()- 用途:從字典加載模型參數(shù)。
- 示例:
model.load_state_dict(torch.load('model.pth'))
四、設(shè)備與數(shù)據(jù)類型
to()- 用途:將模型移動(dòng)到指定設(shè)備(如GPU)或轉(zhuǎn)換數(shù)據(jù)類型。
- 示例:
model.to('cuda') # 移動(dòng)到GPU model.to(torch.float16) # 轉(zhuǎn)換為半精度
cpu()和cuda()- 用途:快捷方法,分別將模型移動(dòng)到CPU或GPU。
- 示例:
model.cuda() # 等價(jià)于 model.to('cuda')
五、前向傳播與計(jì)算
forward()- 用途:定義模型的前向傳播邏輯(需在自定義模塊中重寫)。
- 示例:
class MyModel(nn.Module): def forward(self, x): return self.layer(x)
__call__()- 用途:調(diào)用模型實(shí)例時(shí)觸發(fā)(內(nèi)部調(diào)用
forward(),支持鉤子函數(shù))。 - 示例:
output = model(x) # 等價(jià)于 output = model.forward(x)
- 用途:調(diào)用模型實(shí)例時(shí)觸發(fā)(內(nèi)部調(diào)用
六、參數(shù)初始化與優(yōu)化
zero_grad()- 用途:清空所有參數(shù)的梯度(通常在每個(gè)訓(xùn)練步驟前調(diào)用)。
- 示例:
optimizer.zero_grad() # 等價(jià)于 model.zero_grad()
requires_grad_()- 用途:設(shè)置參數(shù)是否需要梯度(用于凍結(jié)部分模型)。
- 示例:
for param in model.parameters(): param.requires_grad = False # 凍結(jié)所有參數(shù)
七、調(diào)試與信息
extra_repr()- 用途:自定義模塊打印信息(需在子類中重寫)。
- 示例:
class MyModel(nn.Module): def extra_repr(self): return f"hidden_size={self.hidden_size}"
dump_patches()- 用途:打印模型的補(bǔ)丁信息(用于調(diào)試版本差異)。
八、其他實(shí)用方法
apply()- 用途:遞歸應(yīng)用函數(shù)到所有子模塊(如初始化權(quán)重)。
- 示例:
def init_weights(m): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight) model.apply(init_weights)
register_forward_hook()- 用途:注冊(cè)前向傳播鉤子(用于捕獲中間輸出,調(diào)試或特征提?。?。
總結(jié)
日常使用中,最頻繁的方法包括:
- 模型構(gòu)建:
parameters(),children(),modules() - 訓(xùn)練與推理:
train(),eval(),zero_grad(),forward() - 保存與加載:
state_dict(),load_state_dict() - 設(shè)備管理:
to(),cuda(),cpu()
其他方法根據(jù)具體需求選擇使用,例如鉤子函數(shù)用于高級(jí)調(diào)試,apply() 用于統(tǒng)一初始化。
與nn.Sequential對(duì)比:
1. 繼承關(guān)系與基礎(chǔ)屬性
nn.Module- 是所有神經(jīng)網(wǎng)絡(luò)模塊的基類,提供最基礎(chǔ)的功能(如參數(shù)管理、鉤子機(jī)制)。
- 包含核心屬性:
_parameters,_modules,_buffers等。
nn.Sequential- 是
nn.Module的子類,繼承了所有基礎(chǔ)功能。 - 額外添加了與順序執(zhí)行相關(guān)的屬性(如
__getitem__、append)。
- 是
2. 核心差異對(duì)比
| 功能類別 | nn.Module | nn.Sequential |
|---|---|---|
| 模塊構(gòu)建 | 需要手動(dòng)實(shí)現(xiàn) forward 方法 | 自動(dòng)按順序執(zhí)行子模塊,無(wú)需定義 forward |
| 子模塊訪問(wèn) | 通過(guò)屬性名(如 self.conv1) | 通過(guò)索引或命名(如 model[0]) |
| 動(dòng)態(tài)修改 | 需手動(dòng)管理子模塊 | 支持 append、extend、insert 等操作 |
| 適用場(chǎng)景 | 復(fù)雜網(wǎng)絡(luò)結(jié)構(gòu)(如ResNet、U-Net) | 簡(jiǎn)單順序結(jié)構(gòu)(如LeNet卷積部分) |
3. 具體方法對(duì)比
3.1 公共方法(兩者都有)
# 模型參數(shù)與結(jié)構(gòu) ['parameters', 'named_parameters', 'children', 'modules', 'named_children', 'named_modules'] # 模型狀態(tài) ['train', 'eval', 'training', 'zero_grad', 'requires_grad_'] # 設(shè)備與數(shù)據(jù)類型 ['to', 'cpu', 'cuda', 'float', 'double', 'half', 'bfloat16'] # 保存與加載 ['state_dict', 'load_state_dict'] # 鉤子機(jī)制 ['register_forward_hook', 'register_backward_hook']
3.2nn.Sequential特有的方法
# 列表操作(動(dòng)態(tài)修改模塊順序) ['__getitem__', '__setitem__', '__delitem__', '__len__', 'append', 'extend', 'insert', 'pop'] # 索引相關(guān) ['_get_item_by_idx']
3.3nn.Module特有的方法
# 自定義實(shí)現(xiàn) ['forward', 'extra_repr'] # 高級(jí)管理 ['add_module', 'register_module', 'register_parameter', 'register_buffer']
4. 示例對(duì)比
4.1 創(chuàng)建模型
# nn.Module(需自定義 forward)
class CustomModel(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 64, 3)
self.relu = nn.ReLU()
def forward(self, x):
return self.relu(self.conv(x))
# nn.Sequential(自動(dòng)按順序執(zhí)行)
seq_model = nn.Sequential(
nn.Conv2d(3, 64, 3),
nn.ReLU()
)4.2 訪問(wèn)子模塊
# nn.Module custom_model.conv # 通過(guò)屬性名訪問(wèn) # nn.Sequential seq_model[0] # 通過(guò)索引訪問(wèn) seq_model.append(nn.MaxPool2d(2)) # 動(dòng)態(tài)添加模塊
5. 總結(jié)
| 特性 | nn.Module | nn.Sequential |
|---|---|---|
| 靈活性 | 高(自定義任意邏輯) | 低(僅支持順序執(zhí)行) |
| 代碼復(fù)雜度 | 較高(需手動(dòng)實(shí)現(xiàn) forward) | 低(自動(dòng)處理前向傳播) |
| 動(dòng)態(tài)修改 | 不支持直接操作(需手動(dòng)管理) | 支持 append、insert 等操作 |
| 適用場(chǎng)景 | 復(fù)雜網(wǎng)絡(luò)、分支結(jié)構(gòu)、自定義操作 | 簡(jiǎn)單堆疊模塊(如CNN的卷積部分) |
建議:
- 對(duì)于簡(jiǎn)單的順序網(wǎng)絡(luò),優(yōu)先使用
nn.Sequential以減少代碼量。 - 對(duì)于包含復(fù)雜邏輯(如殘差連接、多輸入輸出)的網(wǎng)絡(luò),使用
nn.Module自定義實(shí)現(xiàn)。
到此這篇關(guān)于PyTorch中nn.Module詳解的文章就介紹到這了,更多相關(guān)PyTorch nn.Module內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
django將網(wǎng)絡(luò)中的圖片,保存成model中的ImageField的實(shí)例
今天小編就為大家分享一篇django將網(wǎng)絡(luò)中的圖片,保存成model中的ImageField的實(shí)例,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2019-08-08
python實(shí)現(xiàn)通過(guò)flask和前端進(jìn)行數(shù)據(jù)收發(fā)
今天小編就為大家分享一篇python實(shí)現(xiàn)通過(guò)flask和前端進(jìn)行數(shù)據(jù)收發(fā),具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2019-08-08
使用python telnetlib批量備份交換機(jī)配置的方法
今天小編就為大家分享一篇使用python telnetlib批量備份交換機(jī)配置的方法,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2019-07-07
DjangoUeditor圖片不顯示img的src沒(méi)有域名問(wèn)題
在使用DjangoUeditor過(guò)程中,可能遇到圖片上傳后不顯示問(wèn)題,解決辦法是修改源碼view.py,加入代碼使得保存的圖片URL帶有協(xié)議和域名,具體做法是在保存圖片代碼中添加request.scheme獲取協(xié)議,request.META['HTTP_HOST']獲取域名2024-09-09
python 裝飾器帶參數(shù)和不帶參數(shù)步驟詳解
裝飾器是Python語(yǔ)言中一種特殊的語(yǔ)法,用于在不修改原函數(shù)代碼的情況下,為函數(shù)添加額外的功能或修改函數(shù)的行為,這篇文章主要介紹了python裝飾器帶參數(shù)和不帶參數(shù)的相關(guān)知識(shí),需要的朋友可以參考下2024-05-05
python 讀取攝像頭數(shù)據(jù)并保存的實(shí)例
今天小編就為大家分享一篇python 讀取攝像頭數(shù)據(jù)并保存的實(shí)例,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2018-08-08
使用Python3 poplib模塊刪除服務(wù)器多天前的郵件實(shí)現(xiàn)代碼
這篇文章主要介紹了使用Python3 poplib模塊刪除多天前的郵件的實(shí)現(xiàn)代碼,代碼簡(jiǎn)單易懂,非常不錯(cuò),對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2020-04-04

