深入理解Pytorch微調(diào)torchvision模型
一、簡(jiǎn)介
在本小節(jié),深入探討如何對(duì)torchvision進(jìn)行微調(diào)和特征提取。所有模型都已經(jīng)預(yù)先在1000類的magenet數(shù)據(jù)集上訓(xùn)練完成。 本節(jié)將深入介紹如何使用幾個(gè)現(xiàn)代的CNN架構(gòu),并將直觀展示如何微調(diào)任意的PyTorch模型。
本節(jié)將執(zhí)行兩種類型的遷移學(xué)習(xí):
- 微調(diào):從預(yù)訓(xùn)練模型開始,更新我們新任務(wù)的所有模型參數(shù),實(shí)質(zhì)上是重新訓(xùn)練整個(gè)模型。
- 特征提?。簭念A(yù)訓(xùn)練模型開始,僅更新從中導(dǎo)出預(yù)測(cè)的最終圖層權(quán)重。它被稱為特征提取,因?yàn)槲覀兪褂妙A(yù)訓(xùn)練的CNN作為固定 的特征提取器,并且僅改變輸出層。
通常這兩種遷移學(xué)習(xí)方法都會(huì)遵循一下步驟:
- 初始化預(yù)訓(xùn)練模型
- 重組最后一層,使其具有與新數(shù)據(jù)集類別數(shù)相同的輸出數(shù)
- 為優(yōu)化算法定義想要的訓(xùn)練期間更新的參數(shù)
- 運(yùn)行訓(xùn)練步驟
二、導(dǎo)入相關(guān)包
from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision
from torchvision import datasets,models,transforms
import matplotlib.pyplot as plt
import time
import os
import copy
print("Pytorch version:",torch.__version__)
print("torchvision version:",torchvision.__version__)
運(yùn)行結(jié)果

三、數(shù)據(jù)輸入
數(shù)據(jù)集——>我在這里
鏈接:https://pan.baidu.com/s/1G3yRfKTQf9sIq1iCSoymWQ
提取碼:1234
#%%輸入 data_dir="D:\Python\Pytorch\data\hymenoptera_data" # 從[resnet,alexnet,vgg,squeezenet,desenet,inception] model_name='squeezenet' # 數(shù)據(jù)集中類別數(shù)量 num_classes=2 # 訓(xùn)練的批量大小 batch_size=8 # 訓(xùn)練epoch數(shù) num_epochs=15 # 用于特征提取的標(biāo)志。為FALSE,微調(diào)整個(gè)模型,為TRUE只更新圖層參數(shù) feature_extract=True
四、輔助函數(shù)
1、模型訓(xùn)練和驗(yàn)證
- train_model函數(shù)處理給定模型的訓(xùn)練和驗(yàn)證。作為輸入,它需要PyTorch模型、數(shù)據(jù)加載器字典、損失函數(shù)、優(yōu)化器、用于訓(xùn)練和驗(yàn) 證epoch數(shù),以及當(dāng)模型是初始模型時(shí)的布爾標(biāo)志。
- is_inception標(biāo)志用于容納 Inception v3 模型,因?yàn)樵擉w系結(jié)構(gòu)使用輔助輸出, 并且整體模型損失涉及輔助輸出和最終輸出,如此處所述。 這個(gè)函數(shù)訓(xùn)練指定數(shù)量的epoch,并且在每個(gè)epoch之后運(yùn)行完整的驗(yàn)證步驟。它還跟蹤最佳性能的模型(從驗(yàn)證準(zhǔn)確率方面),并在訓(xùn)練 結(jié)束時(shí)返回性能最好的模型。在每個(gè)epoch之后,打印訓(xùn)練和驗(yàn)證正確率。
#%%模型訓(xùn)練和驗(yàn)證
device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def train_model(model,dataloaders,criterion,optimizer,num_epochs=25,is_inception=False):
since=time.time()
val_acc_history=[]
best_model_wts=copy.deepcopy(model.state_dict())
best_acc=0.0
for epoch in range(num_epochs):
print('Epoch{}/{}'.format(epoch, num_epochs-1))
print('-'*10)
# 每個(gè)epoch都有一個(gè)訓(xùn)練和驗(yàn)證階段
for phase in['train','val']:
if phase=='train':
model.train()
else:
model.eval()
running_loss=0.0
running_corrects=0
# 迭代數(shù)據(jù)
for inputs,labels in dataloaders[phase]:
inputs=inputs.to(device)
labels=labels.to(device)
# 梯度置零
optimizer.zero_grad()
# 向前傳播
with torch.set_grad_enabled(phase=='train'):
# 獲取模型輸出并計(jì)算損失,開始的特殊情況在訓(xùn)練中他有一個(gè)輔助輸出
# 在訓(xùn)練模式下,通過(guò)將最終輸出和輔助輸出相加來(lái)計(jì)算損耗,在測(cè)試中值考慮最終輸出
if is_inception and phase=='train':
outputs,aux_outputs=model(inputs)
loss1=criterion(outputs,labels)
loss2=criterion(aux_outputs,labels)
loss=loss1+0.4*loss2
else:
outputs=model(inputs)
loss=criterion(outputs,labels)
_,preds=torch.max(outputs,1)
if phase=='train':
loss.backward()
optimizer.step()
# 添加
running_loss+=loss.item()*inputs.size(0)
running_corrects+=torch.sum(preds==labels.data)
epoch_loss=running_loss/len(dataloaders[phase].dataset)
epoch_acc=running_corrects.double()/len(dataloaders[phase].dataset)
print('{}loss : {:.4f} acc:{:.4f}'.format(phase, epoch_loss,epoch_acc))
if phase=='train' and epoch_acc>best_acc:
best_acc=epoch_acc
best_model_wts=copy.deepcopy(model.state_dict())
if phase=='val':
val_acc_history.append(epoch_acc)
print()
time_elapsed=time.time()-since
print('training complete in {:.0f}s'.format(time_elapsed//60, time_elapsed%60))
print('best val acc:{:.4f}'.format(best_acc))
model.load_state_dict(best_model_wts)
return model,val_acc_history
2、設(shè)置模型參數(shù)的'.requires_grad屬性'
當(dāng)我們進(jìn)行特征提取時(shí),此輔助函數(shù)將模型中參數(shù)的 .requires_grad 屬性設(shè)置為False。
默認(rèn)情況下,當(dāng)我們加載一個(gè)預(yù)訓(xùn)練模型時(shí),所有參數(shù)都是 .requires_grad = True,如果我們從頭開始訓(xùn)練或微調(diào),這種設(shè)置就沒問題。
但是,如果我們要運(yùn)行特征提取并且只想為新初始化的層計(jì)算梯度,那么我們希望所有其他參數(shù)不需要梯度變化。
#%%設(shè)置模型參數(shù)的.require——grad屬性
def set_parameter_requires_grad(model,feature_extracting):
if feature_extracting:
for param in model.parameters():
param.require_grad=False
靚仔今天先去跑步了,再不跑來(lái)不及了,先更這么多,后續(xù)明天繼續(xù)~(感謝有人沒有催更!感謝監(jiān)督!希望繼續(xù)監(jiān)督?。?/p>
以上就是深入理解Pytorch微調(diào)torchvision模型的詳細(xì)內(nèi)容,更多關(guān)于Pytorch torchvision模型的資料請(qǐng)關(guān)注腳本之家其它相關(guān)文章!
相關(guān)文章
Python 列表(List) 的三種遍歷方法實(shí)例 詳解
這篇文章主要介紹了Python 列表(List) 的三種遍歷方法實(shí)例 詳解的相關(guān)資料,需要的朋友可以參考下2017-04-04
pandas庫(kù)中to_datetime()方法的使用解析
這篇文章主要介紹了pandas庫(kù)中to_datetime()方法的使用解析,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2022-07-07
關(guān)于Gradio中Button用法及事件監(jiān)聽器click方法使用
介紹了在Gradio中使用Button組件和事件監(jiān)聽器的click方法,通過(guò)一個(gè)簡(jiǎn)單的示例展示了如何實(shí)現(xiàn)點(diǎn)擊按鈕輸出一行文字的功能,在實(shí)際項(xiàng)目中遇到了一個(gè)錯(cuò)誤,經(jīng)過(guò)排查和請(qǐng)教室友后,發(fā)現(xiàn)問題出在inputs參數(shù)的傳遞上,需要傳入一個(gè)包含輸入組件的列表2024-11-11
使用python對(duì)pdf文件進(jìn)行加密等操作
這篇文章主要為大家詳細(xì)介紹了使用python對(duì)pdf文件進(jìn)行加密等操作的相關(guān)知識(shí),文中的示例代碼講解詳細(xì),感興趣的小伙伴可以跟隨小編一起學(xué)習(xí)一下2024-12-12

