PyTorch零基礎(chǔ)入門之邏輯斯蒂回歸
學(xué)習(xí)總結(jié)
(1)和上一講的模型訓(xùn)練是類似的,只是在線性模型的基礎(chǔ)上加個sigmoid,然后loss函數(shù)改為交叉熵BCE函數(shù)(當(dāng)然也可以用其他函數(shù)),另外一開始的數(shù)據(jù)y_data也從數(shù)值改為類別0和1(本例為二分類,注意x_data和y_data這里也是矩陣的形式)。
一、sigmoid函數(shù)
logistic function是一種sigmoid函數(shù)(還有其他sigmoid函數(shù)),但由于使用過于廣泛,pytorch默認logistic function叫為sigmoid函數(shù)。還有如下的各種sigmoid函數(shù):

二、和Linear的區(qū)別
邏輯斯蒂和線性模型的unit區(qū)別如下圖:

sigmoid函數(shù)是不需要參數(shù)的,所以不用對其初始化(直接調(diào)用nn.functional.sigmoid即可)。
另外loss函數(shù)從MSE改用交叉熵BCE:盡可能和真實分類貼近。

如下圖右方表格所示,當(dāng) y ^ \hat{y} y^越接近y時則BCE Loss值越小。

三、邏輯斯蒂回歸(分類)PyTorch實現(xiàn)
# -*- coding: utf-8 -*-
"""
Created on Mon Oct 18 08:35:00 2021
@author: 86493
"""
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import torch.nn.functional as F
import numpy as np
# 準備數(shù)據(jù)
x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[0], [0], [1]])
losslst = []
class LogisticRegressionModel(nn.Module):
def __init__(self):
super(LogisticRegressionModel, self).__init__()
self.linear = torch.nn.Linear(1, 1)
def forward(self, x):
# 和線性模型的網(wǎng)絡(luò)的唯一區(qū)別在這句,多了F.sigmoid
y_predict = F.sigmoid(self.linear(x))
return y_predict
model = LogisticRegressionModel()
# 使用交叉熵作損失函數(shù)
criterion = torch.nn.BCELoss(size_average = False)
optimizer = torch.optim.SGD(model.parameters(),
lr = 0.01)
# 訓(xùn)練
for epoch in range(1000):
y_predict = model(x_data)
loss = criterion(y_predict, y_data)
# 打印loss對象會自動調(diào)用__str__
print(epoch, loss.item())
losslst.append(loss.item())
# 梯度清零后反向傳播
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 畫圖
plt.plot(range(1000), losslst)
plt.ylabel('Loss')
plt.xlabel('epoch')
plt.show()
# test
# 每周學(xué)習(xí)的時間,200個點
x = np.linspace(0, 10, 200)
x_t = torch.Tensor(x).view((200, 1))
y_t = model(x_t)
y = y_t.data.numpy()
plt.plot(x, y)
# 畫 probability of pass = 0.5的紅色橫線
plt.plot([0, 10], [0.5, 0.5], c = 'r')
plt.xlabel('Hours')
plt.ylabel('Probability of Pass')
plt.grid()
plt.show()

可以看出處于通過和不通過的分界線是Hours=2.5。

Reference
到此這篇關(guān)于PyTorch零基礎(chǔ)入門之邏輯斯蒂回歸的文章就介紹到這了,更多相關(guān)PyTorch 邏輯斯蒂回歸內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
Django中如何使用celery異步發(fā)送短信驗證碼詳解
Celery是Python開發(fā)的分布式任務(wù)調(diào)度模塊,這篇文章主要給大家介紹了關(guān)于Django中如何使用celery異步發(fā)送短信驗證碼的相關(guān)資料,主要內(nèi)容包括基礎(chǔ)介紹、工作原理、完整代碼等方面,需要的朋友可以參考下2021-09-09
對Python Class之間函數(shù)的調(diào)用關(guān)系詳解
今天小編就為大家分享一篇對Python Class之間函數(shù)的調(diào)用關(guān)系詳解,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2019-01-01
解決pip安裝報錯required?to?install?pyproject.toml-based?projec
這篇文章主要介紹了解決pip安裝報錯required?to?install?pyproject.toml-based?projects問題,具有很好的參考價值,希望對大家有所幫助,如有錯誤或未考慮完全的地方,望不吝賜教2024-05-05

