PyTorch實(shí)現(xiàn)卷積神經(jīng)網(wǎng)絡(luò)的搭建詳解
PyTorch中實(shí)現(xiàn)卷積的重要基礎(chǔ)函數(shù)
1、nn.Conv2d:
nn.Conv2d在pytorch中用于實(shí)現(xiàn)卷積。
nn.Conv2d(
in_channels=32,
out_channels=64,
kernel_size=3,
stride=1,
padding=1,
)
1、in_channels為輸入通道數(shù)。
2、out_channels為輸出通道數(shù)。
3、kernel_size為卷積核大小。
4、stride為步數(shù)。
5、padding為padding情況。
6、dilation表示空洞卷積情況。
2、nn.MaxPool2d(kernel_size=2)
nn.MaxPool2d在pytorch中用于實(shí)現(xiàn)最大池化。
具體使用方式如下:
MaxPool2d(kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False)
1、kernel_size為池化核的大小
2、stride為步長
3、padding為填充情況
3、nn.ReLU()
nn.ReLU()用來實(shí)現(xiàn)Relu函數(shù),實(shí)現(xiàn)非線性。
4、x.view()
x.view用于reshape特征層的形狀。
全部代碼
這是一個簡單的CNN模型,用于預(yù)測mnist手寫體。
import os
import numpy as np
import torch
import torch.nn as nn
import torch.utils.data as Data
import torchvision
import matplotlib.pyplot as plt
# 循環(huán)世代
EPOCH = 20
BATCH_SIZE = 50
# 下載mnist數(shù)據(jù)集
train_data = torchvision.datasets.MNIST(root='./mnist/',train=True,transform=torchvision.transforms.ToTensor(),download=True,)
# (60000, 28, 28)
print(train_data.train_data.size())
# (60000)
print(train_data.train_labels.size())
train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
# 測試集
test_data = torchvision.datasets.MNIST(root='./mnist/', train=False)
# (2000, 1, 28, 28)
# 標(biāo)準(zhǔn)化
test_x = torch.unsqueeze(test_data.test_data, dim=1).type(torch.FloatTensor)[:2000]/255.
test_y = test_data.test_labels[:2000]
# 建立pytorch神經(jīng)網(wǎng)絡(luò)
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
#----------------------------#
# 第一部分卷積
#----------------------------#
self.conv1 = nn.Sequential(
nn.Conv2d(
in_channels=1,
out_channels=32,
kernel_size=5,
stride=1,
padding=2,
dilation=1
),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2),
)
#----------------------------#
# 第二部分卷積
#----------------------------#
self.conv2 = nn.Sequential(
nn.Conv2d(
in_channels=32,
out_channels=64,
kernel_size=3,
stride=1,
padding=1,
dilation=1
),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2),
)
#----------------------------#
# 全連接+池化+全連接
#----------------------------#
self.ful1 = nn.Linear(64 * 7 * 7, 512)
self.drop = nn.Dropout(0.5)
self.ful2 = nn.Sequential(nn.Linear(512, 10),nn.Softmax())
#----------------------------#
# 前向傳播
#----------------------------#
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = x.view(x.size(0), -1)
x = self.ful1(x)
x = self.drop(x)
output = self.ful2(x)
return output
cnn = CNN()
# 指定優(yōu)化器
optimizer = torch.optim.Adam(cnn.parameters(), lr=1e-3)
# 指定loss函數(shù)
loss_func = nn.CrossEntropyLoss()
for epoch in range(EPOCH):
for step, (b_x, b_y) in enumerate(train_loader):
#----------------------------#
# 計(jì)算loss并修正權(quán)值
#----------------------------#
output = cnn(b_x)
loss = loss_func(output, b_y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
#----------------------------#
# 打印
#----------------------------#
if step % 50 == 0:
test_output = cnn(test_x)
pred_y = torch.max(test_output, 1)[1].data.numpy()
accuracy = float((pred_y == test_y.data.numpy()).astype(int).sum()) / float(test_y.size(0))
print('Epoch: %2d'% epoch, ', loss: %.4f' % loss.data.numpy(), ', accuracy: %.4f' % accuracy)
以上就是PyTorch實(shí)現(xiàn)卷積神經(jīng)網(wǎng)絡(luò)的搭建詳解的詳細(xì)內(nèi)容,更多關(guān)于PyTorch搭建卷積神經(jīng)網(wǎng)絡(luò)的資料請關(guān)注腳本之家其它相關(guān)文章!
- PyTorch中的神經(jīng)網(wǎng)絡(luò) Mnist 分類任務(wù)
- 使用Pytorch構(gòu)建第一個神經(jīng)網(wǎng)絡(luò)模型?附案例實(shí)戰(zhàn)
- pytorch簡單實(shí)現(xiàn)神經(jīng)網(wǎng)絡(luò)功能
- pytorch深度神經(jīng)網(wǎng)絡(luò)入門準(zhǔn)備自己的圖片數(shù)據(jù)
- Pytorch卷積神經(jīng)網(wǎng)絡(luò)遷移學(xué)習(xí)的目標(biāo)及好處
- Pytorch深度學(xué)習(xí)經(jīng)典卷積神經(jīng)網(wǎng)絡(luò)resnet模塊訓(xùn)練
- Pytorch卷積神經(jīng)網(wǎng)絡(luò)resent網(wǎng)絡(luò)實(shí)踐
- Pytorch神經(jīng)網(wǎng)絡(luò)參數(shù)管理方法詳細(xì)講解
相關(guān)文章
Python+Selenium實(shí)現(xiàn)一鍵摸魚&采集數(shù)據(jù)
將Selenium程序編寫為 .bat 可執(zhí)行文件,從此一鍵啟動封裝好的Selenium程序,省時省力還可以復(fù)用,豈不美哉。所以本文將利用Selenium實(shí)現(xiàn)一鍵摸魚&一鍵采集數(shù)據(jù),需要的可以參考一下2022-08-08
selenium 安裝與chromedriver安裝的方法步驟
這篇文章主要介紹了selenium 安裝與chromedriver安裝的方法步驟,小編覺得挺不錯的,現(xiàn)在分享給大家,也給大家做個參考。一起跟隨小編過來看看吧2019-06-06
Python+Opencv實(shí)現(xiàn)把圖片、視頻互轉(zhuǎn)的示例
這篇文章主要介紹了Python+Opencv實(shí)現(xiàn)把圖片、視頻互轉(zhuǎn)的示例,幫助大家更好的理解和實(shí)用python,感興趣的朋友可以了解下2020-12-12
Python使用Selenium自動進(jìn)行百度搜索的實(shí)現(xiàn)
我們今天介紹一個非常適合新手的python自動化小項(xiàng)目,這個例子非常適合新手學(xué)習(xí)Python網(wǎng)絡(luò)自動化,不僅能夠了解如何使用Selenium,而且還能知道一些超級好用的小工具。感興趣的可以了解一下2021-07-07
Python使用pymysql從MySQL數(shù)據(jù)庫中讀出數(shù)據(jù)的方法
今天小編就為大家分享一篇Python使用pymysql從MySQL數(shù)據(jù)庫中讀出數(shù)據(jù)的方法,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2018-07-07

