詳解PyTorch預(yù)定義數(shù)據(jù)集類datasets.ImageFolder使用方法
datasets.ImageFolder是PyTorch提供的一個預(yù)定義數(shù)據(jù)集類,用于處理圖像數(shù)據(jù)。它可以方便地將一組圖像加載到內(nèi)存中,并為每個圖像分配標簽。
數(shù)據(jù)集準備和目錄結(jié)構(gòu)
要使用datasets.ImageFolder,我們需要準備好一個包含圖像數(shù)據(jù)的目錄,并按照以下方式進行組織:
root/
class1/
img1.jpg
img2.jpg
...
class2/
img1.jpg
img2.jpg
...
...
其中,root代表數(shù)據(jù)集根目錄,class1、class2等代表不同的分類標簽,img1、img2等代表圖像文件名。每個類別(也稱為標簽)應(yīng)該有一個單獨的子目錄,子目錄中包含這個類別的所有圖像文件。同時,每個圖像文件在對應(yīng)的子目錄下,以其文件名作為其類別標簽。這種目錄組織方式可以讓我們輕松獲取圖像和對應(yīng)的標簽信息。
加載數(shù)據(jù)集
完成數(shù)據(jù)集準備之后,我們就可以使用datasets.ImageFolder來加載它了。下面是一個示例代碼:
import torchvision.datasets as datasets
import torchvision.transforms as transforms
data_dir = "/path/to/data"
transforms = transforms.Compose([
transforms.Resize(size=(224, 224)),
transforms.ToTensor(),
])
dataset = datasets.ImageFolder(root=data_dir, transform=transforms)
在這個例子中,我們首先導(dǎo)入datasets和transforms模塊,然后指定數(shù)據(jù)集的根目錄data_dir。接下來,我們定義一個 transforms 對象,它將圖像轉(zhuǎn)換為PyTorch張量,并調(diào)整大小為(224, 224)。
最后,我們使用datasets.ImageFolder來加載圖像數(shù)據(jù)集。ImageFolder類需要兩個參數(shù):root 和 transform。root是數(shù)據(jù)集根目錄;transform指定對每個圖像應(yīng)該執(zhí)行的預(yù)處理操作,例如調(diào)整大小、裁剪、翻轉(zhuǎn)等。
數(shù)據(jù)集劃分
對于機器學(xué)習(xí)任務(wù),我們通常需要將數(shù)據(jù)集劃分成訓(xùn)練集、驗證集和測試集。在PyTorch中,我們可以使用torch.utils.data.random_split函數(shù)來完成數(shù)據(jù)集的劃分。下面是一個示例代碼:
from torch.utils.data import DataLoader, random_split # Split the dataset into train and test sets train_size = int(0.8 * len(dataset)) test_size = len(dataset) - train_size train_dataset, test_dataset = random_split(dataset, [train_size, test_size]) # Split train dataset into train and validation sets val_size = int(0.2 * len(train_dataset)) train_size = len(train_dataset) - val_size train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])
在這個例子中,我們先使用random_split函數(shù)將原始數(shù)據(jù)集劃分為訓(xùn)練集和測試集,在這里80%的數(shù)據(jù)用于訓(xùn)練,20%的數(shù)據(jù)用于測試。然后,我們再次使用random_split函數(shù)將訓(xùn)練集劃分為訓(xùn)練集和驗證集,其中80%的數(shù)據(jù)用于訓(xùn)練,20%的數(shù)據(jù)用于驗證。
數(shù)據(jù)加載器
最后,我們可以使用數(shù)據(jù)加載器(DataLoader)來加載數(shù)據(jù)集。數(shù)據(jù)加載器負責(zé)將圖像數(shù)據(jù)和標簽封裝成批量,并提供多線程方式加載數(shù)據(jù)以加速訓(xùn)練過程。下面是一個示例代碼:
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False) test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
在這里,我們創(chuàng)建了三個數(shù)據(jù)加載器train_loader、val_loader 和 test_loader,它們分別對應(yīng)訓(xùn)練集、驗證集和測試集。batch_size參數(shù)指定了每個批次的大小,shuffle參數(shù)表示是否隨機化輸入數(shù)據(jù)(在訓(xùn)練集中設(shè)置為True,在驗證集和測試集中設(shè)置為False)。
以上就是詳解PyTorch預(yù)定義數(shù)據(jù)集類datasets.ImageFolder使用方法的詳細內(nèi)容,更多關(guān)于PyTorch datasets.ImageFolder的資料請關(guān)注腳本之家其它相關(guān)文章!
相關(guān)文章
python使用pywinauto驅(qū)動微信客戶端實現(xiàn)公眾號爬蟲
這個項目是通過pywinauto控制windows(win10)上的微信PC客戶端來實現(xiàn)公眾號文章的抓取。代碼分成server和client兩部分。server接收client抓取的微信公眾號文章,并且保存到數(shù)據(jù)庫。另外server支持簡單的搜索和導(dǎo)出功能。client通過pywinauto實現(xiàn)微信公眾號文章的抓取。2021-05-05
pycharm配置python 設(shè)置pip安裝源為豆瓣源
這篇文章主要介紹了pycharm配置python 設(shè)置pip安裝源為豆瓣源,文中通過示例代碼介紹的非常詳細,對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2021-02-02

