Pytorch實(shí)現(xiàn)常用乘法算子TensorRT的示例代碼
本文介紹一下 Pytorch 中常用乘法的 TensorRT 實(shí)現(xiàn)。
pytorch 用于訓(xùn)練,TensorRT 用于推理是很多 AI 應(yīng)用開發(fā)的標(biāo)配。大家往往更加熟悉 pytorch 的算子,而不太熟悉 TensorRT 的算子,這里拿比較常用的乘法運(yùn)算在兩種框架下的實(shí)現(xiàn)做一個(gè)對(duì)比,可能會(huì)有更加直觀一些的認(rèn)識(shí)。
1.乘法運(yùn)算總覽
先把 pytorch 中的一些常用的乘法運(yùn)算進(jìn)行一個(gè)總覽:
- torch.mm:用于兩個(gè)矩陣 (不包括向量) 的乘法,如維度 (m, n) 的矩陣乘以維度 (n, p) 的矩陣;
- torch.bmm:用于帶 batch 的三維向量的乘法,如維度 (b, m, n) 的矩陣乘以維度 (b, n, p) 的矩陣;
- torch.mul:用于同維度矩陣的逐像素點(diǎn)相乘,也即點(diǎn)乘,如維度 (m, n) 的矩陣點(diǎn)乘維度 (m, n) 的矩陣。該方法支持廣播,也即支持矩陣和元素點(diǎn)乘;
- torch.mv:用于矩陣和向量的乘法,矩陣在前,向量在后,如維度 (m, n) 的矩陣乘以維度為 (n) 的向量,輸出維度為 (m);
- torch.matmul:用于兩個(gè)張量相乘,或矩陣與向量乘法,作用包含 torch.mm、torch.bmm、torch.mv;
- @:作用相當(dāng)于 torch.matmul;
- *:作用相當(dāng)于 torch.mul;
如上進(jìn)行了一些具體羅列,可以歸納出,常用的乘法無(wú)非兩種:矩陣乘 和 點(diǎn)乘,所以下面分這兩類進(jìn)行介紹。
2.乘法算子實(shí)現(xiàn)
2.1矩陣乘算子實(shí)現(xiàn)
先來(lái)看看矩陣乘法的 pytorch 的實(shí)現(xiàn) (以下實(shí)現(xiàn)在終端):
>>> import torch >>> # torch.mm >>> a = torch.randn(66, 99) >>> b = torch.randn(99, 88) >>> c = torch.mm(a, b) >>> c.shape torch.size([66, 88]) >>> >>> # torch.bmm >>> a = torch.randn(3, 66, 99) >>> b = torch.randn(3, 99, 77) >>> c = torch.bmm(a, b) >>> c.shape torch.size([3, 66, 77]) >>> >>> # torch.mv >>> a = torch.randn(66, 99) >>> b = torch.randn(99) >>> c = torch.mv(a, b) >>> c.shape torch.size([66]) >>> >>> # torch.matmul >>> a = torch.randn(32, 3, 66, 99) >>> b = torch.randn(32, 3, 99, 55) >>> c = torch.matmul(a, b) >>> c.shape torch.size([32, 3, 66, 55]) >>> >>> # @ >>> d = a @ b >>> d.shape torch.size([32, 3, 66, 55])
來(lái)看 TensorRT 的實(shí)現(xiàn),以上乘法都可使用 addMatrixMultiply 方法覆蓋,對(duì)應(yīng) torch.matmul,先來(lái)看該方法的定義:
//!
//! \brief Add a MatrixMultiply layer to the network.
//!
//! \param input0 The first input tensor (commonly A).
//! \param op0 The operation to apply to input0.
//! \param input1 The second input tensor (commonly B).
//! \param op1 The operation to apply to input1.
//!
//! \see IMatrixMultiplyLayer
//!
//! \warning Int32 tensors are not valid input tensors.
//!
//! \return The new matrix multiply layer, or nullptr if it could not be created.
//!
IMatrixMultiplyLayer* addMatrixMultiply(
ITensor& input0, MatrixOperation op0, ITensor& input1, MatrixOperation op1) noexcept
{
return mImpl->addMatrixMultiply(input0, op0, input1, op1);
}
可以看到這個(gè)方法有四個(gè)傳參,對(duì)應(yīng)兩個(gè)張量和其 operation。來(lái)看這個(gè)算子在 TensorRT 中怎么添加:
// 構(gòu)造張量 Tensor0 nvinfer1::IConstantLayer *Constant_layer0 = m_network->addConstant(tensorShape0, value0); // 構(gòu)造張量 Tensor1 nvinfer1::IConstantLayer *Constant_layer1 = m_network->addConstant(tensorShape1, value1); // 添加矩陣乘法 nvinfer1::IMatrixMultiplyLayer *Matmul_layer = m_network->addMatrixMultiply(Constant_layer0->getOutput(0), matrix0Type, Constant_layer1->getOutput(0), matrix2Type); // 獲取輸出 matmulOutput = Matmul_layer->getOputput(0);
2.2點(diǎn)乘算子實(shí)現(xiàn)
再來(lái)看看點(diǎn)乘的 pytorch 的實(shí)現(xiàn) (以下實(shí)現(xiàn)在終端):
>>> import torch >>> # torch.mul >>> a = torch.randn(66, 99) >>> b = torch.randn(66, 99) >>> c = torch.mul(a, b) >>> c.shape torch.size([66, 99]) >>> d = 0.125 >>> e = torch.mul(a, d) >>> e.shape torch.size([66, 99]) >>> # * >>> f = a * b >>> f.shape torch.size([66, 99])
來(lái)看 TensorRT 的實(shí)現(xiàn),以上乘法都可使用 addScale 方法覆蓋,這在圖像預(yù)處理中十分常用,先來(lái)看該方法的定義:
//!
//! \brief Add a Scale layer to the network.
//!
//! \param input The input tensor to the layer.
//! This tensor is required to have a minimum of 3 dimensions in implicit batch mode
//! and a minimum of 4 dimensions in explicit batch mode.
//! \param mode The scaling mode.
//! \param shift The shift value.
//! \param scale The scale value.
//! \param power The power value.
//!
//! If the weights are available, then the size of weights are dependent on the ScaleMode.
//! For ::kUNIFORM, the number of weights equals 1.
//! For ::kCHANNEL, the number of weights equals the channel dimension.
//! For ::kELEMENTWISE, the number of weights equals the product of the last three dimensions of the input.
//!
//! \see addScaleNd
//! \see IScaleLayer
//! \warning Int32 tensors are not valid input tensors.
//!
//! \return The new Scale layer, or nullptr if it could not be created.
//!
IScaleLayer* addScale(ITensor& input, ScaleMode mode, Weights shift, Weights scale, Weights power) noexcept
{
return mImpl->addScale(input, mode, shift, scale, power);
}
可以看到有三個(gè)模式:
- kUNIFORM:weights 為一個(gè)值,對(duì)應(yīng)張量乘一個(gè)元素;
- kCHANNEL:weights 維度和輸入張量通道的 c 維度對(duì)應(yīng),可以做一些以通道為基準(zhǔn)的預(yù)處理;
- kELEMENTWISE:weights 維度和輸入張量的 c、h、w 對(duì)應(yīng),不考慮 batch,所以是輸入的后三維;
再來(lái)看這個(gè)算子在 TensorRT 中怎么添加:
// 構(gòu)造張量 input
nvinfer1::IConstantLayer *Constant_layer = m_network->addConstant(tensorShape, value);
// scalemode選擇,kUNIFORM、kCHANNEL、kELEMENTWISE
scalemode = kUNIFORM;
// 構(gòu)建 Weights 類型的 shift、scale、power,其中 volume 為元素?cái)?shù)量
nvinfer1::Weights scaleShift{nvinfer1::DataType::kFLOAT, nullptr, volume };
nvinfer1::Weights scaleScale{nvinfer1::DataType::kFLOAT, nullptr, volume };
nvinfer1::Weights scalePower{nvinfer1::DataType::kFLOAT, nullptr, volume };
// !! 注意這里還需要對(duì) shift、scale、power 的 values 進(jìn)行賦值,若只是乘法只需要對(duì) scale 進(jìn)行賦值就行
// 添加張量乘法
nvinfer1::IScaleLayer *Scale_layer = m_network->addScale(Constant_layer->getOutput(0), scalemode, scaleShift, scaleScale, scalePower);
// 獲取輸出
scaleOutput = Scale_layer->getOputput(0);
有一點(diǎn)你可能會(huì)比較疑惑,既然是點(diǎn)乘,那么輸入只需要兩個(gè)張量就可以了,為啥這里有 input、shift、scale、power 四個(gè)張量這么多呢。解釋一下,input 不用說(shuō),就是輸入張量,而 shift 表示加法參數(shù)、scale 表示乘法參數(shù)、power 表示指數(shù)參數(shù),說(shuō)到這里,你應(yīng)該能發(fā)現(xiàn),這個(gè)函數(shù)除了我們上面講的點(diǎn)乘外還有其他更加豐富的運(yùn)算功能。
到此這篇關(guān)于Pytorch實(shí)現(xiàn)常用乘法算子TensorRT的示例代碼的文章就介紹到這了,更多相關(guān)Pytorch乘法算子TensorRT內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
Python中operator模塊的操作符使用示例總結(jié)
operator模塊中包含了Python的各種內(nèi)置操作符,諸如邏輯、比較、計(jì)算等,這里我們針對(duì)一些常用的操作符來(lái)作一個(gè)Python中operator模塊的操作符使用示例總結(jié):2016-06-06
python爬取之json、pickle與shelve庫(kù)的深入講解
這篇文章主要給大家介紹了關(guān)于python爬取之json、pickle與shelve庫(kù)的相關(guān)資料,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2021-03-03
對(duì)pandas中iloc,loc取數(shù)據(jù)差別及按條件取值的方法詳解
今天小編就為大家分享一篇對(duì)pandas中iloc,loc取數(shù)據(jù)差別及按條件取值的方法詳解,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2018-11-11
Python with語(yǔ)句上下文管理器兩種實(shí)現(xiàn)方法分析
這篇文章主要介紹了Python with語(yǔ)句上下文管理器兩種實(shí)現(xiàn)方法,結(jié)合實(shí)例形式較為詳細(xì)的分析了Python上下文管理器的相關(guān)概念、功能、使用方法及相關(guān)操作注意事項(xiàng),需要的朋友可以參考下2018-02-02
利用Python進(jìn)行網(wǎng)絡(luò)爬蟲和數(shù)據(jù)抓取的代碼示例
在當(dāng)今數(shù)字化時(shí)代,數(shù)據(jù)是無(wú)處不在的,從市場(chǎng)趨勢(shì)到個(gè)人偏好,從社交媒體活動(dòng)到商業(yè)智能,數(shù)據(jù)扮演著關(guān)鍵的角色,Python提供了一套強(qiáng)大而靈活的工具,使得網(wǎng)絡(luò)爬蟲和數(shù)據(jù)抓取成為可能,本文將深入探討如何利用Python進(jìn)行網(wǎng)絡(luò)爬蟲和數(shù)據(jù)抓取,為您打開數(shù)據(jù)世界的大門2024-05-05
Python中FTP服務(wù)與SSH登錄暴力破解的實(shí)現(xiàn)
本文學(xué)習(xí)了如何通過(guò) Python 腳本進(jìn)行 FTP、SSH 服務(wù)的登錄爆破,文中通過(guò)示例代碼介紹的非常詳細(xì),具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2021-08-08
深入理解python中函數(shù)傳遞參數(shù)是值傳遞還是引用傳遞
這篇文章主要介紹了深入理解python中函數(shù)傳遞參數(shù)是值傳遞還是引用傳遞,涉及具體代碼示例,具有一定參考價(jià)值,需要的朋友可以了解下。2017-11-11
pandas 使用apply同時(shí)處理兩列數(shù)據(jù)的方法
下面小編就為大家分享一篇pandas 使用apply同時(shí)處理兩列數(shù)據(jù)的方法,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2018-04-04

