簡(jiǎn)單易懂Pytorch實(shí)戰(zhàn)實(shí)例VGG深度網(wǎng)絡(luò)
模型VGG,數(shù)據(jù)集cifar。對(duì)照這份代碼走一遍,大概就知道整個(gè)pytorch的運(yùn)行機(jī)制。
定義模型:
'''VGG11/13/16/19 in Pytorch.'''
import torch
import torch.nn as nn
from torch.autograd import Variable
cfg = {
'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}
# 模型需繼承nn.Module
class VGG(nn.Module):
# 初始化參數(shù):
def __init__(self, vgg_name):
super(VGG, self).__init__()
self.features = self._make_layers(cfg[vgg_name])
self.classifier = nn.Linear(512, 10)
# 模型計(jì)算時(shí)的前向過(guò)程,也就是按照這個(gè)過(guò)程進(jìn)行計(jì)算
def forward(self, x):
out = self.features(x)
out = out.view(out.size(0), -1)
out = self.classifier(out)
return out
def _make_layers(self, cfg):
layers = []
in_channels = 3
for x in cfg:
if x == 'M':
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
else:
layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1),
nn.BatchNorm2d(x),
nn.ReLU(inplace=True)]
in_channels = x
layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
return nn.Sequential(*layers)
# net = VGG('VGG11')
# x = torch.randn(2,3,32,32)
# print(net(Variable(x)).size())
定義訓(xùn)練過(guò)程:
'''Train CIFAR10 with PyTorch.'''
from __future__ import print_function
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as transforms
import os
import argparse
from models import *
from utils import progress_bar
from torch.autograd import Variable
# 獲取參數(shù)
parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint')
args = parser.parse_args()
use_cuda = torch.cuda.is_available()
best_acc = 0 # best test accuracy
start_epoch = 0 # start from epoch 0 or last checkpoint epoch
# 獲取數(shù)據(jù)集,并先進(jìn)行預(yù)處理
print('==> Preparing data..')
# 圖像預(yù)處理和增強(qiáng)
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
# 繼續(xù)訓(xùn)練模型或新建一個(gè)模型
if args.resume:
# Load checkpoint.
print('==> Resuming from checkpoint..')
assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
checkpoint = torch.load('./checkpoint/ckpt.t7')
net = checkpoint['net']
best_acc = checkpoint['acc']
start_epoch = checkpoint['epoch']
else:
print('==> Building model..')
net = VGG('VGG16')
# net = ResNet18()
# net = PreActResNet18()
# net = GoogLeNet()
# net = DenseNet121()
# net = ResNeXt29_2x64d()
# net = MobileNet()
# net = MobileNetV2()
# net = DPN92()
# net = ShuffleNetG2()
# net = SENet18()
# 如果GPU可用,使用GPU
if use_cuda:
# move param and buffer to GPU
net.cuda()
# parallel use GPU
net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count()-1))
# speed up slightly
cudnn.benchmark = True
# 定義度量和優(yōu)化
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
# 訓(xùn)練階段
def train(epoch):
print('\nEpoch: %d' % epoch)
# switch to train mode
net.train()
train_loss = 0
correct = 0
total = 0
# batch 數(shù)據(jù)
for batch_idx, (inputs, targets) in enumerate(trainloader):
# 將數(shù)據(jù)移到GPU上
if use_cuda:
inputs, targets = inputs.cuda(), targets.cuda()
# 先將optimizer梯度先置為0
optimizer.zero_grad()
# Variable表示該變量屬于計(jì)算圖的一部分,此處是圖計(jì)算的開始處。圖的leaf variable
inputs, targets = Variable(inputs), Variable(targets)
# 模型輸出
outputs = net(inputs)
# 計(jì)算loss,圖的終點(diǎn)處
loss = criterion(outputs, targets)
# 反向傳播,計(jì)算梯度
loss.backward()
# 更新參數(shù)
optimizer.step()
# 注意如果你想統(tǒng)計(jì)loss,切勿直接使用loss相加,而是使用loss.data[0]。因?yàn)閘oss是計(jì)算圖的一部分,如果你直接加loss,代表total loss同樣屬于模型一部分,那么圖就越來(lái)越大
train_loss += loss.data[0]
# 數(shù)據(jù)統(tǒng)計(jì)
_, predicted = torch.max(outputs.data, 1)
total += targets.size(0)
correct += predicted.eq(targets.data).cpu().sum()
progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
% (train_loss/(batch_idx+1), 100.*correct/total, correct, total))
# 測(cè)試階段
def test(epoch):
global best_acc
# 先切到測(cè)試模型
net.eval()
test_loss = 0
correct = 0
total = 0
for batch_idx, (inputs, targets) in enumerate(testloader):
if use_cuda:
inputs, targets = inputs.cuda(), targets.cuda()
inputs, targets = Variable(inputs, volatile=True), Variable(targets)
outputs = net(inputs)
loss = criterion(outputs, targets)
# loss is variable , if add it(+=loss) directly, there will be a bigger ang bigger graph.
test_loss += loss.data[0]
_, predicted = torch.max(outputs.data, 1)
total += targets.size(0)
correct += predicted.eq(targets.data).cpu().sum()
progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
% (test_loss/(batch_idx+1), 100.*correct/total, correct, total))
# Save checkpoint.
# 保存模型
acc = 100.*correct/total
if acc > best_acc:
print('Saving..')
state = {
'net': net.module if use_cuda else net,
'acc': acc,
'epoch': epoch,
}
if not os.path.isdir('checkpoint'):
os.mkdir('checkpoint')
torch.save(state, './checkpoint/ckpt.t7')
best_acc = acc
# 運(yùn)行模型
for epoch in range(start_epoch, start_epoch+200):
train(epoch)
test(epoch)
# 清除部分無(wú)用變量
torch.cuda.empty_cache()
運(yùn)行:
新模型:
python main.py --lr=0.01
舊模型繼續(xù)訓(xùn)練:
python main.py --resume --lr=0.01
一些utility:
'''Some helper functions for PyTorch, including:
- get_mean_and_std: calculate the mean and std value of dataset.
- msr_init: net parameter initialization.
- progress_bar: progress bar mimic xlua.progress.
'''
import os
import sys
import time
import math
import torch.nn as nn
import torch.nn.init as init
def get_mean_and_std(dataset):
'''Compute the mean and std value of dataset.'''
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2)
mean = torch.zeros(3)
std = torch.zeros(3)
print('==> Computing mean and std..')
for inputs, targets in dataloader:
for i in range(3):
mean[i] += inputs[:,i,:,:].mean()
std[i] += inputs[:,i,:,:].std()
mean.div_(len(dataset))
std.div_(len(dataset))
return mean, std
def init_params(net):
'''Init layer parameters.'''
for m in net.modules():
if isinstance(m, nn.Conv2d):
init.kaiming_normal(m.weight, mode='fan_out')
if m.bias:
init.constant(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
init.constant(m.weight, 1)
init.constant(m.bias, 0)
elif isinstance(m, nn.Linear):
init.normal(m.weight, std=1e-3)
if m.bias:
init.constant(m.bias, 0)
_, term_width = os.popen('stty size', 'r').read().split()
term_width = int(term_width)
TOTAL_BAR_LENGTH = 65.
last_time = time.time()
begin_time = last_time
def progress_bar(current, total, msg=None):
global last_time, begin_time
if current == 0:
begin_time = time.time() # Reset for new bar.
cur_len = int(TOTAL_BAR_LENGTH*current/total)
rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1
sys.stdout.write(' [')
for i in range(cur_len):
sys.stdout.write('=')
sys.stdout.write('>')
for i in range(rest_len):
sys.stdout.write('.')
sys.stdout.write(']')
cur_time = time.time()
step_time = cur_time - last_time
last_time = cur_time
tot_time = cur_time - begin_time
L = []
L.append(' Step: %s' % format_time(step_time))
L.append(' | Tot: %s' % format_time(tot_time))
if msg:
L.append(' | ' + msg)
msg = ''.join(L)
sys.stdout.write(msg)
for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3):
sys.stdout.write(' ')
# Go back to the center of the bar.
for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2):
sys.stdout.write('\b')
sys.stdout.write(' %d/%d ' % (current+1, total))
if current < total-1:
sys.stdout.write('\r')
else:
sys.stdout.write('\n')
sys.stdout.flush()
def format_time(seconds):
days = int(seconds / 3600/24)
seconds = seconds - days*3600*24
hours = int(seconds / 3600)
seconds = seconds - hours*3600
minutes = int(seconds / 60)
seconds = seconds - minutes*60
secondsf = int(seconds)
seconds = seconds - secondsf
millis = int(seconds*1000)
f = ''
i = 1
if days > 0:
f += str(days) + 'D'
i += 1
if hours > 0 and i <= 2:
f += str(hours) + 'h'
i += 1
if minutes > 0 and i <= 2:
f += str(minutes) + 'm'
i += 1
if secondsf > 0 and i <= 2:
f += str(secondsf) + 's'
i += 1
if millis > 0 and i <= 2:
f += str(millis) + 'ms'
i += 1
if f == '':
f = '0ms'
return f
以上就是本文的全部?jī)?nèi)容,希望對(duì)大家的學(xué)習(xí)有所幫助,也希望大家多多支持腳本之家。
相關(guān)文章
python 實(shí)現(xiàn)單一數(shù)字取對(duì)數(shù)與數(shù)列取對(duì)數(shù)
這篇文章主要介紹了python 實(shí)現(xiàn)單一數(shù)字取對(duì)數(shù)與數(shù)列取對(duì)數(shù)操作,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。2021-05-05
NumPy?與?Python?內(nèi)置列表計(jì)算標(biāo)準(zhǔn)差區(qū)別詳析
這篇文章主要介紹了NumPy與Python內(nèi)置列表計(jì)算標(biāo)準(zhǔn)差區(qū)別詳析,NumPy,是Numerical?Python的簡(jiǎn)稱,用于高性能科學(xué)計(jì)算和數(shù)據(jù)分析的基礎(chǔ)包,更多相關(guān)內(nèi)容需要的朋友可以參考一下2022-07-07
python交互模式基礎(chǔ)知識(shí)點(diǎn)學(xué)習(xí)
在本篇內(nèi)容里小編給大家整理的是關(guān)于python交互模式是什么的相關(guān)基礎(chǔ)知識(shí)點(diǎn),需要的朋友們可以參考下。2020-06-06
python+selenium定時(shí)爬取丁香園的新型冠狀病毒數(shù)據(jù)并制作出類似的地圖(部署到云服務(wù)器)
這篇文章主要介紹了python+selenium定時(shí)爬取丁香園的新冠病毒每天的數(shù)據(jù)并制作出類似的地圖(部署到云服務(wù)器),本文給大家介紹的非常詳細(xì),具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2020-02-02
python opencv實(shí)現(xiàn)切變換 不裁減圖片
這篇文章主要為大家詳細(xì)介紹了python opencv實(shí)現(xiàn)切變換,不裁減圖片,具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2018-07-07
Python面向?qū)ο蟪绦蛟O(shè)計(jì)OOP入門教程【類,實(shí)例,繼承,重載等】
這篇文章主要介紹了Python面向?qū)ο蟪绦蛟O(shè)計(jì)OOP入門教程,較為詳細(xì)的分析了Python面向?qū)ο箢?實(shí)例,繼承,重載等相關(guān)概念與使用技巧,需要的朋友可以參考下2019-01-01
Python ORM框架SQLAlchemy學(xué)習(xí)筆記之映射類使用實(shí)例和Session會(huì)話介紹
這篇文章主要介紹了Python ORM框架SQLAlchemy學(xué)習(xí)筆記之映射類使用實(shí)例和Session會(huì)話介紹,需要的朋友可以參考下2014-06-06

