深入理解Pytorch中的torch. matmul()
torch.matmul()
語法
torch.matmul(input, other, *, out=None) → Tensor
作用
兩個(gè)張量的矩陣乘積
行為取決于張量的維度,如下所示:
- 如果兩個(gè)張量都是一維的,則返回點(diǎn)積(標(biāo)量)。
- 如果兩個(gè)參數(shù)都是二維的,則返回矩陣-矩陣乘積。
- 如果第一個(gè)參數(shù)是一維的,第二個(gè)參數(shù)是二維的,為了矩陣乘法的目的,在它的維數(shù)前面加上一個(gè) 1。在矩陣相乘之后,前置維度被移除。
- 如果第一個(gè)參數(shù)是二維的,第二個(gè)參數(shù)是一維的,則返回矩陣向量積。
- 如果兩個(gè)參數(shù)至少為一維且至少一個(gè)參數(shù)為 N 維(其中 N > 2),則返回批處理矩陣乘法
- 如果第一個(gè)參數(shù)是一維的,則將 1 添加到其維度,以便批量矩陣相乘并在之后刪除。如果第二個(gè)參數(shù)是一維的,則將 1 附加到其維度以用于批量矩陣倍數(shù)并在之后刪除
- 非矩陣(即批次)維度是廣播的(因此必須是可廣播的)
- 例如,如果輸入是( j × 1 × n × n ) (j \times 1 \times n \times n)(j×1×n×n) 張量
- 另一個(gè)是 ( k × n × n ) (k \times n \times n)(k×n×n)張量,
- out 將是一個(gè) ( j × k × n × n ) (j \times k \times n \times n)(j×k×n×n) 張量
請(qǐng)注意,廣播邏輯在確定輸入是否可廣播時(shí)僅查看批處理維度,而不是矩陣維度
例如
- 如果輸入是 ( j × 1 × n × m ) (j \times 1 \times n \times m)(j×1×n×m) 張量
- 另一個(gè)是 ( k × m × p ) (k \times m \times p)(k×m×p) 張量
- 即使最后兩個(gè)維度(即矩陣維度)不同,這些輸入對(duì)于廣播也是有效的
- out 將是一個(gè) ( j × k × n × p ) (j \times k \times n \times p)(j×k×n×p) 張量
該運(yùn)算符支持 TensorFloat32。
在某些 ROCm 設(shè)備上,當(dāng)使用 float16 輸入時(shí),此模塊將使用不同的向后精度

舉例
情形1: 一維 * 一維
如果兩個(gè)張量都是一維的,則返回點(diǎn)積(標(biāo)量)
tensor1 = torch.Tensor([1,2,3])
tensor2 =torch.Tensor([4,5,6])
ans = torch.matmul(tensor1, tensor2)
print('tensor1 : ', tensor1)
print('tensor2 : ', tensor2)
print('ans :', ans)
print('ans.size :', ans.size())
ans = 1 * 4 + 2 * 5 + 3 * 6 = 32
情形2: 二維 * 二維
如果兩個(gè)參數(shù)都是二維的,則返回矩陣-矩陣乘積
也就是 正常的矩陣乘法 (m * n) * (n * k) = (m * k)
tensor1 = torch.Tensor([[1,2,3],[1,2,3]])
tensor2 =torch.Tensor([[4,5],[4,5],[4,5]])
ans = torch.matmul(tensor1, tensor2)
print('tensor1 : ', tensor1)
print('tensor2 : ', tensor2)
print('ans :', ans)
print('ans.size :', ans.size())
情形3: 一維 * 二維
如果第一個(gè)參數(shù)是一維的,第二個(gè)參數(shù)是二維的,為了矩陣乘法的目的,在它的維數(shù)前面加上一個(gè) 1
在矩陣相乘之后,前置維度被移除
tensor1 = torch.Tensor([1,2,3]) # 注意這里是一維
tensor2 =torch.Tensor([[4,5],[4,5],[4,5]])
ans = torch.matmul(tensor1, tensor2)
print('tensor1 : ', tensor1)
print('tensor2 : ', tensor2)
print('ans :', ans)
print('ans.size :', ans.size())
tensor1 = torch.Tensor([1,2,3]) 修改為 tensor1 = torch.Tensor([[1,2,3]])

發(fā)現(xiàn)一個(gè)結(jié)果是[24., 30.] 一個(gè)是[[24., 30.]]
所以,當(dāng)一維 * 二維時(shí), 開始變成 1 * m(一維的維度),也就是一個(gè)二維, 再進(jìn)行正常的矩陣運(yùn)算,得到[[24., 30.]], 然后再去掉開始增加的一個(gè)維度,得到[24., 30.]
想象為二維 * 二維(前置維度為1),最后結(jié)果去掉一個(gè)維度即可
情形4: 二維 * 一維
如果第一個(gè)參數(shù)是二維的,第二個(gè)參數(shù)是一維的,則返回矩陣向量積
tensor1 =torch.Tensor([[4,5,6],[7,8,9]])
tensor2 = torch.Tensor([1,2,3])
ans = torch.matmul(tensor1, tensor2)
print('tensor1 : ', tensor1)
print('tensor2 : ', tensor2)
print('ans :', ans)
print('ans.size :', ans.size())
理解為:
- 把第一個(gè)二維中,想象為多個(gè)行向量
- 第二個(gè)一維想象為一個(gè)列向量
- 行向量與列向量進(jìn)行矩陣乘法,得到一個(gè)標(biāo)量
- 再按照行堆疊起來即可

情形5:兩個(gè)參數(shù)至少為一維且至少一個(gè)參數(shù)為 N 維(其中 N > 2),則返回批處理矩陣乘法
第一個(gè)參數(shù)為N維,第二個(gè)參數(shù)為一維時(shí)
tensor1 = torch.randn(10, 3, 4) tensor2 = torch.randn(4) print(torch.matmul(tensor1, tensor2).size())

(4) 先添加一個(gè)維度 (4 * 1)
得到(10 * 3 * 4) *( 4 * 1) = (10 * 3 * 1)
再刪除最后一個(gè)維度(添加的那個(gè))
得到結(jié)果(10 * 3)
tensor1 = torch.randn(10,2, 3, 4) # tensor2 = torch.randn(4) print(torch.matmul(tensor1, tensor2).size())

(10 * 2 * 3 * 4) * (4 * 1) = (10 * 2 * 3) 【抵消4,刪1】
第一個(gè)參數(shù)為一維,第二個(gè)參數(shù)為二維時(shí)
tensor1 = torch.randn(4) tensor2 = torch.randn(10, 4, 3) print(torch.matmul(tensor1, tensor2).size())

tensor2 中第一個(gè)10理解為批次, 10個(gè)(4 * 3)
(1 * 4)與每個(gè)(4 * 3) 相乘得到(1,3),去除1,得到(3)
批次為10,得到(10,3)
tensor1 = torch.randn(4) tensor2 = torch.randn(10,2, 4, 3) print(torch.matmul(tensor1, tensor2).size())

這里批次理解為[10, 2]即可
tensor1 = torch.randn(4) tensor2 = torch.randn(10,4, 2,4,1) print(torch.matmul(tensor1, tensor2).size())

個(gè)人理解:當(dāng)一個(gè)參數(shù)為一維時(shí),它要去匹配另一個(gè)參數(shù)的最后兩個(gè)維度(二維 * 二維)
比如上面的例子就是(1 * 4) 匹配 (4,1), 批次為(10,4,2)
高維 * 高維時(shí)


注:這不太好理解 … 感覺就是要找準(zhǔn)批次,再進(jìn)行乘法(靠感覺了 哈哈 離譜)
參考 https://pytorch.org/docs/stable/generated/torch.matmul.html#torch.matmul
到此這篇關(guān)于深入理解Pytorch中的torch. matmul()的文章就介紹到這了,更多相關(guān)Pytorch torch. matmul()內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
python構(gòu)建指數(shù)平滑預(yù)測(cè)模型示例
今天小編就為大家分享一篇python構(gòu)建指數(shù)平滑預(yù)測(cè)模型示例,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2019-11-11
基于Python __dict__與dir()的區(qū)別詳解
下面小編就為大家?guī)硪黄赑ython __dict__與dir()的區(qū)別詳解。小編覺得挺不錯(cuò)的,現(xiàn)在就分享給大家,也給大家做個(gè)參考。一起跟隨小編過來看看吧2017-10-10
Python實(shí)現(xiàn)根據(jù)指定端口探測(cè)服務(wù)器/模塊部署的方法
這篇文章主要介紹了Python根據(jù)指定端口探測(cè)服務(wù)器/模塊部署的方法,非常具有實(shí)用價(jià)值,需要的朋友可以參考下2014-08-08
Python實(shí)現(xiàn)堡壘機(jī)模式下遠(yuǎn)程命令執(zhí)行操作示例
這篇文章主要介紹了Python實(shí)現(xiàn)堡壘機(jī)模式下遠(yuǎn)程命令執(zhí)行操作,結(jié)合實(shí)例形式分析了Python堡壘機(jī)模式執(zhí)行遠(yuǎn)程命令的原理與相關(guān)操作技巧,需要的朋友可以參考下2019-05-05
關(guān)于pytorch中網(wǎng)絡(luò)loss傳播和參數(shù)更新的理解
今天小編就為大家分享一篇關(guān)于pytorch中網(wǎng)絡(luò)loss傳播和參數(shù)更新的理解,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2019-08-08
Python實(shí)現(xiàn)去除列表中重復(fù)元素的方法總結(jié)【7種方法】
今天小編就為大家分享一篇關(guān)于Python實(shí)現(xiàn)去除列表中重復(fù)元素的方法總結(jié)【7種方法】,小編覺得內(nèi)容挺不錯(cuò)的,現(xiàn)在分享給大家,具有很好的參考價(jià)值,需要的朋友一起跟隨小編來看看吧2019-02-02

