pytorch中使用cuda擴(kuò)展的實(shí)現(xiàn)示例
以下面這個(gè)例子作為教程,實(shí)現(xiàn)功能是element-wise add;
(pytorch中想調(diào)用cuda模塊,還是用另外使用C編寫接口腳本)
第一步:cuda編程的源文件和頭文件
// mathutil_cuda_kernel.cu
// 頭文件,最后一個(gè)是cuda特有的
#include <curand.h>
#include <stdio.h>
#include <math.h>
#include <float.h>
#include "mathutil_cuda_kernel.h"
// 獲取GPU線程通道信息
dim3 cuda_gridsize(int n)
{
int k = (n - 1) / BLOCK + 1;
int x = k;
int y = 1;
if(x > 65535) {
x = ceil(sqrt(k));
y = (n - 1) / (x * BLOCK) + 1;
}
dim3 d(x, y, 1);
return d;
}
// 這個(gè)函數(shù)是cuda執(zhí)行函數(shù),可以看到細(xì)化到了每一個(gè)元素
__global__ void broadcast_sum_kernel(float *a, float *b, int x, int y, int size)
{
int i = (blockIdx.x + blockIdx.y * gridDim.x) * blockDim.x + threadIdx.x;
if(i >= size) return;
int j = i % x; i = i / x;
int k = i % y;
a[IDX2D(j, k, y)] += b[k];
}
// 這個(gè)函數(shù)是與c語(yǔ)言函數(shù)鏈接的接口函數(shù)
void broadcast_sum_cuda(float *a, float *b, int x, int y, cudaStream_t stream)
{
int size = x * y;
cudaError_t err;
// 上面定義的函數(shù)
broadcast_sum_kernel<<<cuda_gridsize(size), BLOCK, 0, stream>>>(a, b, x, y, size);
err = cudaGetLastError();
if (cudaSuccess != err)
{
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
exit(-1);
}
}
#ifndef _MATHUTIL_CUDA_KERNEL
#define _MATHUTIL_CUDA_KERNEL
#define IDX2D(i, j, dj) (dj * i + j)
#define IDX3D(i, j, k, dj, dk) (IDX2D(IDX2D(i, j, dj), k, dk))
#define BLOCK 512
#define MAX_STREAMS 512
#ifdef __cplusplus
extern "C" {
#endif
void broadcast_sum_cuda(float *a, float *b, int x, int y, cudaStream_t stream);
#ifdef __cplusplus
}
#endif
#endif
第二步:C編程的源文件和頭文件(接口函數(shù))
// mathutil_cuda.c
// THC是pytorch底層GPU庫(kù)
#include <THC/THC.h>
#include "mathutil_cuda_kernel.h"
extern THCState *state;
int broadcast_sum(THCudaTensor *a_tensor, THCudaTensor *b_tensor, int x, int y)
{
float *a = THCudaTensor_data(state, a_tensor);
float *b = THCudaTensor_data(state, b_tensor);
cudaStream_t stream = THCState_getCurrentStream(state);
// 這里調(diào)用之前在cuda中編寫的接口函數(shù)
broadcast_sum_cuda(a, b, x, y, stream);
return 1;
}
int broadcast_sum(THCudaTensor *a_tensor, THCudaTensor *b_tensor, int x, int y);
第三步:編譯,先編譯cuda模塊,再編譯接口函數(shù)模塊(不能放在一起同時(shí)編譯)
nvcc -c -o mathutil_cuda_kernel.cu.o mathutil_cuda_kernel.cu -x cu -Xcompiler -fPIC -arch=sm_52
import os
import torch
from torch.utils.ffi import create_extension
this_file = os.path.dirname(__file__)
sources = []
headers = []
defines = []
with_cuda = False
if torch.cuda.is_available():
print('Including CUDA code.')
sources += ['src/mathutil_cuda.c']
headers += ['src/mathutil_cuda.h']
defines += [('WITH_CUDA', None)]
with_cuda = True
this_file = os.path.dirname(os.path.realpath(__file__))
extra_objects = ['src/mathutil_cuda_kernel.cu.o'] # 這里是編譯好后的.o文件位置
extra_objects = [os.path.join(this_file, fname) for fname in extra_objects]
ffi = create_extension(
'_ext.cuda_util',
headers=headers,
sources=sources,
define_macros=defines,
relative_to=__file__,
with_cuda=with_cuda,
extra_objects=extra_objects
)
if __name__ == '__main__':
ffi.build()
第四步:調(diào)用cuda模塊
from _ext import cuda_util #從對(duì)應(yīng)路徑中調(diào)用編譯好的模塊 a = torch.randn(3, 5).cuda() b = torch.randn(3, 1).cuda() mathutil.broadcast_sum(a, b, *map(int, a.size())) # 上面等價(jià)于下面的效果: a = torch.randn(3, 5) b = torch.randn(3, 1) a += b
以上就是本文的全部?jī)?nèi)容,希望對(duì)大家的學(xué)習(xí)有所幫助,也希望大家多多支持腳本之家。
相關(guān)文章
windows上安裝Anaconda和python的教程詳解
本文主要給大家介紹windows上安裝Anaconda和python的教程詳解,非常不錯(cuò),具有參考借鑒價(jià)值,需要的朋友參考下2017-03-03
Python批量實(shí)現(xiàn)Word/EXCEL/PPT轉(zhuǎn)PDF
在日常辦公和文檔處理中,有時(shí)我們需要將多個(gè)Word文檔、Excel表格或PPT演示文稿轉(zhuǎn)換為PDF文件,本文將介紹如何使用Python編程語(yǔ)言批量實(shí)現(xiàn)將多個(gè)Word、Excel和PPT文件轉(zhuǎn)換為PDF文件,需要的可以參考下2023-09-09
python在前端頁(yè)面使用?MySQLdb?連接數(shù)據(jù)
這篇文章主要介紹了MySQLdb?連接數(shù)據(jù)的使用,文章主要介紹的相關(guān)內(nèi)容又插入數(shù)據(jù),刪除數(shù)據(jù),更新數(shù)據(jù),搜索數(shù)據(jù),需要的小伙伴可以參考一下2022-03-03
Python利用Flask動(dòng)態(tài)生成漢字頭像
這篇文章主要為大家詳細(xì)介紹了Python如何利用Flask動(dòng)態(tài)生成一個(gè)漢字頭像,文中的示例代碼講解詳細(xì),對(duì)我們學(xué)習(xí)Python有一定的幫助,需要的可以參考一下2023-01-01
詳解如何用OpenCV + Python 實(shí)現(xiàn)人臉識(shí)別
這篇文章主要介紹了詳解如何用OpenCV + Python 實(shí)現(xiàn)人臉識(shí)別,非常具有實(shí)用價(jià)值,需要的朋友可以參考下2017-10-10
Python爬蟲常用庫(kù)的安裝及其環(huán)境配置
今天小編就為大家分享一篇關(guān)于python爬蟲常用庫(kù)的安裝及其環(huán)境配置的文章,小編覺(jué)得內(nèi)容挺不錯(cuò)的,現(xiàn)在分享給大家,具有很好的參考價(jià)值,需要的朋友一起跟隨小編來(lái)看看吧2018-09-09
Python開(kāi)發(fā)之os與os.path的使用小結(jié)
這篇文章主要介紹了Python開(kāi)發(fā)之os與os.path的使用小結(jié),本文通過(guò)實(shí)例代碼給大家介紹的非常詳細(xì),感興趣的朋友一起看看吧2024-05-05
Python中PyExecJS(執(zhí)行JS代碼庫(kù))的具體使用
pyexecjs是一個(gè)用Python來(lái)執(zhí)行JavaScript代碼的工具庫(kù),本文主要介紹了Python中PyExecJS(執(zhí)行JS代碼庫(kù))的具體使用,具有一定的參考價(jià)值,感興趣的可以了解一下2024-02-02

