PyTorch 如何將CIFAR100數(shù)據(jù)按類(lèi)標(biāo)歸類(lèi)保存
few-shot learning的采樣
Few-shot learning 基于任務(wù)對(duì)模型進(jìn)行訓(xùn)練,在N-way-K-shot中,一個(gè)任務(wù)中的meta-training中含有N類(lèi),每一類(lèi)抽取K個(gè)樣本構(gòu)成support set, query set則是在剛才抽取的N類(lèi)剩余的樣本中sample一定數(shù)量的樣本(可以是均勻采樣,也可以是不均勻采樣)。
對(duì)數(shù)據(jù)按類(lèi)標(biāo)歸類(lèi)
針對(duì)上述情況,我們需要使用不同類(lèi)別放置在不同文件夾的數(shù)據(jù)集。但有時(shí),數(shù)據(jù)并沒(méi)有按類(lèi)放置,這時(shí)就需要對(duì)數(shù)據(jù)進(jìn)行處理。
下面以CIFAR100為列(不含N-way-k-shot的采樣):
import os
from skimage import io
import torchvision as tv
import numpy as np
import torch
def Cifar100(root):
character = [[] for i in range(100)]
train_set = tv.datasets.CIFAR100(root, train=True, download=True)
test_set = tv.datasets.CIFAR100(root, train=False, download=True)
dataset = []
for (X, Y) in zip(train_set.train_data, train_set.train_labels): # 將train_set的數(shù)據(jù)和label讀入列表
dataset.append(list((X, Y)))
for (X, Y) in zip(test_set.test_data, test_set.test_labels): # 將test_set的數(shù)據(jù)和label讀入列表
dataset.append(list((X, Y)))
for X, Y in dataset:
character[Y].append(X) # 32*32*3
character = np.array(character)
character = torch.from_numpy(character)
# 按類(lèi)打亂
np.random.seed(6)
shuffle_class = np.arange(len(character))
np.random.shuffle(shuffle_class)
character = character[shuffle_class]
# shape = self.character.shape
# self.character = self.character.view(shape[0], shape[1], shape[4], shape[2], shape[3]) # 將數(shù)據(jù)轉(zhuǎn)成channel在前
meta_training, meta_validation, meta_testing = \
character[:64], character[64:80], character[80:] # meta_training : meta_validation : Meta_testing = 64類(lèi):16類(lèi):20類(lèi)
dataset = [] # 釋放內(nèi)存
character = []
os.mkdir(os.path.join(root, 'meta_training'))
for i, per_class in enumerate(meta_training):
character_path = os.path.join(root, 'meta_training', 'character_' + str(i))
os.mkdir(character_path)
for j, img in enumerate(per_class):
img_path = character_path + '/' + str(j) + ".jpg"
io.imsave(img_path, img)
os.mkdir(os.path.join(root, 'meta_validation'))
for i, per_class in enumerate(meta_validation):
character_path = os.path.join(root, 'meta_validation', 'character_' + str(i))
os.mkdir(character_path)
for j, img in enumerate(per_class):
img_path = character_path + '/' + str(j) + ".jpg"
io.imsave(img_path, img)
os.mkdir(os.path.join(root, 'meta_testing'))
for i, per_class in enumerate(meta_testing):
character_path = os.path.join(root, 'meta_testing', 'character_' + str(i))
os.mkdir(character_path)
for j, img in enumerate(per_class):
img_path = character_path + '/' + str(j) + ".jpg"
io.imsave(img_path, img)
if __name__ == '__main__':
root = '/home/xie/文檔/datasets/cifar_100'
Cifar100(root)
print("-----------------")
補(bǔ)充:使用Pytorch對(duì)數(shù)據(jù)集CIFAR-10進(jìn)行分類(lèi)
主要是以下幾個(gè)步驟:
1、下載并預(yù)處理數(shù)據(jù)集
2、定義網(wǎng)絡(luò)結(jié)構(gòu)
3、定義損失函數(shù)和優(yōu)化器
4、訓(xùn)練網(wǎng)絡(luò)并更新參數(shù)
5、測(cè)試網(wǎng)絡(luò)效果
#數(shù)據(jù)加載和預(yù)處理
#使用CIFAR-10數(shù)據(jù)進(jìn)行分類(lèi)實(shí)驗(yàn)
import torch as t
import torchvision as tv
import torchvision.transforms as transforms
from torchvision.transforms import ToPILImage
show = ToPILImage() # 可以把Tensor轉(zhuǎn)成Image,方便可視化
#定義對(duì)數(shù)據(jù)的預(yù)處理
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)), #歸一化
])
#訓(xùn)練集
trainset = tv.datasets.CIFAR10(
root = './data/',
train = True,
download = True,
transform = transform
)
trainloader = t.utils.data.DataLoader(
trainset,
batch_size = 4,
shuffle = True,
num_workers = 2,
)
#測(cè)試集
testset = tv.datasets.CIFAR10(
root = './data/',
train = False,
download = True,
transform = transform,
)
testloader = t.utils.data.DataLoader(
testset,
batch_size = 4,
shuffle = False,
num_workers = 2,
)
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
初次下載需要一些時(shí)間,運(yùn)行結(jié)束后,顯示如下:
![]()
import torch.nn as nn
import torch.nn.functional as F
import time
start = time.time()#計(jì)時(shí)
#定義網(wǎng)絡(luò)結(jié)構(gòu)
class Net(nn.Module):
def __init__(self):
super(Net,self).__init__()
self.conv1 = nn.Conv2d(3,6,5)
self.conv2 = nn.Conv2d(6,16,5)
self.fc1 = nn.Linear(16*5*5,120)
self.fc2 = nn.Linear(120,84)
self.fc3 = nn.Linear(84,10)
def forward(self,x):
x = F.max_pool2d(F.relu(self.conv1(x)),2)
x = F.max_pool2d(F.relu(self.conv2(x)),2)
x = x.view(x.size()[0],-1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
net = Net()
print(net)
顯示net結(jié)構(gòu)如下:
#定義優(yōu)化和損失
loss_func = nn.CrossEntropyLoss() #交叉熵?fù)p失函數(shù)
optimizer = t.optim.SGD(net.parameters(),lr = 0.001,momentum = 0.9)
#訓(xùn)練網(wǎng)絡(luò)
for epoch in range(2):
running_loss = 0
for i,data in enumerate(trainloader,0):
inputs,labels = data
outputs = net(inputs)
loss = loss_func(outputs,labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
running_loss +=loss.item()
if i%2000 ==1999:
print('epoch:',epoch+1,'|i:',i+1,'|loss:%.3f'%(running_loss/2000))
running_loss = 0.0
end = time.time()
time_using = end - start
print('finish training')
print('time:',time_using)
結(jié)果如下:

下一步進(jìn)行使用測(cè)試集進(jìn)行網(wǎng)絡(luò)測(cè)試:
#測(cè)試網(wǎng)絡(luò)
correct = 0 #定義的預(yù)測(cè)正確的圖片數(shù)
total = 0#總共圖片個(gè)數(shù)
with t.no_grad():
for data in testloader:
images,labels = data
outputs = net(images)
_,predict = t.max(outputs,1)
total += labels.size(0)
correct += (predict == labels).sum()
print('測(cè)試集中的準(zhǔn)確率為:%d%%'%(100*correct/total))
結(jié)果如下:
![]()
簡(jiǎn)單的網(wǎng)絡(luò)訓(xùn)練確實(shí)要比10%的比例高一點(diǎn):)
在GPU中訓(xùn)練:
#在GPU中訓(xùn)練
device = t.device('cuda:0' if t.cuda.is_available() else 'cpu')
net.to(device)
images = images.to(device)
labels = labels.to(device)
output = net(images)
loss = loss_func(output,labels)
loss
![]()
以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教。
相關(guān)文章
python實(shí)現(xiàn)多線程的方式及多條命令并發(fā)執(zhí)行
這篇文章主要為大家詳細(xì)介紹了python實(shí)現(xiàn)多線程的方式及多條命令并發(fā)執(zhí)行,感興趣的小伙伴們可以參考一下2016-06-06
python刪掉重復(fù)行之drop_duplicates()用法示例
Pandas的drop_duplicates()方法用于從DataFrame中刪除重復(fù)的行,這篇文章主要給大家介紹了關(guān)于python刪掉重復(fù)行之drop_duplicates()用法的相關(guān)資料,文中通過(guò)代碼介紹的非常詳細(xì),需要的朋友可以參考下2024-08-08
用python實(shí)現(xiàn)一個(gè)讓人戒不掉的百變款消消樂(lè)
消消樂(lè)的熱門(mén)程度幾乎趕上王者榮耀,你是否也有收到過(guò)好友邀請(qǐng)你幫解鎖關(guān)卡的時(shí)候,今天小編帶你用python編寫(xiě)一個(gè)自己的消消樂(lè)升級(jí)版,同學(xué)請(qǐng)往下看2021-09-09
Python實(shí)現(xiàn)解析yaml配置文件的示例詳解
在開(kāi)發(fā)過(guò)程中,配置文件是少不了的,而且配置文件是有專(zhuān)門(mén)的格式的,比如:ini,yaml,toml等等。本文帶大家來(lái)看看Python如何解析yaml文件,它的表達(dá)能力相比?ini?更加的強(qiáng)大,需要的可以參考一下2022-09-09
Python enumerate() 函數(shù)如何實(shí)現(xiàn)索引功能
這篇文章主要介紹了Python enumerate() 函數(shù)如何實(shí)現(xiàn)索引功能,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2020-06-06
Python基礎(chǔ)之教你怎么在M1系統(tǒng)上使用pandas
這篇文章主要介紹了Python基礎(chǔ)之教你怎么在M1系統(tǒng)上使用pandas,文中有非常詳細(xì)的代碼示例,對(duì)正在學(xué)習(xí)python基礎(chǔ)的小伙伴們有很好地幫助,需要的朋友可以參考下2021-05-05
matplotlib教程——強(qiáng)大的python作圖工具庫(kù)
這篇文章主要介紹了python matplotlib的相關(guān)資料,幫助大家更好的利用python matplotlib庫(kù)繪制圖表,感興趣的朋友可以了解下2020-10-10
Win7下Python與Tensorflow-CPU版開(kāi)發(fā)環(huán)境的安裝與配置過(guò)程
這篇文章主要介紹了Win7下Python與Tensorflow-CPU版安裝與配置心得,需要的朋友可以參考下2018-01-01
python多線程實(shí)現(xiàn)TCP服務(wù)端
這篇文章主要為大家詳細(xì)介紹了python多線程實(shí)現(xiàn)TCP服務(wù)端,具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2019-09-09

