PyTorch的自適應(yīng)池化Adaptive Pooling實(shí)例
簡(jiǎn)介
自適應(yīng)池化Adaptive Pooling是PyTorch含有的一種池化層,在PyTorch的中有六種形式:
自適應(yīng)最大池化Adaptive Max Pooling:
torch.nn.AdaptiveMaxPool1d(output_size)
torch.nn.AdaptiveMaxPool2d(output_size)
torch.nn.AdaptiveMaxPool3d(output_size)
自適應(yīng)平均池化Adaptive Average Pooling:
torch.nn.AdaptiveAvgPool1d(output_size)
torch.nn.AdaptiveAvgPool2d(output_size)
torch.nn.AdaptiveAvgPool3d(output_size)
具體可見官方文檔。
官方給出的例子: >>> # target output size of 5x7 >>> m = nn.AdaptiveMaxPool2d((5,7)) >>> input = torch.randn(1, 64, 8, 9) >>> output = m(input) >>> output.size() torch.Size([1, 64, 5, 7]) >>> # target output size of 7x7 (square) >>> m = nn.AdaptiveMaxPool2d(7) >>> input = torch.randn(1, 64, 10, 9) >>> output = m(input) >>> output.size() torch.Size([1, 64, 7, 7]) >>> # target output size of 10x7 >>> m = nn.AdaptiveMaxPool2d((None, 7)) >>> input = torch.randn(1, 64, 10, 9) >>> output = m(input) >>> output.size() torch.Size([1, 64, 10, 7])
Adaptive Pooling特殊性在于,輸出張量的大小都是給定的output_size output\_sizeoutput_size。例如輸入張量大小為(1, 64, 8, 9),設(shè)定輸出大小為(5,7),通過(guò)Adaptive Pooling層,可以得到大小為(1, 64, 5, 7)的張量。
原理

>>> inputsize = 9 >>> outputsize = 4 >>> input = torch.randn(1, 1, inputsize) >>> input tensor([[[ 1.5695, -0.4357, 1.5179, 0.9639, -0.4226, 0.5312, -0.5689, 0.4945, 0.1421]]]) >>> m1 = nn.AdaptiveMaxPool1d(outputsize) >>> m2 = nn.MaxPool1d(kernel_size=math.ceil(inputsize / outputsize), stride=math.floor(inputsize / outputsize), padding=0) >>> output1 = m1(input) >>> output2 = m2(input) >>> output1 tensor([[[1.5695, 1.5179, 0.5312, 0.4945]]]) torch.Size([1, 1, 4]) >>> output2 tensor([[[1.5695, 1.5179, 0.5312, 0.4945]]]) torch.Size([1, 1, 4])
通過(guò)實(shí)驗(yàn)發(fā)現(xiàn):

下面是Adaptive Average Pooling的c++源碼部分。
template <typename scalar_t>
static void adaptive_avg_pool2d_out_frame(
scalar_t *input_p,
scalar_t *output_p,
int64_t sizeD,
int64_t isizeH,
int64_t isizeW,
int64_t osizeH,
int64_t osizeW,
int64_t istrideD,
int64_t istrideH,
int64_t istrideW)
{
int64_t d;
#pragma omp parallel for private(d)
for (d = 0; d < sizeD; d++)
{
/* loop over output */
int64_t oh, ow;
for(oh = 0; oh < osizeH; oh++)
{
int istartH = start_index(oh, osizeH, isizeH);
int iendH = end_index(oh, osizeH, isizeH);
int kH = iendH - istartH;
for(ow = 0; ow < osizeW; ow++)
{
int istartW = start_index(ow, osizeW, isizeW);
int iendW = end_index(ow, osizeW, isizeW);
int kW = iendW - istartW;
/* local pointers */
scalar_t *ip = input_p + d*istrideD + istartH*istrideH + istartW*istrideW;
scalar_t *op = output_p + d*osizeH*osizeW + oh*osizeW + ow;
/* compute local average: */
scalar_t sum = 0;
int ih, iw;
for(ih = 0; ih < kH; ih++)
{
for(iw = 0; iw < kW; iw++)
{
scalar_t val = *(ip + ih*istrideH + iw*istrideW);
sum += val;
}
}
/* set output to local average */
*op = sum / kW / kH;
}
}
}
}
以上這篇PyTorch的自適應(yīng)池化Adaptive Pooling實(shí)例就是小編分享給大家的全部?jī)?nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
Python+ChatGPT制作一個(gè)AI實(shí)用百寶箱
ChatGPT最近在互聯(lián)網(wǎng)掀起了一陣熱潮,其高度智能化的功能能夠給我們現(xiàn)實(shí)生活帶來(lái)諸多的便利。本文就來(lái)用Python和ChatGPT制作一個(gè)AI實(shí)用百寶箱吧2023-02-02
python smtplib模塊實(shí)現(xiàn)發(fā)送郵件帶附件sendmail
這篇文章主要為大家詳細(xì)介紹了python smtplib模塊實(shí)現(xiàn)發(fā)送郵件帶附件sendmail,具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2018-05-05
Python2.7簡(jiǎn)單連接與操作MySQL的方法
這篇文章主要介紹了Python2.7簡(jiǎn)單連接與操作MySQL的方法,涉及Python使用MySQLdb模塊操作MySQL連接及命令運(yùn)行的相關(guān)技巧,需要的朋友可以參考下2016-04-04
利用Python繪制Jazz網(wǎng)絡(luò)圖的例子
今天小編就為大家分享一篇利用Python繪制Jazz網(wǎng)絡(luò)圖的例子,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2019-11-11
Python3 常用數(shù)據(jù)標(biāo)準(zhǔn)化方法詳解
這篇文章主要介紹了Python3 常用數(shù)據(jù)標(biāo)準(zhǔn)化方法詳解,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2021-03-03
Python利用partial偏函數(shù)生成不同的聚合函數(shù)
本文主要介紹了Python利用partial偏函數(shù)生成不同的聚合函數(shù),利用偏函數(shù)的概念,可以生成一些新的函數(shù),在調(diào)用這些新函數(shù)時(shí),不用再傳遞固定值的參數(shù),這樣可以使代碼更簡(jiǎn)潔,感興趣的可以了解一下2024-03-03
通過(guò)Python將MP4視頻轉(zhuǎn)換為GIF動(dòng)畫
Python可用于讀取常見的MP4視頻格式并將其轉(zhuǎn)換為GIF動(dòng)畫。本文將詳細(xì)為大家介紹實(shí)現(xiàn)的過(guò)程,文中的代碼具有一定的參考價(jià)值,感興趣的小伙伴可以學(xué)習(xí)一下2021-12-12

