pytorch如何定義新的自動(dòng)求導(dǎo)函數(shù)
pytorch定義新的自動(dòng)求導(dǎo)函數(shù)
在pytorch中想自定義求導(dǎo)函數(shù),通過實(shí)現(xiàn)torch.autograd.Function并重寫forward和backward函數(shù),來定義自己的自動(dòng)求導(dǎo)運(yùn)算。參考官網(wǎng)上的demo:傳送門
直接上代碼,定義一個(gè)ReLu來實(shí)現(xiàn)自動(dòng)求導(dǎo)
import torch
class MyRelu(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
# 我們使用ctx上下文對(duì)象來緩存,以便在反向傳播中使用,ctx存儲(chǔ)時(shí)候只能存tensor
# 在正向傳播中,我們接收一個(gè)上下文對(duì)象ctx和一個(gè)包含輸入的張量input;
# 我們必須返回一個(gè)包含輸出的張量,
# input.clamp(min = 0)表示講輸入中所有值范圍規(guī)定到0到正無窮,如input=[-1,-2,3]則被轉(zhuǎn)換成input=[0,0,3]
ctx.save_for_backward(input)
# 返回幾個(gè)值,backward接受參數(shù)則包含ctx和這幾個(gè)值
return input.clamp(min = 0)
@staticmethod
def backward(ctx, grad_output):
# 把ctx中存儲(chǔ)的input張量讀取出來
input, = ctx.saved_tensors
# grad_output存放反向傳播過程中的梯度
grad_input = grad_output.clone()
# 這兒就是ReLu的規(guī)則,表示原始數(shù)據(jù)小于0,則relu為0,因此對(duì)應(yīng)索引的梯度都置為0
grad_input[input < 0] = 0
return grad_input進(jìn)行輸入數(shù)據(jù)并測(cè)試
dtype = torch.float
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 使用torch的generator定義隨機(jī)數(shù),注意產(chǎn)生的是cpu隨機(jī)數(shù)還是gpu隨機(jī)數(shù)
generator=torch.Generator(device).manual_seed(42)
# N是Batch, H is hidden dimension,
# D_in is input dimension;D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10
x = torch.randn(N, D_in, device=device, dtype=dtype,generator=generator)
y = torch.randn(N, D_out, device=device, dtype=dtype, generator=generator)
w1 = torch.randn(D_in, H, device=device, dtype=dtype, requires_grad=True, generator=generator)
w2 = torch.randn(H, D_out, device=device, dtype=dtype, requires_grad=True, generator=generator)
learning_rate = 1e-6
for t in range(500):
relu = MyRelu.apply
# 使用函數(shù)傳入?yún)?shù)運(yùn)算
y_pred = relu(x.mm(w1)).mm(w2)
# 計(jì)算損失
loss = (y_pred - y).pow(2).sum()
if t % 100 == 99:
print(t, loss.item())
# 傳播
loss.backward()
with torch.no_grad():
w1 -= learning_rate * w1.grad
w2 -= learning_rate * w2.grad
w1.grad.zero_()
w2.grad.zero_()pytorch自動(dòng)求導(dǎo)與邏輯回歸
自動(dòng)求導(dǎo)

retain_graph設(shè)為True,可以進(jìn)行兩次反向傳播


邏輯回歸


import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
torch.manual_seed(10)
#========生成數(shù)據(jù)=============
sample_nums = 100
mean_value = 1.7
bias = 1
n_data = torch.ones(sample_nums,2)
x0 = torch.normal(mean_value*n_data,1)+bias#類別0數(shù)據(jù)
y0 = torch.zeros(sample_nums)#類別0標(biāo)簽
x1 = torch.normal(-mean_value*n_data,1)+bias#類別1數(shù)據(jù)
y1 = torch.ones(sample_nums)#類別1標(biāo)簽
train_x = torch.cat((x0,x1),0)
train_y = torch.cat((y0,y1),0)
#==========選擇模型===========
class LR(nn.Module):
def __init__(self):
super(LR,self).__init__()
self.features = nn.Linear(2,1)
self.sigmoid = nn.Sigmoid()
def forward(self,x):
x = self.features(x)
x = self.sigmoid(x)
return x
lr_net = LR()#實(shí)例化邏輯回歸模型
#==============選擇損失函數(shù)===============
loss_fn = nn.BCELoss()
#==============選擇優(yōu)化器=================
lr = 0.01
optimizer = torch.optim.SGD(lr_net.parameters(),lr = lr,momentum=0.9)
#===============模型訓(xùn)練==================
for iteration in range(1000):
#前向傳播
y_pred = lr_net(train_x)#模型的輸出
#計(jì)算loss
loss = loss_fn(y_pred.squeeze(),train_y)
#反向傳播
loss.backward()
#更新參數(shù)
optimizer.step()
#繪圖
if iteration % 20 == 0:
mask = y_pred.ge(0.5).float().squeeze() #以0.5分類
correct = (mask==train_y).sum()#正確預(yù)測(cè)樣本數(shù)
acc = correct.item()/train_y.size(0)#分類準(zhǔn)確率
plt.scatter(x0.data.numpy()[:,0],x0.data.numpy()[:,1],c='r',label='class0')
plt.scatter(x1.data.numpy()[:,0],x1.data.numpy()[:,1],c='b',label='class1')
w0,w1 = lr_net.features.weight[0]
w0,w1 = float(w0.item()),float(w1.item())
plot_b = float(lr_net.features.bias[0].item())
plot_x = np.arange(-6,6,0.1)
plot_y = (-w0*plot_x-plot_b)/w1
plt.xlim(-5,7)
plt.ylim(-7,7)
plt.plot(plot_x,plot_y)
plt.text(-5,5,'Loss=%.4f'%loss.data.numpy(),fontdict={'size':20,'color':'red'})
plt.title('Iteration:{}\nw0:{:.2f} w1:{:.2f} b{:.2f} accuracy:{:2%}'.format(iteration,w0,w1,plot_b,acc))
plt.legend()
plt.show()
plt.pause(0.5)
if acc > 0.99:
break總結(jié)
以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
- 使用pytorch進(jìn)行張量計(jì)算、自動(dòng)求導(dǎo)和神經(jīng)網(wǎng)絡(luò)構(gòu)建功能
- 在?pytorch?中實(shí)現(xiàn)計(jì)算圖和自動(dòng)求導(dǎo)
- Pytorch自動(dòng)求導(dǎo)函數(shù)詳解流程以及與TensorFlow搭建網(wǎng)絡(luò)的對(duì)比
- 淺談Pytorch中的自動(dòng)求導(dǎo)函數(shù)backward()所需參數(shù)的含義
- pytorch中的自定義反向傳播,求導(dǎo)實(shí)例
- 關(guān)于PyTorch 自動(dòng)求導(dǎo)機(jī)制詳解
- Pytorch反向求導(dǎo)更新網(wǎng)絡(luò)參數(shù)的方法
- 關(guān)于pytorch求導(dǎo)總結(jié)(torch.autograd)
相關(guān)文章
Python實(shí)現(xiàn)基于TCP UDP協(xié)議的IPv4 IPv6模式客戶端和服務(wù)端功能示例
這篇文章主要介紹了Python實(shí)現(xiàn)基于TCP UDP協(xié)議的IPv4 IPv6模式客戶端和服務(wù)端功能,結(jié)合實(shí)例形式分析了Python基于TCP UDP協(xié)議的IPv4 IPv6模式客戶端和服務(wù)端數(shù)據(jù)發(fā)送與接收相關(guān)操作技巧,需要的朋友可以參考下2018-03-03
python 從csv讀數(shù)據(jù)到mysql的實(shí)例
今天小編就為大家分享一篇python 從csv讀數(shù)據(jù)到mysql的實(shí)例,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2018-06-06
解決python3.6用cx_Oracle庫連接Oracle的問題
這篇文章主要介紹了解決python3.6用cx_Oracle庫連接Oracle的問題,本文給大家介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2020-12-12
詳細(xì)解析Python當(dāng)中的數(shù)據(jù)類型和變量
這篇文章主要介紹了Python當(dāng)中的數(shù)據(jù)類型和變量,是Python學(xué)習(xí)當(dāng)中的基礎(chǔ)知識(shí),需要的朋友可以參考下2015-04-04
Python編程使用PyQt5庫實(shí)現(xiàn)動(dòng)態(tài)水波進(jìn)度條示例
這篇文章主要介紹了Python編程使用PyQt5庫實(shí)現(xiàn)動(dòng)態(tài)水波進(jìn)度條的示例代碼解析,有需要的朋友可以借鑒參考下希望能夠有所幫助,祝大家多多進(jìn)步早日升職加薪2021-10-10
python實(shí)現(xiàn)爬取百度圖片的方法示例
這篇文章主要介紹了python實(shí)現(xiàn)爬取百度圖片的方法,涉及Python基于requests、urllib等模塊的百度圖片抓取相關(guān)操作技巧,需要的朋友可以參考下2019-07-07
對(duì)numpy中二進(jìn)制格式的數(shù)據(jù)存儲(chǔ)與讀取方法詳解
今天小編就為大家分享一篇對(duì)numpy中二進(jìn)制格式的數(shù)據(jù)存儲(chǔ)與讀取方法詳解,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2018-11-11

