Pytorch從0實現(xiàn)Transformer的實踐
摘要
With the continuous development of time series prediction, Transformer-like models have gradually replaced traditional models in the fields of CV and NLP by virtue of their powerful advantages. Among them, the Informer is far superior to the traditional RNN model in long-term prediction, and the Swin Transformer is significantly stronger than the traditional CNN model in image recognition. A deep grasp of Transformer has become an inevitable requirement in the field of artificial intelligence. This article will use the Pytorch framework to implement the position encoding, multi-head attention mechanism, self-mask, causal mask and other functions in Transformer, and build a Transformer network from 0.
隨著時序預測的不斷發(fā)展,Transformer類模型憑借強大的優(yōu)勢,在CV、NLP領(lǐng)域逐漸取代傳統(tǒng)模型。其中Informer在長時序預測上遠超傳統(tǒng)的RNN模型,Swin Transformer在圖像識別上明顯強于傳統(tǒng)的CNN模型。深層次掌握Transformer已經(jīng)成為從事人工智能領(lǐng)域的必然要求。本文將用Pytorch框架,實現(xiàn)Transformer中的位置編碼、多頭注意力機制、自掩碼、因果掩碼等功能,從0搭建一個Transformer網(wǎng)絡(luò)。
一、構(gòu)造數(shù)據(jù)
1.1 句子長度
# 關(guān)于word embedding,以序列建模為例 # 輸入句子有兩個,第一個長度為2,第二個長度為4 src_len = torch.tensor([2, 4]).to(torch.int32) # 目標句子有兩個。第一個長度為4, 第二個長度為3 tgt_len = torch.tensor([4, 3]).to(torch.int32) print(src_len) print(tgt_len)
輸入句子(src_len)有兩個,第一個長度為2,第二個長度為4
目標句子(tgt_len)有兩個。第一個長度為4, 第二個長度為3

1.2 生成句子
用隨機數(shù)生成句子,用0填充空白位置,保持所有句子長度一致
src_seq = torch.cat([torch.unsqueeze(F.pad(torch.randint(1, max_num_src_words, (L, )), (0, max(src_len)-L)), 0) for L in src_len]) tgt_seq = torch.cat([torch.unsqueeze(F.pad(torch.randint(1, max_num_tgt_words, (L, )), (0, max(tgt_len)-L)), 0) for L in tgt_len]) print(src_seq) print(tgt_seq)
src_seq為輸入的兩個句子,tgt_seq為輸出的兩個句子。
為什么句子是數(shù)字?在做中英文翻譯時,每個中文或英文對應的也是一個數(shù)字,只有這樣才便于處理。

1.3 生成字典
在該字典中,總共有8個字(行),每個字對應8維向量(做了簡化了的)。注意在實際應用中,應當有幾十萬個字,每個字可能有512個維度。
# 構(gòu)造word embedding src_embedding_table = nn.Embedding(9, model_dim) tgt_embedding_table = nn.Embedding(9, model_dim) # 輸入單詞的字典 print(src_embedding_table) # 目標單詞的字典 print(tgt_embedding_table)
字典中,需要留一個維度給class token,故是9行。

1.4 得到向量化的句子
通過字典取出1.2中得到的句子
# 得到向量化的句子 src_embedding = src_embedding_table(src_seq) tgt_embedding = tgt_embedding_table(tgt_seq) print(src_embedding) print(tgt_embedding)

該階段總程序
import torch # 句子長度 src_len = torch.tensor([2, 4]).to(torch.int32) tgt_len = torch.tensor([4, 3]).to(torch.int32) # 構(gòu)造句子,用0填充空白處 src_seq = torch.cat([torch.unsqueeze(F.pad(torch.randint(1, 8, (L, )), (0, max(src_len)-L)), 0) for L in src_len]) tgt_seq = torch.cat([torch.unsqueeze(F.pad(torch.randint(1, 8, (L, )), (0, max(tgt_len)-L)), 0) for L in tgt_len]) # 構(gòu)造字典 src_embedding_table = nn.Embedding(9, 8) tgt_embedding_table = nn.Embedding(9, 8) # 得到向量化的句子 src_embedding = src_embedding_table(src_seq) tgt_embedding = tgt_embedding_table(tgt_seq) print(src_embedding) print(tgt_embedding)
二、位置編碼
位置編碼是transformer的一個重點,通過加入transformer位置編碼,代替了傳統(tǒng)RNN的時序信息,增強了模型的并發(fā)度。位置編碼的公式如下:(其中pos代表行,i代表列)

2.1 計算括號內(nèi)的值
# 得到分子pos的值 pos_mat = torch.arange(4).reshape((-1, 1)) # 得到分母值 i_mat = torch.pow(10000, torch.arange(0, 8, 2).reshape((1, -1))/8) print(pos_mat) print(i_mat)

2.2 得到位置編碼
# 初始化位置編碼矩陣 pe_embedding_table = torch.zeros(4, 8) # 得到偶數(shù)行位置編碼 pe_embedding_table[:, 0::2] =torch.sin(pos_mat / i_mat) # 得到奇數(shù)行位置編碼 pe_embedding_table[:, 1::2] =torch.cos(pos_mat / i_mat) pe_embedding = nn.Embedding(4, 8) # 設(shè)置位置編碼不可更新參數(shù) pe_embedding.weight = nn.Parameter(pe_embedding_table, requires_grad=False) print(pe_embedding.weight)

三、多頭注意力
3.1 self mask
有些位置是空白用0填充的,訓練時不希望被這些位置所影響,那么就需要用到self mask。self mask的原理是令這些位置的值為無窮小,經(jīng)過softmax后,這些值會變?yōu)?,不會再影響結(jié)果。
3.1.1 得到有效位置矩陣
# 得到有效位置矩陣 vaild_encoder_pos = torch.unsqueeze(torch.cat([torch.unsqueeze(F.pad(torch.ones(L), (0, max(src_len) - L)), 0)for L in src_len]), 2) valid_encoder_pos_matrix = torch.bmm(vaild_encoder_pos, vaild_encoder_pos.transpose(1, 2)) print(valid_encoder_pos_matrix)

3.1.2 得到無效位置矩陣
invalid_encoder_pos_matrix = 1-valid_encoder_pos_matrix mask_encoder_self_attention = invalid_encoder_pos_matrix.to(torch.bool) print(mask_encoder_self_attention)
True代表需要對該位置mask

3.1.3 得到mask矩陣
用極小數(shù)填充需要被mask的位置
# 初始化mask矩陣 score = torch.randn(2, max(src_len), max(src_len)) # 用極小數(shù)填充 mask_score = score.masked_fill(mask_encoder_self_attention, -1e9) print(mask_score)

算其softmat
mask_score_softmax = F.softmax(mask_score) print(mask_score_softmax)
可以看到,已經(jīng)達到預期效果

到此這篇關(guān)于Pytorch從0實現(xiàn)Transformer的實踐的文章就介紹到這了,更多相關(guān)Pytorch Transformer內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
分布式爬蟲scrapy-redis的實戰(zhàn)踩坑記錄
最近用scrapy-redis嘗試了分布式爬蟲,使用過程中也遇到了不少問題,下面這篇文章主要給大家介紹了關(guān)于分布式爬蟲scrapy-redis的實戰(zhàn)踩坑記錄,文中通過實例代碼介紹的非常詳細,需要的朋友可以參考下2022-08-08
使用python編寫批量卸載手機中安裝的android應用腳本
該腳本的功能是卸載android手機中安裝的所有第三方應用,主要是使用adb shell pm、adb uninstall 命令,需要的朋友可以參考下2014-07-07

