pytorch中nn.Flatten()函數(shù)詳解及示例
torch.nn.Flatten(start_dim=1, end_dim=- 1)
作用:將連續(xù)的維度范圍展平為張量。 經(jīng)常在nn.Sequential()中出現(xiàn),一般寫在某個(gè)神經(jīng)網(wǎng)絡(luò)模型之后,用于對(duì)神經(jīng)網(wǎng)絡(luò)模型的輸出進(jìn)行處理,得到tensor類型的數(shù)據(jù)。

有倆個(gè)參數(shù),start_dim和end_dim,分別表示開始的維度和終止的維度,默認(rèn)值分別是1和-1,其中1表示第一維度,-1表示最后的維度。結(jié)合起來看意思就是從第一維度到最后一個(gè)維度全部給展平為張量。(注意:數(shù)據(jù)的維度是從0開始的,也就是存在第0維度,第一維度并不是真正意義上的第一個(gè))
同理,如果我這么寫:
self.flat = nn.Flatten(start_dim=2, end_dim=3)
那么意思就是從第二維度開始,到第三維度全部給展平,也就是將2、3兩個(gè)維度展平。
官網(wǎng)給出的示例:
input = torch.randn(32, 1, 5, 5) # With default parameters m = nn.Flatten() output = m(input) output.size() #torch.Size([32, 25]) # With non-default parameters m = nn.Flatten(0, 2) output = m(input) output.size() #torch.Size([160, 5])
#開頭的代碼是注釋
整段代碼的意思是:給定一個(gè)維度為(32,1,5,5)的隨機(jī)數(shù)據(jù)。
1.先使用一次nn.Flatten(),使用默認(rèn)參數(shù):
m = nn.Flatten()
也就是說從第一維度展平到最后一個(gè)維度,數(shù)據(jù)的維度是從0開始的,第一維度實(shí)際上是數(shù)據(jù)的第二個(gè)位置代表的維度,也就是樣例中的1。
因此進(jìn)行展平后的結(jié)果也就是[32,1×5×5]?[32,25]
2.接著再使用一次指定參數(shù)的nn.Flatten(),即
m = nn.Flatten(0, 2)
也就是說從第0維度展平到第2維度,0~2,對(duì)應(yīng)的也就是前三個(gè)維度。
因此結(jié)果就是[32×1×5,5]?[160,5]
因此進(jìn)行展平后的結(jié)果也就是[32,1*5*5]?[32,25]
示例1
卷積公式

import torch
import torch.nn as nn
input = torch.randn(32, 1, 5, 5)
m = nn.Sequential(
nn.Conv2d(1, 32, 5, 1, 1), # 通過卷積,得到torch.size([32, 32, 3, 3]
nn.Flatten())
output = m(input)
print(output.size())
>> torch.Size([32, 288])
示例2
import torch
import torch.nn as nn
input = torch.randn(32, 1, 5, 5)
m = nn.Sequential(
nn.Conv2d(1, 32, 5, 1, 1), # 通過卷積,得到torch.size([32, 32, 3, 3]
nn.Flatten(start_dim=0))
output = m(input)
print(output.size())
>>torch.Size([9216])
總結(jié)
到此這篇關(guān)于pytorch中nn.Flatten()函數(shù)詳解的文章就介紹到這了,更多相關(guān)pytorch nn.Flatten()函數(shù)詳解內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
完美解決torch.cuda.is_available()一直返回False的玄學(xué)方法
這篇文章主要介紹了完美解決torch.cuda.is_available()一直返回False的玄學(xué)方法,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2021-02-02
python報(bào)錯(cuò)TypeError: Input z must be
大家好,本篇文章主要講的是python報(bào)錯(cuò)TypeError: Input z must be 2D, not 3D的解決方法,感興趣的同學(xué)趕快來看一看吧,對(duì)你有幫助的話記得收藏一下2021-12-12
使用Python制作讀單詞視頻的實(shí)現(xiàn)代碼
我們經(jīng)常在B站或其他視頻網(wǎng)站上看到那種逐條讀單詞的視頻,但他們的視頻多多少少和我們的預(yù)期都不太一致,然而,網(wǎng)上很難找到和自己需求符合的視頻,所以本文給大家介紹了使用Python制作讀單詞視頻的實(shí)現(xiàn),需要的朋友可以參考下2024-04-04
Python pandas 的索引方式 data.loc[],data[][]示例詳解
這篇文章主要介紹了Python pandas 的索引方式 data.loc[], data[][]的相關(guān)資料,其中data.loc[index,column]使用.loc[ ]第一個(gè)參數(shù)是行索引,第二個(gè)參數(shù)是列索引,本文結(jié)合實(shí)例代碼講解的非常詳細(xì),需要的朋友可以參考下2023-02-02
快速解釋如何使用pandas的inplace參數(shù)的使用
這篇文章主要介紹了快速解釋如何使用pandas的inplace參數(shù)的使用,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2020-07-07
Pytorch torch.repeat_interleave()用法示例詳解
torch.repeat_interleave() 是 PyTorch 中的一個(gè)函數(shù),用于按指定的方式重復(fù)張量中的元素,這篇文章主要介紹了Pytorch torch.repeat_interleave()用法示例詳解,需要的朋友可以參考下2024-01-01

