對pytorch中的梯度更新方法詳解
背景
使用pytorch時,有一個yolov3的bug,我認為涉及到學習率的調整。收集到tencent yolov3和mxnet開源的yolov3,兩個優(yōu)化器中的學習率設置不一樣,而且使用GPU數(shù)目和batch的更新也不太一樣。據(jù)此,我簡單的了解了下pytorch的權重梯度的更新策略,看看能否一窺究竟。
對代碼說明
共三個實驗,分布寫在代碼中的(一)(二)(三)三個地方。運行實驗時注釋掉其他兩個
實驗及其結果
實驗(三):
不使用zero_grad()時,grad累加在一起,官網(wǎng)是使用accumulate 來表述的,所以不太清楚是取的和還是均值(這兩種最有可能)。
不使用zero_grad()時,是直接疊加add的方式累加的。
tensor([[[ 1., 1.],……torch.Size([2, 2, 2]) 0 2 * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * tensor([[[ 2., 2.],…… torch.Size([2, 2, 2]) 1 2 * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * tensor([[[ 3., 3.],…… torch.Size([2, 2, 2]) 2 2 * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * *
實驗(二):
單卡上不同的batchsize對梯度是怎么作用的。 mini-batch SGD中的batch是加快訓練,同時保持一定的噪聲。但設置不同的batchsize的權重的梯度是怎么計算的呢。
設置運行實驗(二),可以看到結果如下:所以單卡batchsize計算梯度是取均值的
tensor([[[ 3., 3.],…… torch.Size([2, 2, 2])
實驗(一):
多gpu情況下,梯度怎么合并在一起的。
在《training imagenet in 1 hours》中提到grad是allreduce的,是累加的形式。但是當設置g=2,實驗一運行時,結果也是取均值的,類同于實驗(二)
tensor([[[ 3., 3.],…… torch.Size([2, 2, 2])
實驗代碼
import torch
import torch.nn as nn
from torch.autograd import Variable
class model(nn.Module):
def __init__(self, w):
super(model, self).__init__()
self.w = w
def forward(self, xx):
b, c, _, _ = xx.shape
# extra = xx.device.index + 1 ## 實驗(一)
y = xx.reshape(b, -1).mm(self.w.cuda(xx.device).reshape(-1, 2) * extra)
return y.reshape(len(xx), -1)
g = 1
x = Variable(torch.ones(2, 1, 2, 2))
# x[1] += 1 ## 實驗(二)
w = Variable(torch.ones(2, 2, 2) * 2, requires_grad=True)
# optim = torch.optim.SGD({'params': x},
lr = 0.01
momentum = 0.9
M = model(w)
M = torch.nn.DataParallel(M, device_ids=range(g))
for i in range(3):
b = len(x)
z = M(x)
zz = z.sum(1)
l = (zz - Variable(torch.ones(b).cuda())).mean()
# zz.backward(Variable(torch.ones(b).cuda()))
l.backward()
print(w.grad, w.grad.shape)
# w.grad.zero_() ## 實驗(三)
print(i, b, '* * ' * 20)
以上這篇對pytorch中的梯度更新方法詳解就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關文章
flask的orm框架SQLAlchemy查詢實現(xiàn)解析
這篇文章主要介紹了flask的orm框架SQLAlchemy查詢實現(xiàn)解析,文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友可以參考下2019-12-12
python飛機大戰(zhàn) pygame游戲創(chuàng)建快速入門詳解
這篇文章主要介紹了python飛機大戰(zhàn) pygame游戲創(chuàng)建,結合實例形式詳細分析了Python使用pygame創(chuàng)建飛機大戰(zhàn)游戲的具體步驟與相關操作注意事項,需要的朋友可以參考下2019-12-12
Python中數(shù)字(Number)數(shù)據(jù)類型常用操作
本文主要介紹了Python中數(shù)字(Number)數(shù)據(jù)類型常用操作,文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友們下面隨著小編來一起學習學習吧2023-02-02
python+pytest接口自動化參數(shù)關聯(lián)
這篇文章主要介紹了python+pytest接口自動化參數(shù)關聯(lián),參數(shù)關聯(lián),也叫接口關聯(lián),即接口之間存在參數(shù)的聯(lián)系或依賴,更多相關內容需要的小伙伴可可以參考一下2022-06-06

