將pytorch的網(wǎng)絡(luò)等轉(zhuǎn)移到cuda
神經(jīng)網(wǎng)絡(luò)一般用GPU來跑,我們的神經(jīng)網(wǎng)絡(luò)框架一般也都安裝的GPU版本,本文就簡單記錄一下GPU使用的編寫。
GPU的設(shè)置不在model,而是在Train的初始化上。
第一步是查看是否可以使用GPU
self.GPU_IN_USE = torch.cuda.is_available()
就是返回這個可不可以用GPU的函數(shù),當(dāng)你的pytorch是cpu版本的時候,他就會返回False。
然后是:
self.device = torch.device('cuda' if self.GPU_IN_USE else 'cpu')
torch.device是代表將torch.tensor分配到哪個設(shè)備的函數(shù)
接著是,我看到了一篇文章,原來就是將網(wǎng)絡(luò)啊、數(shù)據(jù)啊、隨機(jī)種子啊、損失函數(shù)啊、等等等等直接轉(zhuǎn)移到CUDA上就好了!
于是下面就好理解多了:
轉(zhuǎn)移模型:
self.model = Net(num_channels=1, upscale_factor=self.upscale_factor, base_channel=64, num_residuals=4).to(self.device)
設(shè)置cuda的隨機(jī)種子:
torch.cuda.manual_seed(self.seed)
轉(zhuǎn)移損失函數(shù):
self.criterion.cuda()
轉(zhuǎn)移數(shù)據(jù):
data, target = data.to(self.device), target.to(self.device)
pytorch 網(wǎng)絡(luò)定義參數(shù)的后面無法加.cuda()
pytorch定義網(wǎng)絡(luò)__init__()的時候,參數(shù)不能加“cuda()", 不然參數(shù)不包含在state_dict()中,比如下面這種寫法是錯誤的
self.W1 = nn.Parameter(torch.FloatTensor(3,3), requires_grad=True).cuda()
應(yīng)該去掉".cuda()"
self.W1 = nn.Parameter(torch.FloatTensor(3,3), requires_grad=True)
以上為個人經(jīng)驗,希望能給大家一個參考,也希望大家多多支持腳本之家。
- python?windows安裝cuda+cudnn+pytorch教程
- 顯卡驅(qū)動CUDA?和?pytorch?CUDA?之間的區(qū)別
- pytorch?cuda安裝報錯的解決方法
- PyTorch中的CUDA的操作方法
- PyTorch?device與cuda.device用法介紹
- pytorch 如何用cuda處理數(shù)據(jù)
- pytorch中.to(device) 和.cuda()的區(qū)別說明
- PyTorch CUDA環(huán)境配置及安裝的步驟(圖文教程)
- Linux安裝Pytorch1.8GPU(CUDA11.1)的實現(xiàn)
- 詳解win10下pytorch-gpu安裝以及CUDA詳細(xì)安裝過程
- Pytorch使用CUDA流(CUDA?stream)的實現(xiàn)
相關(guān)文章
Python實現(xiàn)將MySQL數(shù)據(jù)庫查詢結(jié)果導(dǎo)出到Excel
在實際工作中,我們經(jīng)常需要將數(shù)據(jù)庫中的數(shù)據(jù)導(dǎo)出到Excel表格中進(jìn)行進(jìn)一步的分析和處理,Python中的pymysql和xlsxwriter庫提供了很好的解決方案,下面我們就來看看具體操作方法吧2023-11-11
在樹莓派2或樹莓派B+上安裝Python和OpenCV的教程
這篇文章主要介紹了在樹莓派2或樹莓派B+上安裝Python和OpenCV的教程,主要基于GTK庫,并以Python2.7和OpenCV 2.4.X版本的安裝作為示例,需要的朋友可以參考下2015-03-03
Python使用xlrd模塊操作Excel數(shù)據(jù)導(dǎo)入的方法
這篇文章主要介紹了Python使用xlrd模塊操作Excel數(shù)據(jù)導(dǎo)入的方法,涉及Python操作xlrd模塊的技巧,需要的朋友可以參考下2015-05-05

