Swin?Transformer圖像處理深度學(xué)習(xí)模型
Swin Transformer
Swin Transformer是一種用于圖像處理的深度學(xué)習(xí)模型,它可以用于各種計算機視覺任務(wù),如圖像分類、目標(biāo)檢測和語義分割等。它的主要特點是采用了分層的窗口機制,可以處理比較大的圖像,同時也減少了模型參數(shù)的數(shù)量,提高了計算效率。Swin Transformer在圖像處理領(lǐng)域取得了很好的表現(xiàn),成為了最先進(jìn)的模型之一。
Swin Transformer通過從小尺寸的圖像塊(用灰色輪廓線框出)開始,并逐漸合并相鄰塊,構(gòu)建了一個分層的表示形式,在更深層的Transformer中實現(xiàn)。

整體架構(gòu)


Swin Transformer 模塊

Swin Transformer模塊是基于Transformer塊中標(biāo)準(zhǔn)的多頭自注意力模塊(MSA)進(jìn)行替換構(gòu)建的,用的是一種基于滑動窗口的模塊(在后面細(xì)說),而其他層保持不變。如上圖所示,Swin Transformer模塊由基于滑動窗口的多頭注意力模塊組成,后跟一個2層MLP,在中間使用GELU非線性激活函數(shù)。在每個MSA模塊和每個MLP之前都應(yīng)用了LayerNorm(LN)層,并在每個模塊之后應(yīng)用了殘差連接。
滑動窗口機制


Cyclic Shift
Cyclic Shift是Swin Transformer中一種有效的處理局部特征的方法。在Swin Transformer中,為了處理高分辨率的輸入特征圖,需要將輸入特征圖分割成小塊(一個patch可能有多個像素)進(jìn)行處理。然而,這樣會導(dǎo)致局部特征在不同塊之間被分割開來,影響了局部特征的提取。Cyclic Shift將輸入特征圖沿著寬度和高度方向分別平移一個固定的距離,使得每個塊的局部特征可以與相鄰塊的局部特征進(jìn)行交互,從而增強了局部特征的表達(dá)能力。另外,Cyclic Shift還可以通過多次平移來增加塊之間的交互,進(jìn)一步提升了模型的性能。需要注意的是,Cyclic Shift只在訓(xùn)練過程中使用,因為它會改變輸入特征圖的分布。在測試過程中,輸入特征圖的大小和分布與訓(xùn)練時相同,因此不需要使用Cyclic Shift操作。
Efficient batch computation for shifted configuration
Cyclic Shift會將輸入特征圖沿著寬度和高度方向進(jìn)行平移操作,以便讓不同塊之間的局部特征進(jìn)行交互。這樣的操作會導(dǎo)致每個塊的特征值的位置發(fā)生改變,從而需要在每個塊上重新計算注意力機制。
為了加速計算過程,Swin Transformer中引入了"Efficient batch computation for shifted configuration"這一技巧。該技巧首先將每個塊的特征值復(fù)制多次,分別放置在Cyclic Shift平移后的不同位置上,使得每個塊都可以在平移后的不同的位置上參與到注意力機制的計算中。然后,將這些位置不同的塊的特征值進(jìn)行合并拼接,計算注意力。
需要注意的是,這種技巧只在訓(xùn)練時使用,因為它會增加計算量,而在測試時,可以將每個塊的特征值計算一次,然后在不同位置上進(jìn)行拼接,以得到最終的輸出。
Relative position bias
在傳統(tǒng)的Transformer模型中,為了考慮單詞之間的位置關(guān)系,通常采用絕對位置編碼(Absolute Positional Encoding)的方式。這種方法是在每個單詞的embedding中添加位置編碼向量,以表示該單詞在序列中的絕對位置。但是,當(dāng)序列長度很長時,絕對位置編碼會面臨兩個問題:
- 編碼向量的大小會隨著序列長度的增加而增加,導(dǎo)致模型參數(shù)量增大,訓(xùn)練難度加大;
- 當(dāng)序列長度超過一定限制時,模型的性能會下降。
為了解決這些問題,Swin Transformer采用了Relative Positional Encoding,它通過編碼單詞之間的相對位置信息來代替絕對位置編碼。相對位置編碼是由每個單詞對其它單詞的相對位置關(guān)系計算得出的。在計算相對位置時,Swin Transformer引入了Relative Position Bias,即相對位置偏置,它是一個可學(xué)習(xí)的參數(shù)矩陣,用于調(diào)整不同位置之間的相對位置關(guān)系。這樣做可以有效地減少相對位置編碼的參數(shù)量,同時提高模型的性能和效率。相對位置編碼可以通過以下公式計算:

最終,相對位置編碼和相對位置偏置的結(jié)果會被加到點積注意力機制中,用于計算不同位置之間的相關(guān)性,從而實現(xiàn)序列的建模。
代碼實現(xiàn):
下面是一個用PyTorch實現(xiàn)Swin B模型的示例代碼,其中包含了相對位置編碼和相對位置偏置的實現(xiàn):
import torch
import torch.nn as nn
from einops.layers.torch import Rearrange
class SwinBlock(nn.Module):
def __init__(self, in_channels, out_channels, window_size=7, shift_size=0):
super(SwinBlock, self).__init__()
self.window_size = window_size
self.shift_size = shift_size
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
self.norm1 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=window_size, stride=1, padding=window_size//2, groups=out_channels)
self.norm2 = nn.BatchNorm2d(out_channels)
self.conv3 = nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, padding=0)
self.norm3 = nn.BatchNorm2d(out_channels)
if in_channels == out_channels:
self.downsample = None
else:
self.downsample = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
self.norm_downsample = nn.BatchNorm2d(out_channels)
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.norm1(out)
out = nn.functional.relu(out)
out = Rearrange(out, 'b c h w -> b (h w) c')
out = self.shift_window(out)
out = Rearrange(out, 'b (h w) c -> b c h w', h=int(x.shape[2]), w=int(x.shape[3]))
out = self.conv2(out)
out = self.norm2(out)
out = nn.functional.relu(out)
out = self.conv3(out)
out = self.norm3(out)
if self.downsample is not None:
residual = self.downsample(x)
residual = self.norm_downsample(residual)
out += residual
out = nn.functional.relu(out)
return out
def shift_window(self, x):
# x: (B, L, C)
B, L, C = x.shape
if self.shift_size == 0:
shifted_x = torch.zeros_like(x)
shifted_x[:, self.window_size//2:L-self.window_size//2, :] = x[:, self.window_size//2:L-self.window_size//2, :]
return shifted_x
else:
# pad feature maps to shift window
left_pad = self.window_size // 2 + self.shift_size
right_pad = left_pad - self.shift_size
x = nn.functional.pad(x, (0, 0, left_pad, right_pad), mode='constant', value=0)
# Reshape X to (B, H, W, C)
H = W = int(x.shape[1] ** 0.5)
x = Rearrange(x, 'b (h w) c -> b c h w', h=H, w=W)
# Shift window
x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(2, 3))
# Reshape back to (B, L, C)
x = Rearrange(x, 'b c h w -> b (h w) c')
return x[:, self.window]
class SwinTransformer(nn.Module):
def __init__(self, in_channels=3, num_classes=1000, num_layers=12, embed_dim=96, window_sizes=(7, 3, 3, 3), shift_sizes=(0, 1, 2, 3)):
super(SwinTransformer, self).__init__()
self.in_channels = in_channels
self.num_classes = num_classes
self.num_layers = num_layers
self.embed_dim = embed_dim
self.window_sizes = window_sizes
self.shift_sizes = shift_sizes
self.conv1 = nn.Conv2d(in_channels, embed_dim, kernel_size=4, stride=4, padding=0)
self.norm1 = nn.BatchNorm2d(embed_dim)
self.blocks = nn.ModuleList()
for i in range(num_layers):
self.blocks.append(SwinBlock(embed_dim * 2**i, embed_dim * 2**(i+1), window_size=window_sizes[i%4], shift_size=shift_sizes[i%4]))
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(embed_dim * 2**num_layers, num_classes)
# add relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * (2 * window_sizes[-1] - 1), embed_dim // 8, embed_dim // 8)),
requires_grad=True)
nn.init.kaiming_uniform_(self.relative_position_bias_table, a=1)
# add relative position encoding
self.pos_embed = nn.Parameter(
torch.zeros(1, embed_dim * 2**num_layers, 7, 7),
requires_grad=True)
nn.init.kaiming_uniform_(self.pos_embed, a=1)
def forward(self, x):
out = self.conv1(x)
out = self.norm1(out)
out = nn.functional.relu(out)
for block in self.blocks:
out = block(out)
out = self.avgpool(out)
out = Rearrange(out, 'b c h w -> b (c h w)')
out = self.fc(out)
return out
def get_relative_position_bias(self, H, W):
# H, W: height and width of feature maps in the last block
# output: (2HW-1, 8, 8)
relative_position_bias_h = self.relative_position_bias_table[:,
:(2 * H - 1), :(2 * W - 1)].transpose(0, 1)
relative_position_bias_w = self.relative_position_bias_table[:,
(2 * H - 1):, (2 * W - 1):].transpose(0, 1)
relative_position_bias = torch.cat([relative_position_bias_h, relative_position_bias_w], dim=0)
return relative_position_bias
def get_relative_position_encoding(self, H, W):
# H, W: height and width of feature maps in the last block
# output: (1, HW, C)
pos_x, pos_y = torch.meshgrid(torch.arange(H), torch.arange(W))
pos_x, pos_y = pos_x.float(), pos_y.float()
pos_x = pos_x / (H-1) * 2 - 1
pos_y = pos_y / (W-1) * 2 - 1
pos_encoding = torch.stack((pos_y, pos_x), dim=-1)
pos_encoding = pos_encoding.reshape(1, -1, 2)
pos_encoding = pos_encoding.repeat(1, 1, embed_dim // 2)
pos_encoding = pos_encoding.transpose(1, 2)
return pos_encoding
以上就是Swin Transformer圖像處理深度學(xué)習(xí)模型的詳細(xì)內(nèi)容,更多關(guān)于Swin Transformer深度學(xué)習(xí)的資料請關(guān)注腳本之家其它相關(guān)文章!
相關(guān)文章
對Python通過pypyodbc訪問Access數(shù)據(jù)庫的方法詳解
今天小編就為大家分享一篇對Python通過pypyodbc訪問Access數(shù)據(jù)庫的方法詳解,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2018-10-10
PyCharm 創(chuàng)建指定版本的 Django(超詳圖解教程)
這篇文章主要介紹了PyCharm 創(chuàng)建指定版本的 Django,小編覺得挺不錯的,現(xiàn)在分享給大家,也給大家做個參考。一起跟隨小編過來看看吧2019-06-06
python 時間的訪問和轉(zhuǎn)換 time示例小結(jié)
Python 的 time 模塊提供了各種與時間處理相關(guān)的功能,包括獲取當(dāng)前時間、操作日期/時間以及執(zhí)行與時間相關(guān)的各種其它功能,這篇文章主要介紹了python 時間的訪問和轉(zhuǎn)換 time,需要的朋友可以參考下2024-05-05
利用 Python 實現(xiàn)隨機相對強弱指數(shù) StochRSI
隨機相對強弱指數(shù)簡稱為StochRSI,是一種技術(shù)分析指標(biāo),用于確定資產(chǎn)是否處于超買或超賣狀態(tài),也用于確定當(dāng)前市場的態(tài)勢。本篇文章小編九來為大家介紹隨機相對強弱指數(shù)簡稱為StochRSI,需要的朋友可以參考下面文章的具體內(nèi)容2021-09-09

