PyTorch-GPU加速實(shí)例
硬件:NVIDIA-GTX1080
軟件:Windows7、python3.6.5、pytorch-gpu-0.4.1
一、基礎(chǔ)知識
將數(shù)據(jù)和網(wǎng)絡(luò)都推到GPU,接上.cuda()
二、代碼展示
import torch
import torch.nn as nn
import torch.utils.data as Data
import torchvision
# torch.manual_seed(1)
EPOCH = 1
BATCH_SIZE = 50
LR = 0.001
DOWNLOAD_MNIST = False
train_data = torchvision.datasets.MNIST(root='./mnist/', train=True, transform=torchvision.transforms.ToTensor(), download=DOWNLOAD_MNIST,)
train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
test_data = torchvision.datasets.MNIST(root='./mnist/', train=False)
# !!!!!!!! Change in here !!!!!!!!! #
test_x = torch.unsqueeze(test_data.test_data, dim=1).type(torch.FloatTensor)[:2000].cuda()/255. # Tensor on GPU
test_y = test_data.test_labels[:2000].cuda()
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Sequential(nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, stride=1, padding=2,),
nn.ReLU(), nn.MaxPool2d(kernel_size=2),)
self.conv2 = nn.Sequential(nn.Conv2d(16, 32, 5, 1, 2), nn.ReLU(), nn.MaxPool2d(2),)
self.out = nn.Linear(32 * 7 * 7, 10)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = x.view(x.size(0), -1)
output = self.out(x)
return output
cnn = CNN()
# !!!!!!!! Change in here !!!!!!!!! #
cnn.cuda() # Moves all model parameters and buffers to the GPU.
optimizer = torch.optim.Adam(cnn.parameters(), lr=LR)
loss_func = nn.CrossEntropyLoss()
for epoch in range(EPOCH):
for step, (x, y) in enumerate(train_loader):
# !!!!!!!! Change in here !!!!!!!!! #
b_x = x.cuda() # Tensor on GPU
b_y = y.cuda() # Tensor on GPU
output = cnn(b_x)
loss = loss_func(output, b_y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if step % 50 == 0:
test_output = cnn(test_x)
# !!!!!!!! Change in here !!!!!!!!! #
pred_y = torch.max(test_output, 1)[1].cuda().data # move the computation in GPU
accuracy = torch.sum(pred_y == test_y).type(torch.FloatTensor) / test_y.size(0)
print('Epoch: ', epoch, '| train loss: %.4f' % loss, '| test accuracy: %.2f' % accuracy)
test_output = cnn(test_x[:10])
# !!!!!!!! Change in here !!!!!!!!! #
pred_y = torch.max(test_output, 1)[1].cuda().data # move the computation in GPU
print(pred_y, 'prediction number')
print(test_y[:10], 'real number')
三、結(jié)果展示

補(bǔ)充知識:pytorch使用gpu對網(wǎng)絡(luò)計(jì)算進(jìn)行加速
1.基本要求
你的電腦里面有合適的GPU顯卡(NVIDA),并且需要支持CUDA模塊
你必須安裝GPU版的Torch,(詳細(xì)安裝方法請移步pytorch官網(wǎng))
2.使用GPU訓(xùn)練CNN
利用pytorch使用GPU進(jìn)行加速方法主要就是將數(shù)據(jù)的形式變成GPU能讀的形式,然后將CNN也變成GPU能讀的形式,具體辦法就是在后面加上.cuda()。
例如:
#如何檢查自己電腦是否支持cuda print torch.cuda.is_available() # 返回True代表支持,F(xiàn)alse代表不支持 ''' 注意在進(jìn)行某種運(yùn)算的時候使用.cuda() ''' test_data=test_data.test_labels[:2000].cuda() ''' 對于CNN與損失函數(shù)利用cuda加速 ''' class CNN(nn.Module): ... cnn=CNN() cnn.cuda() loss_f = t.nn.CrossEntropyLoss() loss_f = loss_f.cuda()
而在train時,對于train_data訓(xùn)練過程進(jìn)行GPU加速。也同樣+.cuda()。
for epoch ..: for step, ...: 1 ''' 若你的train_data在訓(xùn)練時需要進(jìn)行操作 若沒有其他操作僅僅只利用cnn()則無需另加.cuda() ''' #eg train_data = torch.max(teain_data, 1)[1].cuda()
補(bǔ)充:取出數(shù)據(jù)需要從GPU切換到CPU上進(jìn)行操作
eg:
loss = loss.cpu()
acc = acc.cpu()
理解并不全,如有紕漏或者錯誤還望各位大佬指點(diǎn)迷津
以上這篇PyTorch-GPU加速實(shí)例就是小編分享給大家的全部內(nèi)容了,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關(guān)文章
Python運(yùn)維自動化之paramiko模塊應(yīng)用實(shí)例
paramiko是一個基于SSH用于連接遠(yuǎn)程服務(wù)器并執(zhí)行相關(guān)操作,使用該模塊可以對遠(yuǎn)程服務(wù)器進(jìn)行命令或文件操作,這篇文章主要給大家介紹了關(guān)于Python運(yùn)維自動化之paramiko模塊應(yīng)用的相關(guān)資料,需要的朋友可以參考下2022-09-09
Python face_recognition實(shí)現(xiàn)AI識別圖片中的人物
最近碰到了照片識別的場景,正好使用了face_recognition項(xiàng)目,給大家分享分享。face_recognition項(xiàng)目能做的很多,人臉檢測功能也是有的,是一個比較成熟的項(xiàng)目。感興趣的可以了解一下2022-01-01
Kali Linux安裝ipython2 和 ipython3的方法
今天小編就為大家分享一篇Kali Linux安裝ipython2 和 ipython3的方法,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2019-07-07
Python進(jìn)行Restful?API開發(fā)實(shí)例詳解
這篇文章主要介紹了Python進(jìn)行Restful?API開發(fā)實(shí)例,本文給大家介紹的非常詳細(xì),對大家的學(xué)習(xí)或工作具有一定的參考借鑒價值,需要的朋友可以參考下2022-03-03
使用Python設(shè)計(jì)一個代碼統(tǒng)計(jì)工具
這篇文章主要介紹了使用Python設(shè)計(jì)一個代碼統(tǒng)計(jì)工具的相關(guān)資料,包括文件個數(shù),代碼行數(shù),注釋行數(shù),空行行數(shù)。感興趣的朋友跟隨腳本之家小編一起看看吧2018-04-04
Python?xmltodict實(shí)現(xiàn)簡化XML數(shù)據(jù)處理
Python社區(qū)為提供了xmltodict庫,它專為簡化XML與Python數(shù)據(jù)結(jié)構(gòu)的轉(zhuǎn)換而設(shè)計(jì),本文主要來為大家介紹一下如何使用xmltodict實(shí)現(xiàn)簡化XML數(shù)據(jù)處理,希望對大家有所幫助2025-01-01

