Pytorch深度學(xué)習(xí)gather一些使用問(wèn)題解決方案
問(wèn)題場(chǎng)景描述
我在復(fù)現(xiàn)Faster-RCNN模型的過(guò)程中遇到這樣一個(gè)問(wèn)題:
有一個(gè)張量,它的形狀是 (128, 21, 4)
roi_loc.shape = (128, 21, 4)
與之對(duì)應(yīng)的還有一個(gè)label數(shù)據(jù)
gt_label.shape = (128)

我現(xiàn)在的需求是將label當(dāng)作第一個(gè)張量在dim=1上的索引,將其中的數(shù)據(jù)拿出來(lái)。
具體來(lái)說(shuō)就是,現(xiàn)在有128個(gè)樣本數(shù)據(jù),每個(gè)樣本中有21個(gè)長(zhǎng)度為4的向量。label也是128個(gè),每個(gè)值代表取出21個(gè)向量中的哪一個(gè)。
問(wèn)題的思考
我嘗試了很多辦法,包括布爾索引,index_select方法等,最后發(fā)現(xiàn)都不適用(也有可能我沒(méi)用好)。最后利用gather API解決了這個(gè)問(wèn)題。

這個(gè)API的說(shuō)明我看了很多遍都沒(méi)看懂,我相信絕大部分讀者也是因?yàn)榭床欢@個(gè)說(shuō)明才來(lái)這兒的。
下面我給出自己的一些理解:
gather的說(shuō)明
gather所需要的第一個(gè)參數(shù)是待索引的數(shù)據(jù),在我們的問(wèn)題中 roi_loc就是這個(gè)input。第二個(gè)參數(shù)dim,是你的索引數(shù)據(jù)要作用在哪個(gè)軸上,正如前面所言,我們想索引第二個(gè)軸(dim=1).
最難理解的是index,index就是我們想要用來(lái)索引的張量,對(duì)應(yīng)的是label??墒莑abel不能直接拿來(lái)用,得先做一定的變換,這也就是gather的難點(diǎn)。
我們先從簡(jiǎn)單的情況來(lái)看
input和gather必須在維度上相同,假設(shè)數(shù)據(jù)還是3 * 3,index也是1 * 3的(注意這里是二維的)

此時(shí)row至多取值0,col至多取值為2
如果我要對(duì)dim=0索引
那么data[0][0] = data[index[0][0]] [0] = data[1][0] = 2
data[0][1] = data[index[0][1]] [1] = data[0][1] = 5
data[0][2] = data[index[0][2]][2] = data[2][2] = 9
上面的過(guò)程可以描述為,第一列的元素我想選第二行的,第二列的元素我想選第一行的,第三列的元素我想選第三行的。
可以發(fā)現(xiàn)因?yàn)閕ndex是1 * 3的,所以最后的輸出也是31* 3,即輸出張量的shape取決于index的shape

以上過(guò)程我相信讀者好好體悟應(yīng)該可以理解。
問(wèn)題的解決
回到我們的問(wèn)題
roi_loc.shape = (128, 21, 4),gt_label.shape = (128)
我們想索引dim=1,最后的結(jié)果應(yīng)該是(128, 4)
由上面的說(shuō)明可以知道,input和index的dimension首先得相同
idx = gt_roi_labels.unsqueeze(-1).unsqueeze(-1) idx.shape = (128, 1, 1)
又因?yàn)槲覀兿胍敵龅慕Y(jié)果得是(128, 4),所以得讓idx在最后一個(gè)軸上重復(fù)4次
idx = idx.repeat_interleave(-1, 4) idx.shape = (128, 1, 4)
現(xiàn)在就可以利用gather在dim=1上索引了
result = roi_loc.gather(1, idx) result.shape = (128, 1, 4)
最后將長(zhǎng)度為1的軸壓縮(本身這個(gè)軸的出現(xiàn)是為了滿(mǎn)足input和index維度一樣的要求)
result = result.squeeze(1) result.shape(128, 4)
以上就是Pytorch深度學(xué)習(xí)gather一些使用問(wèn)題解決方案的詳細(xì)內(nèi)容,更多關(guān)于Pytorch學(xué)習(xí)gather使用問(wèn)題的資料請(qǐng)關(guān)注腳本之家其它相關(guān)文章!
相關(guān)文章
Pycharm配置遠(yuǎn)程SSH服務(wù)器實(shí)現(xiàn)(切換不同虛擬環(huán)境)
本文主要介紹了Pycharm配置遠(yuǎn)程SSH服務(wù)器實(shí)現(xiàn),文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2023-02-02
Python實(shí)現(xiàn)將數(shù)據(jù)庫(kù)一鍵導(dǎo)出為Excel表格的實(shí)例
下面小編就為大家?guī)?lái)一篇Python實(shí)現(xiàn)將數(shù)據(jù)庫(kù)一鍵導(dǎo)出為Excel表格的實(shí)例。小編覺(jué)得挺不錯(cuò)的,現(xiàn)在就分享給大家,也給大家做個(gè)參考。一起跟隨小編過(guò)來(lái)看看吧2016-12-12
Tensorflow讀取并輸出已保存模型的權(quán)重?cái)?shù)值方式
今天小編就為大家分享一篇Tensorflow讀取并輸出已保存模型的權(quán)重?cái)?shù)值方式,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看不看2020-01-01
在Apache服務(wù)器上同時(shí)運(yùn)行多個(gè)Django程序的方法
這篇文章主要介紹了在Apache服務(wù)器上同時(shí)運(yùn)行多個(gè)Django程序的方法,Django是Python各色高人氣web框架中最為著名的一個(gè),需要的朋友可以參考下2015-07-07
通過(guò)數(shù)據(jù)庫(kù)向Django模型添加字段的示例
這篇文章主要介紹了通過(guò)數(shù)據(jù)庫(kù)向Django模型添加字段的示例,Django是人氣最高的Python web開(kāi)發(fā)框架,需要的朋友可以參考下2015-07-07
深入了解Python數(shù)據(jù)類(lèi)型之列表
下面小編就為大家?guī)?lái)一篇深入了解Python數(shù)據(jù)類(lèi)型之列表。小編覺(jué)得挺不錯(cuò)的,現(xiàn)在就分享給大家,也給大家做個(gè)參考。一起跟隨小編過(guò)來(lái)看看吧2016-06-06

