Pytorch統(tǒng)計參數(shù)網(wǎng)絡(luò)參數(shù)數(shù)量方式
更新時間:2023年02月20日 10:09:39 作者:qq_34535410
這篇文章主要介紹了Pytorch統(tǒng)計參數(shù)網(wǎng)絡(luò)參數(shù)數(shù)量方式,具有很好的參考價值,希望對大家有所幫助。如有錯誤或未考慮完全的地方,望不吝賜教
Pytorch統(tǒng)計參數(shù)網(wǎng)絡(luò)參數(shù)數(shù)量
def get_parameter_number(net):
total_num = sum(p.numel() for p in net.parameters())
trainable_num = sum(p.numel() for p in net.parameters() if p.requires_grad)
return {'Total': total_num, 'Trainable': trainable_num}
Pytorch如何計算網(wǎng)絡(luò)的參數(shù)量
本文以 Dense Block 為例,Pytorch 為 DL 框架,最終計算模塊參數(shù)量方法如下:
import torch
import torch.nn as nn
class Norm_Conv(nn.Module):
? ? def __init__(self,in_channel):
? ? ? ? super(Norm_Conv,self).__init__()
? ? ? ? self.layers = nn.Sequential(
? ? ? ? ? ? nn.Conv2d(in_channel,in_channel,3,1,1),
? ? ? ? ? ? nn.ReLU(True),
? ? ? ? ? ? nn.BatchNorm2d(in_channel),
? ? ? ? ? ? nn.Conv2d(in_channel,in_channel,3,1,1),
? ? ? ? ? ? nn.ReLU(True),
? ? ? ? ? ? nn.BatchNorm2d(in_channel),
? ? ? ? ? ? nn.Conv2d(in_channel,in_channel,3,1,1),
? ? ? ? ? ? nn.ReLU(True),
? ? ? ? ? ? nn.BatchNorm2d(in_channel))
? ? def forward(self,input):
? ? ? ? out = self.layers(input)
? ? ? ? return out
class DenseBlock_Norm(nn.Module):
? ? def __init__(self,in_channel):
? ? ? ? super(DenseBlock_Norm,self).__init__()
? ? ? ? self.first_layer = nn.Sequential(nn.Conv2d(in_channel,in_channel,3,1,1),
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? nn.ReLU(True),
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? nn.BatchNorm2d(in_channel))
? ? ? ? self.second_layer = nn.Sequential(nn.Conv2d(in_channel*2,in_channel,3,1,1),
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? nn.ReLU(True),
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? nn.BatchNorm2d(in_channel))
? ? ? ? self.third_layer = nn.Sequential(
? ? ? ? ? ? nn.Conv2d(in_channel*3,in_channel,3,1,1),
? ? ? ? ? ? nn.ReLU(True),
? ? ? ? ? ? nn.BatchNorm2d(in_channel))
? ? def forward(self,input):
? ? ? ? output1 = self.first_layer(input)
? ? ? ? output2 = self.second_layer(torch.cat((output1,input),dim=1))
? ? ? ? output3 = self.third_layer(torch.cat((input,output1,output2),dim=1))
? ? ? ? return output3
def count_param(model):
? ? param_count = 0
? ? for param in model.parameters():
? ? ? ? param_count += param.view(-1).size()[0]
? ? return param_count
# Get Parameter number of Network
in_channel = 128
net1 = Norm_Conv(in_channel)
print('Norm Conv parameter count is {}'.format(count_param(net1)))
net2 = DenseBlock_Norm(in_channel)
print('DenseBlock Norm parameter count is {}'.format(count_param(net2)))最終結(jié)果如下
Norm Conv parameter count is 443520
DenseBlock Norm parameter count is 885888
總結(jié)
以上為個人經(jīng)驗,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關(guān)文章
python3.6 如何將list存入txt后再讀出list的方法
這篇文章主要介紹了python3.6 如何將list存入txt后再讀出list的方法,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧
2019-07-07
Python中排序函數(shù)sorted()函數(shù)的使用實例
sorted()作為Python內(nèi)置函數(shù)之一,其功能是對序列(列表、元組、字典、集合、還包括字符串)進(jìn)行排序,下面這篇文章主要給大家介紹了關(guān)于Python中排序函數(shù)sorted()函數(shù)的相關(guān)資料,需要的朋友可以參考下
2022-11-11
Python3 Tensorlfow:增加或者減小矩陣維度的實現(xiàn)
這篇文章主要介紹了Python3 Tensorlfow:增加或者減小矩陣維度的實現(xiàn),具有好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧
2020-05-05 
