pytorch使用Variable實(shí)現(xiàn)線性回歸
本文實(shí)例為大家分享了pytorch使用Variable實(shí)現(xiàn)線性回歸的具體代碼,供大家參考,具體內(nèi)容如下
一、手動(dòng)計(jì)算梯度實(shí)現(xiàn)線性回歸
#導(dǎo)入相關(guān)包 import torch as t import matplotlib.pyplot as plt #構(gòu)造數(shù)據(jù) def get_fake_data(batch_size = 8): #設(shè)置隨機(jī)種子數(shù),這樣每次生成的隨機(jī)數(shù)都是一樣的 t.manual_seed(10) #產(chǎn)生隨機(jī)數(shù)據(jù):y = 2*x+3,加上了一些噪聲 x = t.rand(batch_size,1) * 20 #randn生成期望為0方差為1的正態(tài)分布隨機(jī)數(shù) y = x * 2 + (1 + t.randn(batch_size,1)) * 3 return x,y #查看生成數(shù)據(jù)的分布 x,y = get_fake_data() plt.scatter(x.squeeze().numpy(),y.squeeze().numpy()) #線性回歸 #隨機(jī)初始化參數(shù) w = t.rand(1,1) b = t.zeros(1,1) #學(xué)習(xí)率 lr = 0.001 for i in range(10000): x,y = get_fake_data() #forward:計(jì)算loss y_pred = x.mm(w) + b.expand_as(y) #均方誤差作為損失函數(shù) loss = 0.5 * (y_pred - y)**2 loss = loss.sum() #backward:手動(dòng)計(jì)算梯度 dloss = 1 dy_pred = dloss * (y_pred - y) dw = x.t().mm(dy_pred) db = dy_pred.sum() #更新參數(shù) w.sub_(lr * dw) b.sub_(lr * db) if i%1000 == 0: #畫圖 plt.scatter(x.squeeze().numpy(),y.squeeze().numpy()) x1 = t.arange(0,20).float().view(-1,1) y1 = x1.mm(w) + b.expand_as(x1) plt.plot(x1.numpy(),y1.numpy()) #predicted plt.show() #plt.pause(0.5) print(w.squeeze(),b.squeeze())

顯示的最后一張圖如下所示:

二、自動(dòng)梯度 計(jì)算梯度實(shí)現(xiàn)線性回歸
#導(dǎo)入相關(guān)包 import torch as t from torch.autograd import Variable as V import matplotlib.pyplot as plt #構(gòu)造數(shù)據(jù) def get_fake_data(batch_size=8): t.manual_seed(10) #設(shè)置隨機(jī)數(shù)種子 x = t.rand(batch_size,1) * 20 y = 2 * x +(1 + t.randn(batch_size,1)) * 3 return x,y #查看產(chǎn)生的x,y的分布是什么樣的 x,y = get_fake_data() plt.scatter(x.squeeze().numpy(),y.squeeze().numpy()) #線性回歸 #初始化隨機(jī)參數(shù) w = V(t.rand(1,1),requires_grad=True) b = V(t.rand(1,1),requires_grad=True) lr = 0.001 for i in range(8000): x,y = get_fake_data() x,y = V(x),V(y) y_pred = x * w + b loss = 0.5 * (y_pred-y)**2 loss = loss.sum() #自動(dòng)計(jì)算梯度 loss.backward() #更新參數(shù) w.data.sub_(lr * w.grad.data) b.data.sub_(lr * b.grad.data) #梯度清零,不清零梯度會(huì)累加的 w.grad.data.zero_() b.grad.data.zero_() if i%1000==0: #predicted x = t.arange(0,20).float().view(-1,1) y = x.mm(w.data) + b.data.expand_as(x) plt.plot(x.numpy(),y.numpy()) #true data x2,y2 = get_fake_data() plt.scatter(x2.numpy(),y2.numpy()) plt.show() print(w.data[0],b.data[0])

顯示的最后一張圖如下所示:

用autograd實(shí)現(xiàn)的線性回歸最大的不同點(diǎn)就在于利用autograd不需要手動(dòng)計(jì)算梯度,可以自動(dòng)微分。這一點(diǎn)不單是在深度在學(xué)習(xí)中,在許多機(jī)器學(xué)習(xí)的問題中都很有用。另外,需要注意的是每次反向傳播之前要記得先把梯度清零,因?yàn)閍utograd求得的梯度是自動(dòng)累加的。
以上就是本文的全部內(nèi)容,希望對(duì)大家的學(xué)習(xí)有所幫助,也希望大家多多支持腳本之家。
相關(guān)文章
Python StringIO及BytesIO包使用方法解析
這篇文章主要介紹了Python StringIO及BytesIO包使用方法解析,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2020-06-06
python使用Pillow將照片轉(zhuǎn)換為1寸報(bào)名照片的教程分享
在現(xiàn)代科技時(shí)代,我們經(jīng)常需要調(diào)整和處理照片以適應(yīng)特定的需求和用途,本文將介紹如何使用wxPython和Pillow庫,通過一個(gè)簡單的圖形界面程序,將選擇的照片轉(zhuǎn)換為指定尺寸的JPG格式,并保存在桌面上,需要的朋友可以參考下2023-09-09
Pycharm插件(Grep Console)自定義規(guī)則輸出顏色日志的方法
這篇文章主要介紹了Pycharm插件(Grep Console)自定義規(guī)則輸出顏色日志的方法,本文通過圖文并茂的形式給大家介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2020-05-05
Python math庫 ln(x)運(yùn)算的實(shí)現(xiàn)及原理
這篇文章主要介紹了Python math庫 ln(x)運(yùn)算的實(shí)現(xiàn)及原理,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2019-07-07
淺談django框架集成swagger以及自定義參數(shù)問題
這篇文章主要介紹了淺談django框架集成swagger以及自定義參數(shù)問題,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2020-07-07
基于Python實(shí)現(xiàn)ComicReaper漫畫自動(dòng)爬取腳本過程解析
這篇文章主要介紹了基于Python實(shí)現(xiàn)ComicReaper漫畫自動(dòng)爬取腳本過程解析,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2019-11-11

