Python反向傳播實(shí)現(xiàn)線性回歸步驟詳細(xì)講解
1. 導(dǎo)入包
我們這次的任務(wù)是隨機(jī)生成一些離散的點(diǎn),然后用直線(y = w *x + b )去擬合
首先看一下我們需要導(dǎo)入的包有

torch 包為我們生成張量,可以使用反向傳播
matplotlib.pyplot 包幫助我們繪制曲線,實(shí)現(xiàn)可視化
2. 生成數(shù)據(jù)
這里我們通過rand隨機(jī)生成數(shù)據(jù),因?yàn)樯傻臄?shù)據(jù)在0~1之間,這里我們擴(kuò)大10倍。
我們?cè)O(shè)置的batch_size,也就是數(shù)據(jù)的個(gè)數(shù)為20個(gè),所以這里會(huì)產(chǎn)生維度是(20,1)個(gè)訓(xùn)練樣本
我們假設(shè)大概的回歸是 y = 2 * x + 3 的,為了保證損失不一直為0 ,這里我們添加一點(diǎn)噪音
最后返回x作為輸入,y作為真實(shí)值label
rand [0,1]均勻分布

如果想要每次產(chǎn)生的隨機(jī)數(shù)是一樣的,可以在代碼的前面設(shè)置一下隨機(jī)數(shù)種子

3. 訓(xùn)練數(shù)據(jù)
首先,我們要建立的模型是線性的y = w * x + b ,所以我們需要先初始化w ,b
使用randn 標(biāo)準(zhǔn)正態(tài)分布隨機(jī)初始化權(quán)重w,將偏置b初始化為0
為什么將權(quán)重w隨機(jī)初始化?
- 首先,為了抑制過擬合,提高模型的泛化能力,我們可以采用權(quán)重衰減來抑制權(quán)重w的大小。因?yàn)闄?quán)重過大,對(duì)應(yīng)的輸入x的特征就越重要,但是如果對(duì)應(yīng)x是噪音的話,那么系統(tǒng)就會(huì)陷入過擬合中。所以我們希望得到的模型曲線是一條光滑的,對(duì)輸入不敏感的曲線,所以w越小越好
- 那這樣為什么不直接把權(quán)重初始化為0,或者說很小很小的數(shù)字呢。因?yàn)?,w太小的話,那么在反向傳播的時(shí)候,由于我們習(xí)慣學(xué)習(xí)率lr 設(shè)置很小,那在更新w的時(shí)候基本就不更新了。而不把權(quán)重設(shè)置為0,是因?yàn)闊o論訓(xùn)練多久,在更新權(quán)重的時(shí)候,所有權(quán)重都會(huì)被更新成相同的值,這樣多層隱藏層就沒有意義了。嚴(yán)格來說,是為了瓦解權(quán)重的對(duì)稱結(jié)構(gòu)

接下來可以訓(xùn)練我們的模型了

1. 將輸入的特征x和對(duì)應(yīng)真實(shí)值label y通過zip函數(shù)打包。將輸入x經(jīng)過模型 w *x + b 的預(yù)測(cè)輸出預(yù)測(cè)值y
2. 計(jì)算損失函數(shù)loss,因?yàn)橹皩、b都是設(shè)置成會(huì)計(jì)算梯度的,那么loss.backward() 會(huì)自動(dòng)計(jì)算w和b的梯度。用w的值data,減去梯度的值grad.data 乘上 學(xué)習(xí)率lr完成一次更新
3. 當(dāng)w、b梯度不為零的話,要清零。這里有兩種解釋,第一種是每次計(jì)算完梯度后,值會(huì)和之前計(jì)算的梯度值進(jìn)行累加,而我們只是需要當(dāng)前這步的梯度值,所有我們需要將之前的值清零。第二種是,因?yàn)樘荻鹊睦奂樱敲聪喈?dāng)于實(shí)現(xiàn)一個(gè)很大的batch訓(xùn)練。假如一個(gè)epoch里面,梯度不進(jìn)行清零的話,相當(dāng)于把所有的樣本求和后在進(jìn)行梯度下降,而不是我們?cè)仁褂玫尼槍?duì)單個(gè)樣本進(jìn)行下降的SGD算法
4. 每100次迭代后,我們打印一下?lián)p失
4. 繪制圖像

scatter 相當(dāng)于離散點(diǎn)的繪圖
要繪制連續(xù)的圖像,只需要給個(gè)定義域然后通過表達(dá)式 w * x +b 計(jì)算y就可以了,最后輸出一下w和b,看看是不是和我們?cè)O(shè)置的w = 2,b =3 接近
5. 代碼
import torch
import matplotlib.pyplot as plt
def trainSet(batch_size = 20): # 定義訓(xùn)練集
x = torch.rand(batch_size,1) * 10
y = x * 2 + 3 + torch.randn(batch_size,1) # y = x * 2 + 3(近似)
return x,y
train_x, train_y = trainSet() # 訓(xùn)練集
w =torch.randn(1,requires_grad= True)
b = torch.zeros(1,requires_grad= True)
lr = 0.001
for epoch in range(1000):
for x,y in zip(train_x,train_y): # SGD算法,如果是BSGD的話,不需要這個(gè)for
y_pred = w*x + b
loss = (y - y_pred).pow(2) / 2
loss.backward()
w.data -= w.grad.data * lr
b.data -= b.grad.data * lr
if w.data is not True: # 梯度值不為零的話,要清零
w.grad.data.zero_() # 否則相當(dāng)于一個(gè)大的batch訓(xùn)練
if b.data is not True:
b.grad.data.zero_()
if epoch % 100 ==0:
print('loss:',loss.data)
plt.scatter(train_x,train_y)
x = torch.arange(0,11).view(-1,1)
y = x * w.data + b.data
plt.plot(x,y)
plt.show()
print(w.data,b.data)輸出的圖像

輸出的結(jié)果為

這里可以看的最后的w = 1.9865和b = 2.9857 和我們?cè)O(shè)置的2,3是接近的
到此這篇關(guān)于Python反向傳播實(shí)現(xiàn)線性回歸步驟詳細(xì)講解的文章就介紹到這了,更多相關(guān)Python線性回歸內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
- python實(shí)現(xiàn)線性回歸算法
- python深度總結(jié)線性回歸
- python機(jī)器學(xué)習(xí)基礎(chǔ)線性回歸與嶺回歸算法詳解
- Python線性回歸圖文實(shí)例詳解
- python實(shí)現(xiàn)線性回歸的示例代碼
- python數(shù)據(jù)分析之線性回歸選擇基金
- python基于numpy的線性回歸
- Python實(shí)現(xiàn)多元線性回歸的梯度下降法
- Python構(gòu)建簡(jiǎn)單線性回歸模型
- python繪制y關(guān)于x的線性回歸線性方程圖像實(shí)例
- python實(shí)現(xiàn)線性回歸的示例代碼
相關(guān)文章
在Python中f-string的幾個(gè)技巧,你都知道嗎
f-string想必很多Python用戶都基礎(chǔ)性的使用過,但是百分之九十的人不知道?在Python中f-string的幾個(gè)技巧,今天就帶大家一起看看Python f-string技巧大全,需要的朋友參考下吧2021-10-10
Python實(shí)現(xiàn)在Excel文件中寫入圖表
這篇文章主要為大家介紹了如何利用Python語言實(shí)現(xiàn)在Excel文件中寫入一個(gè)比較簡(jiǎn)單的圖表,文中的實(shí)現(xiàn)方法講解詳細(xì),快動(dòng)手嘗試一下吧2022-05-05
Python selenium爬取微博數(shù)據(jù)代碼實(shí)例
這篇文章主要介紹了Python selenium爬取微博數(shù)據(jù)代碼實(shí)例,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2020-05-05
Pytest單元測(cè)試框架生成HTML測(cè)試報(bào)告及優(yōu)化的步驟
本文主要介紹了Pytest單元測(cè)試框架生成HTML測(cè)試報(bào)告及優(yōu)化的步驟,文中通過示例代碼介紹的非常詳細(xì),具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2022-01-01
Matplotlib控制坐標(biāo)軸刻度間距與標(biāo)簽實(shí)例代碼
在matplotlib中,記號(hào)是圖形兩個(gè)軸上的小標(biāo)記,到目前為止,我們讓matplotlib處理軸圖例上記號(hào)的位置,下面這篇文章主要給大家介紹了關(guān)于Matplotlib控制坐標(biāo)軸刻度間距與標(biāo)簽的相關(guān)資料,需要的朋友可以參考下2021-10-10

