PyTorch加載數(shù)據(jù)集梯度下降優(yōu)化
一、實現(xiàn)過程
1、準備數(shù)據(jù)
與PyTorch實現(xiàn)多維度特征輸入的邏輯回歸的方法不同的是:本文使用DataLoader方法,并繼承DataSet抽象類,可實現(xiàn)對數(shù)據(jù)集進行mini_batch梯度下降優(yōu)化。
代碼如下:
import torch
import numpy as np
from torch.utils.data import Dataset,DataLoader
class DiabetesDataSet(Dataset):
? ? def __init__(self, filepath):
? ? ? ? xy = np.loadtxt(filepath,delimiter=',',dtype=np.float32)
? ? ? ? self.len = xy.shape[0]
? ? ? ? self.x_data = torch.from_numpy(xy[:,:-1])
? ? ? ? self.y_data = torch.from_numpy(xy[:,[-1]])
? ? ? ??
? ? def __getitem__(self, index):
? ? ? ? return self.x_data[index],self.y_data[index]
? ??
? ? def __len__(self):
? ? ? ? return self.len
dataset = DiabetesDataSet('G:/datasets/diabetes/diabetes.csv')
train_loader = DataLoader(dataset=dataset,batch_size=32,shuffle=True,num_workers=0)2、設計模型
class Model(torch.nn.Module): ? ? def __init__(self): ? ? ? ? super(Model,self).__init__() ? ? ? ? self.linear1 = torch.nn.Linear(8,6) ? ? ? ? self.linear2 = torch.nn.Linear(6,4) ? ? ? ? self.linear3 = torch.nn.Linear(4,1) ? ? ? ? self.activate = torch.nn.Sigmoid() ? ?? ? ? def forward(self, x): ? ? ? ? x = self.activate(self.linear1(x)) ? ? ? ? x = self.activate(self.linear2(x)) ? ? ? ? x = self.activate(self.linear3(x)) ? ? ? ? return x model = Model()
3、構造損失函數(shù)和優(yōu)化器
criterion = torch.nn.BCELoss(reduction='mean') optimizer = torch.optim.SGD(model.parameters(),lr=0.1)
4、訓練過程
每次拿出mini_batch個樣本進行訓練,代碼如下:
epoch_list = [] loss_list = [] for epoch in range(100): ? ? count = 0 ? ? loss1 = 0 ? ? for i, data in enumerate(train_loader,0): ? ? ? ? # 1.Prepare data ? ? ? ? inputs, labels = data ? ? ? ? # 2.Forward ? ? ? ? y_pred = model(inputs) ? ? ? ? loss = criterion(y_pred,labels) ? ? ? ? print(epoch,i,loss.item()) ? ? ? ? count += 1 ? ? ? ? loss1 += loss.item() ? ? ? ? # 3.Backward ? ? ? ? optimizer.zero_grad() ? ? ? ? loss.backward() ? ? ? ? # 4.Update ? ? ? ? optimizer.step() ? ? ? ?? ? ? epoch_list.append(epoch) ? ? loss_list.append(loss1/count)
5、結果展示
plt.plot(epoch_list,loss_list,'b')
plt.xlabel('epoch')
plt.ylabel('loss')
plt.grid()
plt.show()
二、參考文獻
到此這篇關于PyTorch加載數(shù)據(jù)集梯度下降優(yōu)化的文章就介紹到這了,更多相關PyTorch加載數(shù)據(jù)集內容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關文章希望大家以后多多支持腳本之家!
相關文章
對django 2.x版本中models.ForeignKey()外鍵說明介紹
這篇文章主要介紹了對django 2.x版本中models.ForeignKey()外鍵說明介紹,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2020-03-03
Pandas如何對Categorical類型字段數(shù)據(jù)統(tǒng)計實戰(zhàn)案例
這篇文章主要介紹了Pandas如何對Categorical類型字段數(shù)據(jù)統(tǒng)計實戰(zhàn)案例,文章圍繞主題展開詳細的內容介紹,具有一定的參考價值,需要的小伙伴可以參考一下2022-08-08
解決Django no such table: django_session的問題
這篇文章主要介紹了解決Django no such table: django_session的問題,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2020-04-04

