如何計(jì)算 tensorflow 和 pytorch 模型的浮點(diǎn)運(yùn)算數(shù)
本文主要討論如何計(jì)算 tensorflow 和 pytorch 模型的 FLOPs。如有表述不當(dāng)之處歡迎批評(píng)指正。歡迎任何形式的轉(zhuǎn)載,但請(qǐng)務(wù)必注明出處。
1. 引言
FLOPs 是 floating point operations 的縮寫,指浮點(diǎn)運(yùn)算數(shù),可以用來衡量模型/算法的計(jì)算復(fù)雜度。本文主要討論如何在 tensorflow 1.x, tensorflow 2.x 以及 pytorch 中利用相關(guān)工具計(jì)算對(duì)應(yīng)模型的 FLOPs。
2. 模型結(jié)構(gòu)
為了說明方便,先搭建一個(gè)簡單的神經(jīng)網(wǎng)絡(luò)模型,其模型結(jié)構(gòu)以及主要參數(shù)如表1 所示。
表 1 模型結(jié)構(gòu)及主要參數(shù)
| Layers | channels | Kernels | Strides | Units | Activation |
|---|---|---|---|---|---|
| Conv2D | 32 | (4,4) | (1,2) | \ | relu |
| GRU | \ | \ | \ | 96 | \ |
| Dense | \ | \ | \ | 256 | sigmoid |
用 tensorflow(實(shí)際使用 tensorflow 中的 keras 模塊)實(shí)現(xiàn)該模型的代碼為:
from tensorflow.keras.layers import *
from tensorflow.keras.models import load_model, Model
def test_model_tf(Input_shape):
# shape: [B, C, T, F]
main_input = Input(batch_shape=Input_shape, name='main_inputs')
conv = Conv2D(32, kernel_size=(4, 4), strides=(1, 2), activation='relu', data_format='channels_first', name='conv')(main_input)
# shape: [B, T, FC]
gru = Reshape((conv.shape[2], conv.shape[1] * conv.shape[3]))(conv)
gru = GRU(units=96, reset_after=True, return_sequences=True, name='gru')(gru)
output = Dense(256, activation='sigmoid', name='output')(gru)
model = Model(inputs=[main_input], outputs=[output])
return model用 pytorch 實(shí)現(xiàn)該模型的代碼為:
import torch
import torch.nn as nn
class test_model_torch(nn.Module):
def __init__(self):
super(test_model_torch, self).__init__()
self.conv2d = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=(4,4), stride=(1,2))
self.relu = nn.ReLU()
self.gru = nn.GRU(input_size=4064, hidden_size=96)
self.fc = nn.Linear(96, 256)
self.sigmoid = nn.Sigmoid()
def forward(self, inputs):
# shape: [B, C, T, F]
out = self.conv2d(inputs)
out = self.relu(out)
# shape: [B, T, FC]
batch, channel, frame, freq = out.size()
out = torch.reshape(out, (batch, frame, freq*channel))
out, _ = self.gru(out)
out = self.fc(out)
out = self.sigmoid(out)
return out
3. 計(jì)算模型的 FLOPs
本節(jié)討論的版本具體為:tensorflow 1.12.0, tensorflow 2.3.1 以及 pytorch 1.10.1+cu102。
3.1. tensorflow 1.12.0
在 tensorflow 1.12.0 環(huán)境中,可以使用以下代碼計(jì)算模型的 FLOPs:
import tensorflow as tf
import tensorflow.keras.backend as K
def get_flops(model):
run_meta = tf.RunMetadata()
opts = tf.profiler.ProfileOptionBuilder.float_operation()
flops = tf.profiler.profile(graph=K.get_session().graph,
run_meta=run_meta, cmd='op', options=opts)
return flops.total_float_ops
if __name__ == "__main__":
x = K.random_normal(shape=(1, 1, 100, 256))
model = test_model_tf(x.shape)
print('FLOPs of tensorflow 1.12.0:', get_flops(model))3.2. tensorflow 2.3.1
在 tensorflow 2.3.1 環(huán)境中,可以使用以下代碼計(jì)算模型的 FLOPs :
import tensorflow.compat.v1 as tf
import tensorflow.compat.v1.keras.backend as K
tf.disable_eager_execution()
def get_flops(model):
run_meta = tf.RunMetadata()
opts = tf.profiler.ProfileOptionBuilder.float_operation()
flops = tf.profiler.profile(graph=K.get_session().graph,
run_meta=run_meta, cmd='op', options=opts)
return flops.total_float_ops
if __name__ == "__main__":
x = K.random_normal(shape=(1, 1, 100, 256))
model = test_model_tf(x.shape)
print('FLOPs of tensorflow 2.3.1:', get_flops(model))3.3. pytorch 1.10.1+cu102
在 pytorch 1.10.1+cu102 環(huán)境中,可以使用以下代碼計(jì)算模型的 FLOPs(需要安裝 thop):
import thop
x = torch.randn(1, 1, 100, 256)
model = test_model_torch()
flops, _ = thop.profile(model, inputs=(x,))
print('FLOPs of pytorch 1.10.1:', flops * 2)需要注意的是,thop 返回的是 MACs (Multiply–Accumulate Operations),其等于 2 2 2 倍的 FLOPs,所以上述代碼有乘 2 2 2 操作。
3.4. 結(jié)果對(duì)比
三者計(jì)算出的 FLOPs 分別為:
tensorflow 1.12.0:

tensorflow 2.3.1:

pytorch 1.10.1:

可以看到 tensorflow 1.12.0 和 tensorflow 2.3.1 的結(jié)果基本在同一個(gè)量級(jí),而與 pytorch 1.10.1 計(jì)算出來的相差甚遠(yuǎn)。但如果將上述模型結(jié)構(gòu)改為只包含第一層 Conv2D,三者計(jì)算出來的 FLOPs 卻又是一致的。所以推斷差異主要來自于 GRU 的 FLOPs。如讀者知道其中詳情,還請(qǐng)不吝賜教。
4. 總結(jié)
本文給出了在 tensorflow 1.x, tensorflow 2.x 以及 pytorch 中利用相關(guān)工具計(jì)算模型 FLOPs 的方法,但從本文所使用的測(cè)試模型來看, tensorflow 與 pytorch 統(tǒng)計(jì)出的結(jié)果相差甚遠(yuǎn)。當(dāng)然,也可以根據(jù)網(wǎng)絡(luò)層的類型及其對(duì)應(yīng)的參數(shù),推導(dǎo)計(jì)算出每個(gè)網(wǎng)絡(luò)層所需的 FLOPs。
到此這篇關(guān)于計(jì)算 tensorflow 和 pytorch 模型的浮點(diǎn)運(yùn)算數(shù)的文章就介紹到這了,更多相關(guān)tensorflow 和 pytorch浮點(diǎn)運(yùn)算數(shù)內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
Python實(shí)現(xiàn)列表轉(zhuǎn)Excel表格的第一列
這篇文章主要為大家詳細(xì)介紹了如何將Python中的列表轉(zhuǎn)換為Excel表格的第一列,并通過案例和代碼展示具體的操作步驟,希望可以幫助大家快速掌握這一技能2024-04-04
Django利用Channels+websocket開發(fā)聊天室完整案例
Channels是Django團(tuán)隊(duì)研發(fā)的一個(gè)給Django提供websocket支持的框架,使用它我們可以輕松開發(fā)需要長鏈接的實(shí)時(shí)通訊應(yīng)用,下面這篇文章主要給大家介紹了關(guān)于Django利用Channels+websocket開發(fā)聊天室的相關(guān)資料,需要的朋友可以參考下2023-06-06
Python3 全自動(dòng)更新已安裝的模塊實(shí)現(xiàn)
這篇文章主要介紹了Python3 全自動(dòng)更新已安裝的模塊實(shí)現(xiàn),文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2020-01-01
python將excel轉(zhuǎn)換為csv的代碼方法總結(jié)
在本篇文章里小編給大家分享了關(guān)于python如何將excel轉(zhuǎn)換為csv的實(shí)例方法和代碼內(nèi)容,需要的朋友們學(xué)習(xí)下。2019-07-07
Django中的CACHE_BACKEND參數(shù)和站點(diǎn)級(jí)Cache設(shè)置
這篇文章主要介紹了Django中的CACHE_BACKEND參數(shù)和站點(diǎn)級(jí)Cache設(shè)置,Python是最具人氣的Python web框架,需要的朋友可以參考下2015-07-07

