詳解Pytorch 使用Pytorch擬合多項式(多項式回歸)
使用Pytorch來編寫神經(jīng)網(wǎng)絡(luò)具有很多優(yōu)勢,比起Tensorflow,我認(rèn)為Pytorch更加簡單,結(jié)構(gòu)更加清晰。
希望通過實戰(zhàn)幾個Pytorch的例子,讓大家熟悉Pytorch的使用方法,包括數(shù)據(jù)集創(chuàng)建,各種網(wǎng)絡(luò)層結(jié)構(gòu)的定義,以及前向傳播與權(quán)重更新方式。
比如這里給出
很顯然,這里我們只需要假定

這里我們只需要設(shè)置一個合適尺寸的全連接網(wǎng)絡(luò),根據(jù)不斷迭代,求出最接近的參數(shù)即可。
但是這里需要思考一個問題,使用全連接網(wǎng)絡(luò)結(jié)構(gòu)是毫無疑問的,但是我們的輸入與輸出格式是什么樣的呢?
只將一個x作為輸入合理嗎?顯然是不合理的,因為每一個神經(jīng)元其實模擬的是wx+b的計算過程,無法模擬冪運算,所以顯然我們需要將x,x的平方,x的三次方,x的四次方組合成一個向量作為輸入,假設(shè)有n個不同的x值,我們就可以將n個組合向量合在一起組成輸入矩陣。
這一步代碼如下:
def make_features(x): x = x.unsqueeze(1) return torch.cat([x ** i for i in range(1,4)] , 1)
我們需要生成一些隨機(jī)數(shù)作為網(wǎng)絡(luò)輸入:
def get_batch(batch_size=32): random = torch.randn(batch_size) x = make_features(random) '''Compute the actual results''' y = f(x) if torch.cuda.is_available(): return Variable(x).cuda(), Variable(y).cuda() else: return Variable(x), Variable(y)
其中的f(x)定義如下:
w_target = torch.FloatTensor([0.5,3,2.4]).unsqueeze(1) b_target = torch.FloatTensor([0.9]) def f(x): return x.mm(w_target)+b_target[0]
接下來定義模型:
class poly_model(nn.Module): def __init__(self): super(poly_model, self).__init__() self.poly = nn.Linear(3,1) def forward(self, x): out = self.poly(x) return out
if torch.cuda.is_available(): model = poly_model().cuda() else: model = poly_model()
接下來我們定義損失函數(shù)和優(yōu)化器:
criterion = nn.MSELoss() optimizer = optim.SGD(model.parameters(), lr = 1e-3)
網(wǎng)絡(luò)部件定義完后,開始訓(xùn)練:
epoch = 0 while True: batch_x,batch_y = get_batch() output = model(batch_x) loss = criterion(output,batch_y) print_loss = loss.data[0] optimizer.zero_grad() loss.backward() optimizer.step() epoch+=1 if print_loss < 1e-3: break
到此我們的所有代碼就敲完了,接下來我們開始詳細(xì)了解一下其中的一些代碼。
在make_features()定義中,torch.cat是將計算出的向量拼接成矩陣。unsqueeze是作一個維度上的變化。
get_batch中,torch.randn是產(chǎn)生指定維度的隨機(jī)數(shù),如果你的機(jī)器支持GPU加速,可以將Variable放在GPU上進(jìn)行運算,類似語句含義相通。
x.mm是作矩陣乘法。
模型定義是重中之重,其實當(dāng)你掌握Pytorch之后,你會發(fā)現(xiàn)模型定義是十分簡單的,各種基本的層結(jié)構(gòu)都已經(jīng)為你封裝好了。所有的層結(jié)構(gòu)和損失函數(shù)都來自torch.nn,所有的模型構(gòu)建都是從這個基類 nn.Module繼承的。模型定義中,__init__與forward是有模板的,大家可以自己體會。
nn.Linear是做一個線性的運算,參數(shù)的含義代表了輸入層與輸出層的結(jié)構(gòu),即3*1;在訓(xùn)練階段,有幾行是Pytorch不同于別的框架的,首先loss是一個Variable,通過loss.data可以取出一個Tensor,再通過data[0]可以得到一個int或者float類型的值,我們才可以進(jìn)行基本運算或者顯示。每次計算梯度之前,都需要將梯度歸零,否則梯度會疊加。個人覺得別的語句還是比較好懂的,如果有疑問可以在下方評論。
下面是我們的擬合結(jié)果

其實效果肯定會很好,因為只是一個非常簡單的全連接網(wǎng)絡(luò),希望大家通過這個小例子可以學(xué)到Pytorch的一些基本操作。往后我們會繼續(xù)更新,完整代碼請戳,https://github.com/ZhichaoDuan/PytorchCourse
以上就是本文的全部內(nèi)容,希望對大家的學(xué)習(xí)有所幫助,也希望大家多多支持腳本之家。
相關(guān)文章
基于Python實現(xiàn)GeoServer矢量文件批量發(fā)布
由于矢量圖層文件較多,手動發(fā)布費時費力,python支持的關(guān)于geoserver包又由于年久失修,無法在較新的geoserver版本中正常使用。本文為大家準(zhǔn)備了Python自動化發(fā)布矢量文件的代碼,需要的可以參考一下2022-07-07
python簡單鼠標(biāo)自動點擊某區(qū)域的實例
今天小編就為大家分享一篇python簡單鼠標(biāo)自動點擊某區(qū)域的實例,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2019-06-06
apache部署python程序出現(xiàn)503錯誤的解決方法
這篇文章主要給大家介紹了關(guān)于在apahce部署python程序出現(xiàn)503錯誤的解決方法,文中通過示例代碼介紹的非常詳細(xì),對同樣遇到這個問題的朋友們具有一定的參考學(xué)習(xí)價值,需要的朋友們下面來一起看看吧。2017-07-07

