pytorch中的reshape()、view()、nn.flatten()和flatten()使用
在使用pytorch定義神經(jīng)網(wǎng)絡(luò)結(jié)構(gòu)時(shí),經(jīng)常會(huì)看到類(lèi)似如下的.view() / flatten()用法,這里對(duì)其用法做出講解與演示。

torch.reshape用法
reshape()可以由torch.reshape(),也可由torch.Tensor.reshape()調(diào)用,其作用是在不改變tensor元素?cái)?shù)目的情況下改變tensor的shape。
torch.reshape() 需要兩個(gè)參數(shù),一個(gè)是待被改變的張量tensor,一個(gè)是想要改變的形狀。
torch.reshape(input, shape) → Tensor
- input(Tensor)-要重塑的張量
- shape(python的元組:ints)-新形狀`
案例1.
輸入:
import torch a = torch.tensor([[0,1],[2,3]]) x = torch.reshape(a,(-1,)) print (x) b = torch.arange(4.) Y = torch.reshape(a,(2,2)) print(Y)
結(jié)果:
tensor([0, 1, 2, 3])
tensor([[0, 1],
[2, 3]])
torch.view用法
view()的原理很簡(jiǎn)單,其實(shí)就是把原先tensor中的數(shù)據(jù)進(jìn)行排列,排成一行,然后根據(jù)所給的view()中的參數(shù)從一行中按順序選擇組成最終的tensor。
view()可以有多個(gè)參數(shù),這取決于你想要得到的是幾維的tensor,一般設(shè)置兩個(gè)參數(shù),也是神經(jīng)網(wǎng)絡(luò)中常用的(一般在全連接之前),代表二維。
view(h,w),h代表行(想要變?yōu)閹仔校?,?dāng)不知道要變?yōu)閹仔?,但知道要變?yōu)閹琢袝r(shí)可取-1;w代表的是列(想要變?yōu)閹琢校?,?dāng)不知道要變?yōu)閹琢?,但知道要變?yōu)閹仔袝r(shí)可取-1。
一、普通用法(手動(dòng)調(diào)整)
view()相當(dāng)于reshape、resize,重新調(diào)整Tensor的形狀。
案例2.
輸入
import torch a1 = torch.arange(0,16) print(a1)
輸出
tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])
輸入
a2 = a1.view(8, 2) a3 = a1.view(2, 8) a4 = a1.view(4, 4) print(a2) print(a3) print(a4)
輸出
tensor([[ 0, 1],
[ 2, 3],
[ 4, 5],
[ 6, 7],
[ 8, 9],
[10, 11],
[12, 13],
[14, 15]])
tensor([[ 0, 1, 2, 3, 4, 5, 6, 7],
[ 8, 9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]])
二、特殊用法:參數(shù)-1(自動(dòng)調(diào)整size)
view中一個(gè)參數(shù)定為-1,代表自動(dòng)調(diào)整這個(gè)維度上的元素個(gè)數(shù),以保證元素的總數(shù)不變。
輸入
import torch a1 = torch.arange(0,16) print(a1)
輸出
tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])
輸入
a2 = a1.view(-1, 16) a3 = a1.view(-1, 8) a4 = a1.view(-1, 4) a5 = a1.view(-1, 2) a6 = a1.view(4*4, -1) a7 = a1.view(1*4, -1) a8 = a1.view(2*4, -1) print(a2) print(a3) print(a4) print(a5) print(a6) print(a7) print(a8)
輸出
tensor([[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0, 1, 2, 3, 4, 5, 6, 7],
[ 8, 9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]])
tensor([[ 0, 1],
[ 2, 3],
[ 4, 5],
[ 6, 7],
[ 8, 9],
[10, 11],
[12, 13],
[14, 15]])
tensor([[ 0],
[ 1],
[ 2],
[ 3],
[ 4],
[ 5],
[ 6],
[ 7],
[ 8],
[ 9],
[10],
[11],
[12],
[13],
[14],
[15]])
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]])
tensor([[ 0, 1],
[ 2, 3],
[ 4, 5],
[ 6, 7],
[ 8, 9],
[10, 11],
[12, 13],
[14, 15]])
torch.nn.Flatten(start_dim=1,end_dim=-1)
start_dim與end_dim分別表示開(kāi)始的維度和終止的維度,默認(rèn)值為1和-1,其中1表示第一維度,-1表示最后的維度。結(jié)合起來(lái)看意思就是從第一維度到最后一個(gè)維度全部給展平為張量。(注意:數(shù)據(jù)的維度是從0開(kāi)始的,也就是存在第0維度,第一維度并不是真正意義上的第一個(gè))。
因?yàn)槠浔挥迷谏窠?jīng)網(wǎng)絡(luò)中,輸入為一批數(shù)據(jù),第 0 維為batch(輸入數(shù)據(jù)的個(gè)數(shù)),通常要把一個(gè)數(shù)據(jù)拉成一維,而不是將一批數(shù)據(jù)拉為一維。所以torch.nn.Flatten()默認(rèn)從第一維開(kāi)始平坦化。
使用nn.Flatten(),使用默認(rèn)參數(shù)
官方給出的示例:
input = torch.randn(32, 1, 5, 5) # With default parameters m = nn.Flatten() output = m(input) output.size() #torch.Size([32, 25]) # With non-default parameters m = nn.Flatten(0, 2) output = m(input) output.size() #torch.Size([160, 5])
#開(kāi)頭的代碼是注釋
整段代碼的意思是:給定一個(gè)維度為(32,1,5,5)的隨機(jī)數(shù)據(jù)。
1.先使用一次nn.Flatten(),使用默認(rèn)參數(shù):
m = nn.Flatten()
也就是說(shuō)從第一維度展平到最后一個(gè)維度,數(shù)據(jù)的維度是從0開(kāi)始的,第一維度實(shí)際上是數(shù)據(jù)的第二位置代表的維度,也就是樣例中的1。
因此進(jìn)行展平后的結(jié)果也就是[32,155]→[32,25]
2.接著再使用一次指定參數(shù)的nn.Flatten(),即
m = nn.Flatten(0,2)
也就是說(shuō)從第0維度展平到第2維度,0~2,對(duì)應(yīng)的也就是前三個(gè)維度。
因此結(jié)果就是[3215,5]→[160,25]
torch.flatten
torch.flatten()函數(shù)經(jīng)常用于寫(xiě)分類(lèi)神經(jīng)網(wǎng)絡(luò)的時(shí)候,經(jīng)過(guò)最后一個(gè)卷積層之后,一般會(huì)再接一個(gè)自適應(yīng)的池化層,輸出一個(gè)BCHW的向量。
這時(shí)候就需要用到torch.flatten()函數(shù)將這個(gè)向量拉平成一個(gè)Bx的向量(其中,x = CHW),然后送入到FC層中。

語(yǔ)句結(jié)構(gòu)
torch.flatten(input, start_dim=0, end_dim=-1)
input: 一個(gè) tensor,即要被“攤平”的 tensor。
- start_dim: “攤平”的起始維度。
- end_dim: “攤平”的結(jié)束維度。
作用與 torch.nn.flatten 類(lèi)似,都是用于展平 tensor 的,只是 torch.flatten 是 function 而不是類(lèi),其默認(rèn)開(kāi)始維度為第 0 維。
例1:
import torch data_pool = torch.randn(2,2,3,3) # 模擬經(jīng)過(guò)最后一個(gè)池化層或自適應(yīng)池化層之后的輸出,Batchsize*c*h*w print(data_pool) y=torch.flatten(data_pool,1) print(y)
輸出結(jié)果:

結(jié)果是一個(gè)B*x的向量。
總結(jié)
以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
使用Python打造一個(gè)專(zhuān)業(yè)的PDF文本提取工具
這篇文章主要為大家詳細(xì)介紹了如何使用Python開(kāi)發(fā)一個(gè)專(zhuān)業(yè)的PDF文本提取工具,幫助大家從PDF文檔中高效提取結(jié)構(gòu)化文本數(shù)據(jù),適用于數(shù)據(jù)分析,內(nèi)容歸檔和知識(shí)管理等場(chǎng)景2025-07-07
從入門(mén)到實(shí)戰(zhàn)詳解Python如何將Excel工作表轉(zhuǎn)換為PDF
這篇文章主要為大家詳細(xì)介紹了Python如何將Excel工作表轉(zhuǎn)換為PDF,文中的示例代碼講解詳細(xì),感興趣的小伙伴可以跟隨小編一起學(xué)習(xí)一下2025-11-11
使用pytorch加載并讀取COCO數(shù)據(jù)集的詳細(xì)操作
這篇文章主要介紹了使用pytorch加載并讀取COCO數(shù)據(jù)集,基礎(chǔ)知識(shí)包括元祖、字典、數(shù)組,本文通過(guò)實(shí)例代碼給大家介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2022-05-05
Python 批量操作設(shè)備的實(shí)現(xiàn)步驟
本文將結(jié)合實(shí)例代碼,介紹Python 批量操作設(shè)備的實(shí)現(xiàn)步驟,文中通過(guò)示例代碼介紹的非常詳細(xì),需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2021-07-07
Python通過(guò)requests模塊實(shí)現(xiàn)抓取王者榮耀全套皮膚
只學(xué)書(shū)上的理論是遠(yuǎn)遠(yuǎn)不如實(shí)踐帶來(lái)的提升快,只有在實(shí)例中才能獲得能力的提升,本篇文章手把手帶你用Python實(shí)現(xiàn)抓取王者榮耀全套皮膚,大家可以在過(guò)程中查缺補(bǔ)漏,提升水平2021-10-10
python代碼如何實(shí)現(xiàn)切換中英文輸入法
這篇文章主要介紹了python代碼如何實(shí)現(xiàn)切換中英文輸入法,具有很好的參考價(jià)值,希望對(duì)大家有所幫助,如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2023-11-11
Restful_framework視圖組件代碼實(shí)例解析
這篇文章主要介紹了Restful_framework視圖組件代碼實(shí)例解析,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2020-11-11
Django框架的使用教程路由請(qǐng)求響應(yīng)的方法
這篇文章主要介紹了Django框架的使用教程路由請(qǐng)求響應(yīng)的方法,小編覺(jué)得挺不錯(cuò)的,現(xiàn)在分享給大家,也給大家做個(gè)參考。一起跟隨小編過(guò)來(lái)看看吧2018-07-07
Python實(shí)現(xiàn)時(shí)鐘顯示效果思路詳解
這篇文章主要介紹了Python實(shí)現(xiàn)時(shí)鐘顯示,需要的朋友可以參考下2018-04-04

