在pytorch中如何查看模型model參數(shù)parameters
pytorch查看模型model參數(shù)parameters
示例1:pytorch自帶的faster r-cnn模型
import torch
import torchvision
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
for name, p in model.named_parameters():
print(name)
print(p.requires_grad)
print(...)
#或者
for p in model.parameters():
print(p)
print(...)
示例2:自定義網(wǎng)絡(luò)模型
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
cfg = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512]
self.features = self._vgg_layers(cfg)
def _vgg_layers(self, cfg):
layers = []
in_channels = 3
for x in cfg:
if x == 'M':
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
else:
layers += [nn.Conv2d(in_channels, x ,kernel_size=3, padding=1),
nn.BatchNorm2d(x),
nn.ReLU(inplace=True)
]
in_channels = x
return nn.Sequential(*layers)
def forward(self, data):
out_map = self.features(data)
return out_map
Model = Net()
for name, p in model.named_parameters():
print(name)
print(p.requires_grad)
print(...)
#或者
for p in model.parameters():
print(p)
print(...)
在自定義網(wǎng)絡(luò)中,model.parameters()方法繼承自nn.Module
pytorch查看模型參數(shù)總結(jié)
1:DNN_printer
其中(3, 32, 32)是輸入的大小,其他方法中的參數(shù)同理
from DNN_printer import DNN_printer
batch_size = 512
def train(epoch):
print('\nEpoch: %d' % epoch)
net.train()
train_loss = 0
correct = 0
total = 0
// put the code here and you can get the result
DNN_printer(net, (3, 32, 32),batch_size)結(jié)果

2:parameters
def cnn_paras_count(net):
"""cnn參數(shù)量統(tǒng)計(jì), 使用方式cnn_paras_count(net)"""
# Find total parameters and trainable parameters
total_params = sum(p.numel() for p in net.parameters())
print(f'{total_params:,} total parameters.')
total_trainable_params = sum(p.numel() for p in net.parameters() if p.requires_grad)
print(f'{total_trainable_params:,} training parameters.')
return total_params, total_trainable_params
cnn_paras_count(net)直接輸出參數(shù)量,然后自己計(jì)算
需要注意的是,一般模型中參數(shù)是以float32保存的,也就是一個參數(shù)由4個bytes表示,那么就可以將參數(shù)量轉(zhuǎn)化為存儲大小。
例如:
- 44426個參數(shù)*4 / 1024 ≈ 174KB
3:get_model_complexity_info()
from ptflops import get_model_complexity_info from torchvision import models net = models.mobilenet_v2() ops, params = get_model_complexity_info(net, (3, 224, 224), as_strings=True, print_per_layer_stat=True, verbose=True)

4:torchstat
from torchstat import stat import torchvision.models as models model = models.resnet152() stat(model, (3, 224, 224))
輸出

以上為個人經(jīng)驗(yàn),希望能給大家一個參考,也希望大家多多支持腳本之家。
相關(guān)文章
教你利用Python玩轉(zhuǎn)histogram直方圖的五種方法
這篇文章主要給大家介紹了關(guān)于如何利用Python玩轉(zhuǎn)histogram直方圖的五種方法,文中通過示例代碼介紹的非常詳細(xì),對大家學(xué)習(xí)或者使用python具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2018-07-07
通過shell+python實(shí)現(xiàn)企業(yè)微信預(yù)警
這篇文章主要介紹了通過shell+python實(shí)現(xiàn)企業(yè)微信預(yù)警,小編覺得挺不錯的,現(xiàn)在分享給大家,也給大家做個參考。一起跟隨小編過來看看吧2019-03-03
Python實(shí)現(xiàn)在Windows平臺修改文件屬性
這篇文章主要介紹了Python實(shí)現(xiàn)在Windows平臺修改文件屬性,具有很好的參考價(jià)值,希望對大家有所幫助。一起跟隨小編過來看看吧2020-03-03
在Python中執(zhí)行和調(diào)用JavaScript的多種方法小結(jié)
JavaScript(JS)是一種常用的腳本語言,通常用于網(wǎng)頁開發(fā),但有時(shí)也需要在Python中執(zhí)行或調(diào)用JavaScript代碼,本文將詳細(xì)介紹Python中執(zhí)行和調(diào)用JavaScript的多種方法,每種方法都將附有示例代碼,方便理解如何在Python中與JavaScript進(jìn)行互動,需要的朋友可以參考下2023-11-11
Python如何讀寫二進(jìn)制數(shù)組數(shù)據(jù)
這篇文章主要介紹了Python如何讀寫二進(jìn)制數(shù)組數(shù)據(jù),文中講解非常細(xì)致,代碼幫助大家更好的理解和學(xué)習(xí),感興趣的朋友可以了解下2020-08-08
TensorFlow實(shí)現(xiàn)自定義Op方式
今天小編就為大家分享一篇TensorFlow實(shí)現(xiàn)自定義Op方式,具有很好的參考價(jià)值,希望對大家有所幫助。一起跟隨小編過來看看吧2020-02-02
Python設(shè)置matplotlib.plot的坐標(biāo)軸刻度間隔以及刻度范圍
這篇文章主要介紹了Python設(shè)置matplotlib.plot的坐標(biāo)軸刻度間隔以及刻度范圍,具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2019-06-06
python通過paramiko復(fù)制遠(yuǎn)程文件及文件目錄到本地
這篇文章主要為大家詳細(xì)介紹了python通過paramiko復(fù)制遠(yuǎn)程文件及文件目錄到本地,具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2019-04-04

