PyTorch學(xué)習(xí)筆記之回歸實(shí)戰(zhàn)
本文主要是用PyTorch來實(shí)現(xiàn)一個(gè)簡(jiǎn)單的回歸任務(wù)。
編輯器:spyder
1.引入相應(yīng)的包及生成偽數(shù)據(jù)
import torch import torch.nn.functional as F # 主要實(shí)現(xiàn)激活函數(shù) import matplotlib.pyplot as plt # 繪圖的工具 from torch.autograd import Variable # 生成偽數(shù)據(jù) x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim = 1) y = x.pow(2) + 0.2 * torch.rand(x.size()) # 變?yōu)閂ariable x, y = Variable(x), Variable(y)
其中torch.linspace是為了生成連續(xù)間斷的數(shù)據(jù),第一個(gè)參數(shù)表示起點(diǎn),第二個(gè)參數(shù)表示終點(diǎn),第三個(gè)參數(shù)表示將這個(gè)區(qū)間分成平均幾份,即生成幾個(gè)數(shù)據(jù)。因?yàn)閠orch只能處理二維的數(shù)據(jù),所以我們用torch.unsqueeze給偽數(shù)據(jù)添加一個(gè)維度,dim表示添加在第幾維。torch.rand返回的是[0,1)之間的均勻分布。
2.繪制數(shù)據(jù)圖像
在上述代碼后面加下面的代碼,然后運(yùn)行可得偽數(shù)據(jù)的圖形化表示:
# 繪制數(shù)據(jù)圖像 plt.scatter(x.data.numpy(), y.data.numpy()) plt.show()

3.建立神經(jīng)網(wǎng)絡(luò)
class Net(torch.nn.Module): def __init__(self, n_feature, n_hidden, n_output): super(Net, self).__init__() self.hidden = torch.nn.Linear(n_feature, n_hidden) # hidden layer self.predict = torch.nn.Linear(n_hidden, n_output) # output layer def forward(self, x): x = F.relu(self.hidden(x)) # activation function for hidden layer x = self.predict(x) # linear output return x net = Net(n_feature=1, n_hidden=10, n_output=1) # define the network print(net) # net architecture
一般神經(jīng)網(wǎng)絡(luò)的類都繼承自torch.nn.Module,__init__()和forward()兩個(gè)函數(shù)是自定義類的主要函數(shù)。在__init__()中都要添加一句super(Net, self).__init__(),這是固定的標(biāo)準(zhǔn)寫法,用于繼承父類的初始化函數(shù)。__init__()中只是對(duì)神經(jīng)網(wǎng)絡(luò)的模塊進(jìn)行了聲明,真正的搭建是在forwad()中實(shí)現(xiàn)。自定義類中的成員都通過self指針來進(jìn)行訪問,所以參數(shù)列表中都包含了self。
如果想查看網(wǎng)絡(luò)結(jié)構(gòu),可以用print()函數(shù)直接打印網(wǎng)絡(luò)。本文的網(wǎng)絡(luò)結(jié)構(gòu)輸出如下:
Net ( (hidden): Linear (1 -> 10) (predict): Linear (10 -> 1) )
4.訓(xùn)練網(wǎng)絡(luò)
# 訓(xùn)練100次 for t in range(100): prediction = net(x) # input x and predict based on x loss = loss_func(prediction, y) # 一定要是輸出在前,標(biāo)簽在后 (1. nn output, 2. target) optimizer.zero_grad() # clear gradients for next train loss.backward() # backpropagation, compute gradients optimizer.step() # apply gradients
訓(xùn)練網(wǎng)絡(luò)之前我們需要先定義優(yōu)化器和損失函數(shù)。torch.optim包中包括了各種優(yōu)化器,這里我們選用最常見的SGD作為優(yōu)化器。因?yàn)槲覀円獙?duì)網(wǎng)絡(luò)的參數(shù)進(jìn)行優(yōu)化,所以我們要把網(wǎng)絡(luò)的參數(shù)net.parameters()傳入優(yōu)化器中,并設(shè)置學(xué)習(xí)率(一般小于1)。
由于這里是回歸任務(wù),我們選擇torch.nn.MSELoss()作為損失函數(shù)。
由于優(yōu)化器是基于梯度來優(yōu)化參數(shù)的,并且梯度會(huì)保存在其中。所以在每次優(yōu)化前要通過optimizer.zero_grad()把梯度置零,然后再后向傳播及更新。
5.可視化訓(xùn)練過程
plt.ion() # something about plotting
for t in range(100):
...
if t % 5 == 0:
# plot and show learning process
plt.cla()
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
plt.text(0.5, 0, 'Loss=%.4f' % loss.data[0], fontdict={'size': 20, 'color': 'red'})
plt.pause(0.1)
plt.ioff()
plt.show()
6.運(yùn)行結(jié)果

以上就是本文的全部?jī)?nèi)容,希望對(duì)大家的學(xué)習(xí)有所幫助,也希望大家多多支持腳本之家。
相關(guān)文章
Python爬蟲解析網(wǎng)頁的4種方式實(shí)例及原理解析
這篇文章主要介紹了Python爬蟲解析網(wǎng)頁的4種方式實(shí)例及原理解析,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2019-12-12
python實(shí)現(xiàn)的按要求生成手機(jī)號(hào)功能示例
這篇文章主要介紹了python實(shí)現(xiàn)的按要求生成手機(jī)號(hào)功能,涉及Python流程控制、隨機(jī)數(shù)操作及數(shù)學(xué)運(yùn)算相關(guān)實(shí)現(xiàn)技巧,需要的朋友可以參考下2019-10-10
Python函數(shù)中*args和**kwargs來傳遞變長(zhǎng)參數(shù)的用法
這篇文章主要介紹了Python編程中使用*args和**kwargs來傳遞可變參數(shù)的用法,文中舉了變長(zhǎng)參數(shù)的例子,需要的朋友可以參考下2016-01-01
pytest自定義命令行參數(shù)的實(shí)現(xiàn)
本文主要介紹了在使用pytest運(yùn)行測(cè)試用例時(shí),通過傳遞自定義命令行參數(shù)來啟動(dòng)mitmdump進(jìn)程進(jìn)行抓包,具有一定的參考價(jià)值,感興趣的可以了解一下2024-12-12
python3.6 如何將list存入txt后再讀出list的方法
這篇文章主要介紹了python3.6 如何將list存入txt后再讀出list的方法,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2019-07-07

