pytorch對(duì)可變長(zhǎng)度序列的處理方法詳解
主要是用函數(shù)torch.nn.utils.rnn.PackedSequence()和torch.nn.utils.rnn.pack_padded_sequence()以及torch.nn.utils.rnn.pad_packed_sequence()來(lái)進(jìn)行的,分別來(lái)看看這三個(gè)函數(shù)的用法。
1、torch.nn.utils.rnn.PackedSequence()
NOTE: 這個(gè)類的實(shí)例不能手動(dòng)創(chuàng)建。它們只能被 pack_padded_sequence() 實(shí)例化。
PackedSequence對(duì)象包括:
一個(gè)data對(duì)象:一個(gè)torch.Variable(令牌的總數(shù),每個(gè)令牌的維度),在這個(gè)簡(jiǎn)單的例子中有五個(gè)令牌序列(用整數(shù)表示):(18,1)
一個(gè)batch_sizes對(duì)象:每個(gè)時(shí)間步長(zhǎng)的令牌數(shù)列表,在這個(gè)例子中為:[6,5,2,4,1]
用pack_padded_sequence函數(shù)來(lái)構(gòu)造這個(gè)對(duì)象非常的簡(jiǎn)單:

如何構(gòu)造一個(gè)PackedSequence對(duì)象(batch_first = True)
PackedSequence對(duì)象有一個(gè)很不錯(cuò)的特性,就是我們無(wú)需對(duì)序列解包(這一步操作非常慢)即可直接在PackedSequence數(shù)據(jù)變量上執(zhí)行許多操作。特別是我們可以對(duì)令牌執(zhí)行任何操作(即對(duì)令牌的順序/上下文不敏感)。當(dāng)然,我們也可以使用接受PackedSequence作為輸入的任何一個(gè)pyTorch模塊(pyTorch 0.2)。
2、torch.nn.utils.rnn.pack_padded_sequence()
這里的pack,理解成壓緊比較好。 將一個(gè) 填充過(guò)的變長(zhǎng)序列 壓緊。(填充時(shí)候,會(huì)有冗余,所以壓緊一下)
輸入的形狀可以是(T×B×* )。T是最長(zhǎng)序列長(zhǎng)度,B是batch size,*代表任意維度(可以是0)。如果batch_first=True的話,那么相應(yīng)的 input size 就是 (B×T×*)。
Variable中保存的序列,應(yīng)該按序列長(zhǎng)度的長(zhǎng)短排序,長(zhǎng)的在前,短的在后。即input[:,0]代表的是最長(zhǎng)的序列,input[:, B-1]保存的是最短的序列。
NOTE: 只要是維度大于等于2的input都可以作為這個(gè)函數(shù)的參數(shù)。你可以用它來(lái)打包labels,然后用RNN的輸出和打包后的labels來(lái)計(jì)算loss。通過(guò)PackedSequence對(duì)象的.data屬性可以獲取 Variable。
參數(shù)說(shuō)明:
input (Variable) – 變長(zhǎng)序列 被填充后的 batch
lengths (list[int]) – Variable 中 每個(gè)序列的長(zhǎng)度。
batch_first (bool, optional) – 如果是True,input的形狀應(yīng)該是B*T*size。
返回值:
一個(gè)PackedSequence 對(duì)象。
3、torch.nn.utils.rnn.pad_packed_sequence()
填充packed_sequence。
上面提到的函數(shù)的功能是將一個(gè)填充后的變長(zhǎng)序列壓緊。 這個(gè)操作和pack_padded_sequence()是相反的。把壓緊的序列再填充回來(lái)。
返回的Varaible的值的size是 T×B×*, T 是最長(zhǎng)序列的長(zhǎng)度,B 是 batch_size,如果 batch_first=True,那么返回值是B×T×*。
Batch中的元素將會(huì)以它們長(zhǎng)度的逆序排列。
參數(shù)說(shuō)明:
sequence (PackedSequence) – 將要被填充的 batch
batch_first (bool, optional) – 如果為True,返回的數(shù)據(jù)的格式為 B×T×*。
返回值: 一個(gè)tuple,包含被填充后的序列,和batch中序列的長(zhǎng)度列表。
例子:
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.nn import utils as nn_utils
batch_size = 2
max_length = 3
hidden_size = 2
n_layers =1
tensor_in = torch.FloatTensor([[1, 2, 3], [1, 0, 0]]).resize_(2,3,1)
tensor_in = Variable( tensor_in ) #[batch, seq, feature], [2, 3, 1]
seq_lengths = [3,1] # list of integers holding information about the batch size at each sequence step
# pack it
pack = nn_utils.rnn.pack_padded_sequence(tensor_in, seq_lengths, batch_first=True)
# initialize
rnn = nn.RNN(1, hidden_size, n_layers, batch_first=True)
h0 = Variable(torch.randn(n_layers, batch_size, hidden_size))
#forward
out, _ = rnn(pack, h0)
# unpack
unpacked = nn_utils.rnn.pad_packed_sequence(out)
print('111',unpacked)
輸出:
111 (Variable containing: (0 ,.,.) = 0.5406 0.3584 -0.1403 0.0308 (1 ,.,.) = -0.6855 -0.9307 0.0000 0.0000 [torch.FloatTensor of size 2x2x2] , [2, 1])
以上這篇pytorch對(duì)可變長(zhǎng)度序列的處理方法詳解就是小編分享給大家的全部?jī)?nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
- 對(duì)pytorch網(wǎng)絡(luò)層結(jié)構(gòu)的數(shù)組化詳解
- pytorch 轉(zhuǎn)換矩陣的維數(shù)位置方法
- pytorch 調(diào)整某一維度數(shù)據(jù)順序的方法
- 對(duì)PyTorch torch.stack的實(shí)例講解
- 使用pytorch進(jìn)行圖像的順序讀取方法
- mac安裝pytorch及系統(tǒng)的numpy更新方法
- 淺談pytorch和Numpy的區(qū)別以及相互轉(zhuǎn)換方法
- pytorch + visdom CNN處理自建圖片數(shù)據(jù)集的方法
- PyTorch CNN實(shí)戰(zhàn)之MNIST手寫數(shù)字識(shí)別示例
- PyTorch 1.0 正式版已經(jīng)發(fā)布了
相關(guān)文章
使用PyTorch實(shí)現(xiàn)限制GPU顯存的可使用上限
從?PyTorch?1.4?版本開始,引入了一個(gè)新的功能,可以允許用戶為特定的?GPU?設(shè)備設(shè)置進(jìn)程可使用的顯存上限比例,下面我們就來(lái)看看具體實(shí)現(xiàn)方法吧2024-03-03
pytorch 實(shí)現(xiàn)cross entropy損失函數(shù)計(jì)算方式
今天小編就為大家分享一篇pytorch 實(shí)現(xiàn)cross entropy損失函數(shù)計(jì)算方式,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2020-01-01
OpenCV圖像卷積之cv.filter2D()函數(shù)詳解
在其官方文檔中,filter2D()函數(shù)在掩模板介紹中一筆帶過(guò),我認(rèn)為該函數(shù)應(yīng)該進(jìn)行詳細(xì)介紹,下面這篇文章主要給大家介紹了關(guān)于OpenCV圖像卷積之cv.filter2D()函數(shù)的相關(guān)資料,需要的朋友可以參考下2022-09-09
Python實(shí)現(xiàn)遍歷子文件夾并將文件復(fù)制到不同的目標(biāo)文件夾
這篇文章主要介紹了如何基于Python語(yǔ)言實(shí)現(xiàn)遍歷多個(gè)子文件夾,將每一個(gè)子文件夾中大量的文件,按照每一個(gè)文件的文件名稱的特點(diǎn)復(fù)制到不同的目標(biāo)文件夾中,感興趣的可以了解下2023-08-08
Django nginx配置實(shí)現(xiàn)過(guò)程詳解
這篇文章主要介紹了Django nginx配置實(shí)現(xiàn)過(guò)程詳解,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2020-09-09
Python中最好用的json庫(kù)orjson用法詳解
orjson是一個(gè)用于python的快速、正確的json庫(kù),它的基準(zhǔn)是 json最快的python庫(kù),具有全面的單元、集成和互操作性測(cè)試,下面這篇文章主要給大家介紹了關(guān)于Python中最好用的json庫(kù)orjson用法的相關(guān)資料,需要的朋友可以參考下2022-06-06

