教你利用PyTorch實(shí)現(xiàn)sin函數(shù)模擬
一、簡(jiǎn)介
本文旨在使用兩種方法來(lái)實(shí)現(xiàn)sin函數(shù)的模擬,具體的模擬方法是使用機(jī)器學(xué)習(xí)來(lái)實(shí)現(xiàn)的,我們使用Python的torch模塊進(jìn)行機(jī)器學(xué)習(xí),從而為sin確定多項(xiàng)式的系數(shù)。
二、第一種方法
# 這個(gè)案例相當(dāng)于是使用torch來(lái)模擬sin函數(shù)進(jìn)行計(jì)算啦。
# 通過(guò)3次函數(shù)來(lái)模擬sin函數(shù),實(shí)現(xiàn)類似于機(jī)器學(xué)習(xí)的操作。
import torch
import math
dtype = torch.float
# 數(shù)據(jù)的類型
device = torch.device("cpu")
# 設(shè)備的類型
# device = torch.device("cuda:0") # Uncomment this to run on GPU
# Create random input and output data
x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype)
# 與numpy的linspace是類似的
y = torch.sin(x)
# tensor->張量
# Randomly initialize weights
# 標(biāo)準(zhǔn)的高斯函數(shù)分布。
# 隨機(jī)產(chǎn)生一個(gè)參數(shù),然后通過(guò)學(xué)習(xí)來(lái)進(jìn)行改進(jìn)參數(shù)。
a = torch.randn((), device=device, dtype=dtype)
# a
b = torch.randn((), device=device, dtype=dtype)
# b
c = torch.randn((), device=device, dtype=dtype)
# c
d = torch.randn((), device=device, dtype=dtype)
# d
learning_rate = 1e-6
for t in range(2000):
# Forward pass: compute predicted y
y_pred = a + b * x + c * x ** 2 + d * x ** 3
# 這個(gè)也是一個(gè)張量。
# 3次函數(shù)來(lái)進(jìn)行模擬。
# Compute and print loss
loss = (y_pred - y).pow(2).sum().item()
if t % 100 == 99:
print(t, loss)
# 計(jì)算誤差
# Backprop to compute gradients of a, b, c, d with respect to loss
grad_y_pred = 2.0 * (y_pred - y)
grad_a = grad_y_pred.sum()
grad_b = (grad_y_pred * x).sum()
grad_c = (grad_y_pred * x ** 2).sum()
grad_d = (grad_y_pred * x ** 3).sum()
# 計(jì)算誤差。
# Update weights using gradient descent
# 更新參數(shù),每一次都要更新。
a -= learning_rate * grad_a
b -= learning_rate * grad_b
c -= learning_rate * grad_c
d -= learning_rate * grad_d
# reward
# 最終的結(jié)果
print(f'Result: y = {a.item()} + {b.item()} x + {c.item()} x^2 + {d.item()} x^3')
運(yùn)行結(jié)果:
99 676.0404663085938
199 478.38140869140625
299 339.39117431640625
399 241.61537170410156
499 172.80801391601562
599 124.37007904052734
699 90.26084899902344
799 66.23435974121094
899 49.30537033081055
999 37.37403106689453
1099 28.96288299560547
1199 23.031932830810547
1299 18.848905563354492
1399 15.898048400878906
1499 13.81600570678711
1599 12.34669017791748
1699 11.309612274169922
1799 10.57749080657959
1899 10.060576438903809
1999 9.695555686950684
Result: y = -0.03098311647772789 + 0.852223813533783 x + 0.005345103796571493 x^2 + -0.09268788248300552 x^3
三、第二種方法
import torch
import math
dtype = torch.float
device = torch.device("cpu")
# device = torch.device("cuda:0") # Uncomment this to run on GPU
# Create Tensors to hold input and outputs.
# By default, requires_grad=False, which indicates that we do not need to
# compute gradients with respect to these Tensors during the backward pass.
x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype)
y = torch.sin(x)
# Create random Tensors for weights. For a third order polynomial, we need
# 4 weights: y = a + b x + c x^2 + d x^3
# Setting requires_grad=True indicates that we want to compute gradients with
# respect to these Tensors during the backward pass.
a = torch.randn((), device=device, dtype=dtype, requires_grad=True)
b = torch.randn((), device=device, dtype=dtype, requires_grad=True)
c = torch.randn((), device=device, dtype=dtype, requires_grad=True)
d = torch.randn((), device=device, dtype=dtype, requires_grad=True)
learning_rate = 1e-6
for t in range(2000):
# Forward pass: compute predicted y using operations on Tensors.
y_pred = a + b * x + c * x ** 2 + d * x ** 3
# Compute and print loss using operations on Tensors.
# Now loss is a Tensor of shape (1,)
# loss.item() gets the scalar value held in the loss.
loss = (y_pred - y).pow(2).sum()
if t % 100 == 99:
print(t, loss.item())
# Use autograd to compute the backward pass. This call will compute the
# gradient of loss with respect to all Tensors with requires_grad=True.
# After this call a.grad, b.grad. c.grad and d.grad will be Tensors holding
# the gradient of the loss with respect to a, b, c, d respectively.
loss.backward()
# Manually update weights using gradient descent. Wrap in torch.no_grad()
# because weights have requires_grad=True, but we don't need to track this
# in autograd.
with torch.no_grad():
a -= learning_rate * a.grad
b -= learning_rate * b.grad
c -= learning_rate * c.grad
d -= learning_rate * d.grad
# Manually zero the gradients after updating weights
a.grad = None
b.grad = None
c.grad = None
d.grad = None
print(f'Result: y = {a.item()} + {b.item()} x + {c.item()} x^2 + {d.item()} x^3')
運(yùn)行結(jié)果:
99 1702.320556640625
199 1140.3609619140625
299 765.3402709960938
399 514.934326171875
499 347.6383972167969
599 235.80038452148438
699 160.98876953125
799 110.91152954101562
899 77.36819458007812
999 54.883243560791016
1099 39.79965591430664
1199 29.673206329345703
1299 22.869291305541992
1399 18.293842315673828
1499 15.214327812194824
1599 13.1397705078125
1699 11.740955352783203
1799 10.796865463256836
1899 10.159022331237793
1999 9.727652549743652
Result: y = 0.019909318536520004 + 0.8338049650192261 x + -0.0034346890170127153 x^2 + -0.09006795287132263 x^3
四、總結(jié)
以上的兩種方法都只是模擬到了3次方,所以僅僅只是在x比較小的時(shí)候才比較合理,此外,由于系數(shù)是隨機(jī)產(chǎn)生的,因此,每次運(yùn)行的結(jié)果可能會(huì)有一定的差別的。
到此這篇關(guān)于教你利用PyTorch實(shí)現(xiàn)sin函數(shù)模擬的文章就介紹到這了,更多相關(guān)PyTorch實(shí)現(xiàn)sin函數(shù)模擬內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
使用Python實(shí)現(xiàn)Excel文件轉(zhuǎn)換為SVG格式
SVG(Scalable Vector Graphics)是一種基于XML的矢量圖像格式,這種格式在Web開發(fā)和其他圖形應(yīng)用中非常流行,提供了一種高效的方式來(lái)呈現(xiàn)復(fù)雜的矢量圖形,本文將介紹如何使用Python轉(zhuǎn)換Excel文件為SVG格式,需要的朋友可以參考下2024-07-07
python中用shutil.move移動(dòng)文件或目錄的方法實(shí)例
在python操作中大家對(duì)os,shutil,sys,等通用庫(kù)一定不陌生,下面這篇文章主要給大家介紹了關(guān)于python中用shutil.move移動(dòng)文件或目錄的相關(guān)資料,需要的朋友可以參考下2022-12-12
python3中dict.keys().sort()用不了的解決方法
本文主要介紹了python3中dict.keys().sort()用不了的解決方法,文中通過(guò)示例代碼介紹的非常詳細(xì),具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2021-12-12
Python中的自定義函數(shù)學(xué)習(xí)筆記
這篇文章主要介紹了Python中的自定義函數(shù)學(xué)習(xí)筆記,本文講解了定義函數(shù)、callable函數(shù)、help函數(shù)等內(nèi)容,需要的朋友可以參考下2014-09-09
Python基于mysql實(shí)現(xiàn)學(xué)生管理系統(tǒng)
這篇文章主要為大家詳細(xì)介紹了Python基于mysql實(shí)現(xiàn)學(xué)生管理系統(tǒng),文中示例代碼介紹的非常詳細(xì),具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2019-02-02

