關(guān)于pytorch處理類別不平衡的問題
當(dāng)訓(xùn)練樣本不均勻時(shí),我們可以采用過采樣、欠采樣、數(shù)據(jù)增強(qiáng)等手段來避免過擬合。今天遇到一個(gè)3d點(diǎn)云數(shù)據(jù)集合,樣本分布極不均勻,正例與負(fù)例相差4-5個(gè)數(shù)量級(jí)。數(shù)據(jù)增強(qiáng)效果就不會(huì)太好了,另外過采樣也不太合適,因?yàn)槭强臻g數(shù)據(jù),新增的點(diǎn)有可能會(huì)對(duì)真實(shí)分布產(chǎn)生未知影響。所以采用欠采樣來緩解類別不平衡的問題。
下面的代碼展示了如何使用WeightedRandomSampler來完成抽樣。
numDataPoints = 1000
data_dim = 5
bs = 100
# Create dummy data with class imbalance 9 to 1
data = torch.FloatTensor(numDataPoints, data_dim)
target = np.hstack((np.zeros(int(numDataPoints * 0.9), dtype=np.int32),
np.ones(int(numDataPoints * 0.1), dtype=np.int32)))
print 'target train 0/1: {}/{}'.format(
len(np.where(target == 0)[0]), len(np.where(target == 1)[0]))
class_sample_count = np.array(
[len(np.where(target == t)[0]) for t in np.unique(target)])
weight = 1. / class_sample_count
samples_weight = np.array([weight[t] for t in target])
samples_weight = torch.from_numpy(samples_weight)
samples_weight = samples_weight.double()
sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
target = torch.from_numpy(target).long()
train_dataset = torch.utils.data.TensorDataset(data, target)
train_loader = DataLoader(
train_dataset, batch_size=bs, num_workers=1, sampler=sampler)
for i, (data, target) in enumerate(train_loader):
print "batch index {}, 0/1: {}/{}".format(
i,
len(np.where(target.numpy() == 0)[0]),
len(np.where(target.numpy() == 1)[0]))
核心部分為實(shí)際使用時(shí)替換下變量把sampler傳遞給DataLoader即可,注意使用了sampler就不能使用shuffle,另外需要指定采樣點(diǎn)個(gè)數(shù):
class_sample_count = np.array( [len(np.where(target == t)[0]) for t in np.unique(target)]) weight = 1. / class_sample_count samples_weight = np.array([weight[t] for t in target]) samples_weight = torch.from_numpy(samples_weight) samples_weight = samples_weight.double() sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
參考:https://discuss.pytorch.org/t/how-to-handle-imbalanced-classes/11264/2
以上這篇關(guān)于pytorch處理類別不平衡的問題就是小編分享給大家的全部?jī)?nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
基于Python實(shí)現(xiàn)PDF動(dòng)畫翻頁效果的閱讀器
在這篇博客中,我們將深入分析一個(gè)基于 wxPython 實(shí)現(xiàn)的 PDF 閱讀器程序,該程序支持加載 PDF 文件并顯示頁面內(nèi)容,同時(shí)支持頁面切換動(dòng)畫效果,文中有詳細(xì)的代碼示例,需要的朋友可以參考下2025-01-01
Python輕松實(shí)現(xiàn)2位小數(shù)隨機(jī)生成
在Python中,我們經(jīng)常需要生成隨機(jī)數(shù),特別是2位小數(shù)的隨機(jī)數(shù),這在模擬實(shí)驗(yàn)、密碼學(xué)、游戲開發(fā)等領(lǐng)域都很有用,下面是如何在Python中生成2位小數(shù)的隨機(jī)數(shù)的代碼示例,需要的朋友可以參考下2023-11-11
windows下安裝python paramiko模塊的代碼
windows下安裝python paramiko模塊,有需要的朋友可以參考下2013-02-02
詳解matplotlib中pyplot和面向?qū)ο髢煞N繪圖模式之間的關(guān)系
這篇文章主要介紹了詳解matplotlib中pyplot和面向?qū)ο髢煞N繪圖模式之間的關(guān)系,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2021-01-01
利用Python實(shí)現(xiàn)定時(shí)程序的方法
在 Python 中,如何定義一個(gè)定時(shí)器函數(shù)呢?本文主要介紹了2種方法,具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2021-07-07
Boston數(shù)據(jù)集預(yù)測(cè)放假及應(yīng)用優(yōu)缺點(diǎn)評(píng)估
這篇文章主要為大家介紹了Boston數(shù)據(jù)集預(yù)測(cè)放假及應(yīng)用優(yōu)缺點(diǎn)評(píng)估,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪2023-10-10

