pytorch中關(guān)于distributedsampler函數(shù)的使用
關(guān)于distributedsampler函數(shù)的使用
1.如何使用這個分布式采樣器
在使用distributedsampler函數(shù)時,觀察loss發(fā)現(xiàn)loss收斂有規(guī)律,發(fā)現(xiàn)是按順序讀取數(shù)據(jù),未進(jìn)行shuffle。
問題的解決方式就是懷疑 seed 有問題,參考源碼 DistributedSampler,發(fā)現(xiàn) shuffle 的結(jié)果依賴 g.manual_seed(self.epoch) 中的 self.epoch。
def __iter__(self):
# deterministically shuffle based on epoch
g = torch.Generator()
g.manual_seed(self.epoch)
if self.shuffle:
indices = torch.randperm(len(self.dataset), generator=g).tolist()
else:
indices = list(range(len(self.dataset)))
# add extra samples to make it evenly divisible
indices += indices[:(self.total_size - len(indices))]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
assert len(indices) == self.num_samples
return iter(indices)而 self.epoch 初始默認(rèn)是 0
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
self.total_size = self.num_samples * self.num_replicas
self.shuffle = shuffle但是 DistributedSampler 也提供了一個 set 函數(shù)來改變 self.epoch
def set_epoch(self, epoch):
self.epoch = epoch所以在運(yùn)行的時候要不斷調(diào)用這個 set_epoch 函數(shù)。只要把我的代碼中的
# sampler.set_epoch(e)
全部代碼如下:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler
torch.distributed.init_process_group(backend="nccl")
input_size = 5
output_size = 2
batch_size = 2
data_size = 16
local_rank = torch.distributed.get_rank()
torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank)
class RandomDataset(Dataset):
def __init__(self, size, length, local_rank):
self.len = length
self.data = torch.stack([torch.ones(5), torch.ones(5)*2,
torch.ones(5)*3,torch.ones(5)*4,
torch.ones(5)*5,torch.ones(5)*6,
torch.ones(5)*7,torch.ones(5)*8,
torch.ones(5)*9, torch.ones(5)*10,
torch.ones(5)*11,torch.ones(5)*12,
torch.ones(5)*13,torch.ones(5)*14,
torch.ones(5)*15,torch.ones(5)*16]).to('cuda')
self.local_rank = local_rank
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
dataset = RandomDataset(input_size, data_size, local_rank)
sampler = DistributedSampler(dataset)
rand_loader = DataLoader(dataset=dataset,
batch_size=batch_size,
sampler=sampler)
e = 0
while e < 2:
t = 0
# sampler.set_epoch(e)
for data in rand_loader:
print(data)
e+=1運(yùn)行:
CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 test.py
2.關(guān)于用不用這個采樣器的區(qū)別
多卡去訓(xùn)模型,嘗試著用DDP模式,而不是DP模式去加速訓(xùn)練(很容易出現(xiàn)負(fù)載不均衡的情況)。
遇到了一點(diǎn)關(guān)于DistributedSampler這個采樣器的一點(diǎn)疑惑,想試驗(yàn)下在DDP模式下,使用這個采樣器和不使用這個采樣器有什么區(qū)別。
實(shí)驗(yàn)代碼:
整個數(shù)據(jù)集大小為8,batch_size 為4,總共跑2個epoch。
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler
torch.distributed.init_process_group(backend="nccl")
batch_size = 4
data_size = 8
local_rank = torch.distributed.get_rank()
print(local_rank)
torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank)
class RandomDataset(Dataset):
def __init__(self, length, local_rank):
self.len = length
self.data = torch.stack([torch.ones(1), torch.ones(1)*2,torch.ones(1)*3,torch.ones(1)*4,torch.ones(1)*5,torch.ones(1)*6,torch.ones(1)*7,torch.ones(1)*8]).to('cuda')
self.local_rank = local_rank
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
dataset = RandomDataset(data_size, local_rank)
sampler = DistributedSampler(dataset)
#rand_loader =DataLoader(dataset=dataset,batch_size=batch_size,sampler=None,shuffle=True)
rand_loader = DataLoader(dataset=dataset,batch_size=batch_size,sampler=sampler)
epoch = 0
while epoch < 2:
sampler.set_epoch(epoch)
for data in rand_loader:
print(data)
epoch+=1運(yùn)行命令:
CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 test.py
實(shí)驗(yàn)結(jié)果:

結(jié)論分析:上面的運(yùn)行結(jié)果來看,在一個epoch中,sampler相當(dāng)于把整個數(shù)據(jù)集 劃分成了nproc_per_node份,每個GPU每次得到batch_size的數(shù)量,也就是nproc_per_node 個GPU分一整份數(shù)據(jù)集,總數(shù)據(jù)量大小就為1個dataset。
如果不用它里面自帶的sampler,單純的還是按照我們一般的形式。Sampler=None,shuffle=True這種,那么結(jié)果將會是下面這樣的:
結(jié)果分析:沒用sampler的話,在一個epoch中,每個GPU各自維護(hù)著一份數(shù)據(jù),每個GPU每次得到的batch_size的數(shù)據(jù),總的數(shù)據(jù)量為2個dataset,

總結(jié)
一般的形式的dataset只能在同進(jìn)程中進(jìn)行采樣分發(fā),也就是為什么圖2只能單GPU維護(hù)自己的dataset,DDP中的sampler可以對不同進(jìn)程進(jìn)行分發(fā)數(shù)據(jù),圖1,可以夸不同進(jìn)程(GPU)進(jìn)行分發(fā)。
以上為個人經(jīng)驗(yàn),希望能給大家一個參考,也希望大家多多支持腳本之家。
相關(guān)文章
python同義詞替換的實(shí)現(xiàn)(jieba分詞)
這篇文章主要介紹了python同義詞替換的實(shí)現(xiàn)(jieba分詞),文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2020-01-01
Python數(shù)據(jù)分析之如何利用pandas查詢數(shù)據(jù)示例代碼
查詢和分析數(shù)據(jù)是pandas的重要功能,也是我們學(xué)習(xí)pandas的基礎(chǔ),下面這篇文章主要給大家介紹了關(guān)于Python數(shù)據(jù)分析之如何利用pandas查詢數(shù)據(jù)的相關(guān)資料,文中通過示例代碼介紹的非常詳細(xì),需要的朋友可以參考借鑒,下面來一起看看吧。2017-09-09
Pandas 稀疏數(shù)據(jù)結(jié)構(gòu)的實(shí)現(xiàn)
如果數(shù)據(jù)中有很多NaN的值,存儲起來就會浪費(fèi)空間。為了解決這個問題,Pandas引入了一種叫做Sparse data的結(jié)構(gòu),來有效的存儲這些NaN的值,本文就來詳細(xì)的介紹了一下,感興趣的可以了解一下2021-07-07
python 普通克里金(Kriging)法的實(shí)現(xiàn)
這篇文章主要介紹了python 普通克里金(Kriging)法的實(shí)現(xiàn),文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2019-12-12
淺談python print(xx, flush = True) 全網(wǎng)最清晰的解釋
今天小編就為大家分享一篇淺談python print(xx, flush = True) 全網(wǎng)最清晰的解釋,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2020-02-02
python GUI實(shí)現(xiàn)小球滿屏亂跑效果
這篇文章主要為大家詳細(xì)介紹了python GUI實(shí)現(xiàn)小球滿屏亂跑效果,具有一定的參考價值,感興趣的小伙伴們可以參考一下2019-05-05
Python通過pytesseract庫實(shí)現(xiàn)識別圖片中的文字
Pytesseract是一個Python的OCR庫,它可以識別圖片中的文本并將其轉(zhuǎn)換成文本形式。本文就來用pytesseract庫實(shí)現(xiàn)識別圖片中的文字,感興趣的可以了解一下2023-05-05

