CoAtNet實(shí)戰(zhàn)之對植物幼苗圖像進(jìn)行分類(pytorch)
前言
雖然Transformer在CV任務(wù)上有非常強(qiáng)的學(xué)習(xí)建模能力,但是由于缺少了像CNN那樣的歸納偏置,所以相比于CNN,Transformer的泛化能力就比較差。因此,如果只有Transformer進(jìn)行全局信息的建模,在沒有預(yù)訓(xùn)練(JFT-300M)的情況下,Transformer在性能上很難超過CNN(VOLO在沒有預(yù)訓(xùn)練的情況下,一定程度上也是因?yàn)閂OLO的Outlook Attention對特征信息進(jìn)行了局部感知,相當(dāng)于引入了歸納偏置)。既然CNN有更強(qiáng)的泛化能力,Transformer具有更強(qiáng)的學(xué)習(xí)能力,那么,為什么不能將Transformer和CNN進(jìn)行一個結(jié)合呢?
谷歌的最新模型CoAtNet做了卷積 + Transformer的融合,在ImageNet-1K數(shù)據(jù)集上取得88.56%的成績。今天我們就用CoAtNet實(shí)現(xiàn)植物幼苗的分類。

項(xiàng)目結(jié)構(gòu)

數(shù)據(jù)集
數(shù)據(jù)集選用植物幼苗分類,總共12類。數(shù)據(jù)集連接如下:
鏈接 提取碼:q060
在工程的根目錄新建data文件夾,獲取數(shù)據(jù)集后,將trian和test解壓放到data文件夾下面,如下圖:

安裝庫,并導(dǎo)入需要的庫
安裝完成后,導(dǎo)入到項(xiàng)目中。
import torch.optim as optim import torch import torch.nn as nn import torch.nn.parallel import torch.utils.data import torch.utils.data.distributed import torchvision.transforms as transforms from dataset.dataset import SeedlingData from torch.autograd import Variable from models.coatnet import coatnet_0
設(shè)置全局參數(shù)
設(shè)置使用GPU,設(shè)置學(xué)習(xí)率、BatchSize、epoch等參數(shù)
# 設(shè)置全局參數(shù)
modellr = 1e-4
BATCH_SIZE = 16
EPOCHS = 50
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
數(shù)據(jù)預(yù)處理
數(shù)據(jù)處理比較簡單,沒有做復(fù)雜的嘗試,有興趣的可以加入一些處理。
# 數(shù)據(jù)預(yù)處理
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
transform_test = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
數(shù)據(jù)讀取
然后我們在dataset文件夾下面新建 init.py和dataset.py,在mydatasets.py文件夾寫入下面的代碼:
說一下代碼的核心邏輯。
第一步 建立字典,定義類別對應(yīng)的ID,用數(shù)字代替類別。
第二步 在__init__里面編寫獲取圖片路徑的方法。測試集只有一層路徑直接讀取,訓(xùn)練集在train文件夾下面是類別文件夾,先獲取到類別,再獲取到具體的圖片路徑。然后使用sklearn中切分?jǐn)?shù)據(jù)集的方法,按照7:3的比例切分訓(xùn)練集和驗(yàn)證集。
第三步 在__getitem__方法中定義讀取單個圖片和類別的方法,由于圖像中有位深度32位的,所以我在讀取圖像的時候做了轉(zhuǎn)換。
代碼如下:
# coding:utf8
import os
from PIL import Image
from torch.utils import data
from torchvision import transforms as T
from sklearn.model_selection import train_test_split
Labels = {'Black-grass': 0, 'Charlock': 1, 'Cleavers': 2, 'Common Chickweed': 3,
'Common wheat': 4, 'Fat Hen': 5, 'Loose Silky-bent': 6, 'Maize': 7, 'Scentless Mayweed': 8,
'Shepherds Purse': 9, 'Small-flowered Cranesbill': 10, 'Sugar beet': 11}
class SeedlingData (data.Dataset):
def __init__(self, root, transforms=None, train=True, test=False):
"""
主要目標(biāo): 獲取所有圖片的地址,并根據(jù)訓(xùn)練,驗(yàn)證,測試劃分?jǐn)?shù)據(jù)
"""
self.test = test
self.transforms = transforms
if self.test:
imgs = [os.path.join(root, img) for img in os.listdir(root)]
self.imgs = imgs
else:
imgs_labels = [os.path.join(root, img) for img in os.listdir(root)]
imgs = []
for imglable in imgs_labels:
for imgname in os.listdir(imglable):
imgpath = os.path.join(imglable, imgname)
imgs.append(imgpath)
trainval_files, val_files = train_test_split(imgs, test_size=0.3, random_state=42)
if train:
self.imgs = trainval_files
else:
self.imgs = val_files
def __getitem__(self, index):
"""
一次返回一張圖片的數(shù)據(jù)
"""
img_path = self.imgs[index]
img_path=img_path.replace("\\",'/')
if self.test:
label = -1
else:
labelname = img_path.split('/')[-2]
label = Labels[labelname]
data = Image.open(img_path).convert('RGB')
data = self.transforms(data)
return data, label
def __len__(self):
return len(self.imgs)
然后我們在train.py調(diào)用SeedlingData讀取數(shù)據(jù) ,記著導(dǎo)入剛才寫的dataset.py(from mydatasets import SeedlingData)
# 讀取數(shù)據(jù)
dataset_train = SeedlingData('data/train', transforms=transform, train=True)
dataset_test = SeedlingData("data/train", transforms=transform_test, train=False)
# 導(dǎo)入數(shù)據(jù)
train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset_test, batch_size=BATCH_SIZE, shuffle=False)
設(shè)置模型
- 設(shè)置loss函數(shù)為nn.CrossEntropyLoss()。
- 設(shè)置模型為coatnet_0,修改最后一層全連接輸出改為12。
- 優(yōu)化器設(shè)置為adam。
- 學(xué)習(xí)率調(diào)整策略改為余弦退火
# 實(shí)例化模型并且移動到GPU criterion = nn.CrossEntropyLoss() model_ft = coatnet_0() num_ftrs = model_ft.fc.in_features model_ft.fc = nn.Linear(num_ftrs, 12) model_ft.to(DEVICE) # 選擇簡單暴力的Adam優(yōu)化器,學(xué)習(xí)率調(diào)低 optimizer = optim.Adam(model_ft.parameters(), lr=modellr) cosine_schedule = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer,T_max=20,eta_min=1e-9)
# 定義訓(xùn)練過程
def train(model, device, train_loader, optimizer, epoch):
model.train()
sum_loss = 0
total_num = len(train_loader.dataset)
print(total_num, len(train_loader))
for batch_idx, (data, target) in enumerate(train_loader):
data, target = Variable(data).to(device), Variable(target).to(device)
output = model(data)
loss = criterion(output, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print_loss = loss.data.item()
sum_loss += print_loss
if (batch_idx + 1) % 10 == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, (batch_idx + 1) * len(data), len(train_loader.dataset),
100. * (batch_idx + 1) / len(train_loader), loss.item()))
ave_loss = sum_loss / len(train_loader)
print('epoch:{},loss:{}'.format(epoch, ave_loss))
# 驗(yàn)證過程
def val(model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
total_num = len(test_loader.dataset)
print(total_num, len(test_loader))
with torch.no_grad():
for data, target in test_loader:
data, target = Variable(data).to(device), Variable(target).to(device)
output = model(data)
loss = criterion(output, target)
_, pred = torch.max(output.data, 1)
correct += torch.sum(pred == target)
print_loss = loss.data.item()
test_loss += print_loss
correct = correct.data.item()
acc = correct / total_num
avgloss = test_loss / len(test_loader)
print('\nVal set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
avgloss, correct, len(test_loader.dataset), 100 * acc))
# 訓(xùn)練
for epoch in range(1, EPOCHS + 1):
train(model_ft, DEVICE, train_loader, optimizer, epoch)
cosine_schedule.step()
val(model_ft, DEVICE, test_loader)
torch.save(model_ft, 'model.pth')
測試
測試集存放的目錄如下圖:

第一步 定義類別,這個類別的順序和訓(xùn)練時的類別順序?qū)?yīng),一定不要改變順序!?。?!
classes = ('Black-grass', 'Charlock', 'Cleavers', 'Common Chickweed',
'Common wheat', 'Fat Hen', 'Loose Silky-bent',
'Maize', 'Scentless Mayweed', 'Shepherds Purse', 'Small-flowered Cranesbill', 'Sugar beet')
第二步 定義transforms,transforms和驗(yàn)證集的transforms一樣即可,別做數(shù)據(jù)增強(qiáng)。
transform_test = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
第三步 加載model,并將模型放在DEVICE里。
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = torch.load("model.pth")
model.eval()
model.to(DEVICE)
第四步 讀取圖片并預(yù)測圖片的類別,在這里注意,讀取圖片用PIL庫的Image。不要用cv2,transforms不支持。
path = 'data/test/'
testList = os.listdir(path)
for file in testList:
img = Image.open(path + file)
img = transform_test(img)
img.unsqueeze_(0)
img = Variable(img).to(DEVICE)
out = model(img)
# Predict
_, pred = torch.max(out.data, 1)
print('Image Name:{},predict:{}'.format(file, classes[pred.data.item()]))
測試完整代碼:
import torch.utils.data.distributed
import torchvision.transforms as transforms
from PIL import Image
from torch.autograd import Variable
import os
classes = ('Black-grass', 'Charlock', 'Cleavers', 'Common Chickweed',
'Common wheat', 'Fat Hen', 'Loose Silky-bent',
'Maize', 'Scentless Mayweed', 'Shepherds Purse', 'Small-flowered Cranesbill', 'Sugar beet')
transform_test = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = torch.load("model.pth")
model.eval()
model.to(DEVICE)
path = 'data/test/'
testList = os.listdir(path)
for file in testList:
img = Image.open(path + file)
img = transform_test(img)
img.unsqueeze_(0)
img = Variable(img).to(DEVICE)
out = model(img)
# Predict
_, pred = torch.max(out.data, 1)
print('Image Name:{},predict:{}'.format(file, classes[pred.data.item()]))
運(yùn)行結(jié)果:

以上就是CoAtNet實(shí)戰(zhàn)之對植物幼苗圖像進(jìn)行分類(pytorch)的詳細(xì)內(nèi)容,更多關(guān)于CoAtNet 植物幼苗圖像分類的資料請關(guān)注腳本之家其它相關(guān)文章!
相關(guān)文章
Python grequests模塊使用場景及代碼實(shí)例
這篇文章主要介紹了Python grequests模塊使用場景及代碼實(shí)例,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友可以參考下2020-08-08
Python數(shù)據(jù)可視化編程通過Matplotlib創(chuàng)建散點(diǎn)圖代碼示例
這篇文章主要介紹了Python數(shù)據(jù)可視化編程通過Matplotlib創(chuàng)建散點(diǎn)圖實(shí)例,具有一定借鑒價值,需要的朋友可以參考下。2017-12-12
Python鍵鼠操作自動化庫PyAutoGUI簡介(小結(jié))
這篇文章主要介紹了Python鍵鼠操作自動化庫PyAutoGUI簡介,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2020-05-05
吳恩達(dá)機(jī)器學(xué)習(xí)練習(xí):神經(jīng)網(wǎng)絡(luò)(反向傳播)
這篇文章主要介紹了學(xué)習(xí)吳恩達(dá)機(jī)器學(xué)習(xí)中的一個練習(xí):神經(jīng)網(wǎng)絡(luò)(反向傳播),在這個練習(xí)中,你將實(shí)現(xiàn)反向傳播算法來學(xué)習(xí)神經(jīng)網(wǎng)絡(luò)的參數(shù),需要的朋友可以參考下2021-04-04
Python實(shí)現(xiàn)批量下載ts文件并合并為mp4
這篇文章主要為大家詳細(xì)介紹了如何通過Python語言實(shí)現(xiàn)批量下載ts文件并合并為mp4視頻的功能,文中的示例代碼講解詳細(xì),感興趣的小伙伴可以了解一下2023-06-06
Python利用LyScript插件實(shí)現(xiàn)批量打開關(guān)閉進(jìn)程
LyScript是一款x64dbg主動化操控插件,經(jīng)過Python操控X64dbg,完成了遠(yuǎn)程動態(tài)調(diào)試,解決了逆向工作者剖析漏洞,尋覓指令片段,原生腳本不行強(qiáng)壯的問題。本文將利用LyScript插件實(shí)現(xiàn)批量打開關(guān)閉進(jìn)程,感興趣的可以了解一下2022-07-07
python實(shí)戰(zhàn)項(xiàng)目scrapy管道學(xué)習(xí)爬取在行高手?jǐn)?shù)據(jù)
這篇文章主要為介紹了python實(shí)戰(zhàn)項(xiàng)目scrapy管道學(xué)習(xí)拿在行練手爬蟲項(xiàng)目,爬取在行高手?jǐn)?shù)據(jù),本篇博客的重點(diǎn)為scrapy管道pipelines的應(yīng)用,學(xué)習(xí)時請重點(diǎn)關(guān)注2021-11-11

