Pytorch訓(xùn)練模型得到輸出后計(jì)算F1-Score 和AUC的操作
1、計(jì)算F1-Score
對(duì)于二分類來說,假設(shè)batch size 大小為64的話,那么模型一個(gè)batch的輸出應(yīng)該是torch.size([64,2]),所以首先做的是得到這個(gè)二維矩陣的每一行的最大索引值,然后添加到一個(gè)列表中,同時(shí)把標(biāo)簽也添加到一個(gè)列表中,最后使用sklearn中計(jì)算F1的工具包進(jìn)行計(jì)算,代碼如下
import numpy as np
import sklearn.metrics import f1_score
prob_all = []
lable_all = []
for i, (data,label) in tqdm(train_data_loader):
prob = model(data) #表示模型的預(yù)測(cè)輸出
prob = prob.cpu().numpy() #先把prob轉(zhuǎn)到CPU上,然后再轉(zhuǎn)成numpy,如果本身在CPU上訓(xùn)練的話就不用先轉(zhuǎn)成CPU了
prob_all.extend(np.argmax(prob,axis=1)) #求每一行的最大值索引
label_all.extend(label)
print("F1-Score:{:.4f}".format(f1_score(label_all,prob_all)))
2、計(jì)算AUC
計(jì)算AUC的時(shí)候,本次使用的是sklearn中的roc_auc_score () 方法
輸入?yún)?shù):
y_true:真實(shí)的標(biāo)簽。形狀 (n_samples,) 或 (n_samples, n_classes)。二分類的形狀 (n_samples,1),而多標(biāo)簽情況的形狀 (n_samples, n_classes)。
y_score:目標(biāo)分?jǐn)?shù)。形狀 (n_samples,) 或 (n_samples, n_classes)。二分類情況形狀 (n_samples,1),“分?jǐn)?shù)必須是具有較大標(biāo)簽的類的分?jǐn)?shù)”,通俗點(diǎn)理解:模型打分的第二列。舉個(gè)例子:模型輸入的得分是一個(gè)數(shù)組 [0.98361117 0.01638886],索引是其類別,這里 “較大標(biāo)簽類的分?jǐn)?shù)”,指的是索引為 1 的分?jǐn)?shù):0.01638886,也就是正例的預(yù)測(cè)得分。
average='macro':二分類時(shí),該參數(shù)可以忽略。用于多分類,' micro ':將標(biāo)簽指標(biāo)矩陣的每個(gè)元素看作一個(gè)標(biāo)簽,計(jì)算全局的指標(biāo)。' macro ':計(jì)算每個(gè)標(biāo)簽的指標(biāo),并找到它們的未加權(quán)平均值。這并沒有考慮標(biāo)簽的不平衡。' weighted ':計(jì)算每個(gè)標(biāo)簽的指標(biāo),并找到它們的平均值,根據(jù)支持度 (每個(gè)標(biāo)簽的真實(shí)實(shí)例的數(shù)量) 進(jìn)行加權(quán)。
sample_weight=None:樣本權(quán)重。形狀 (n_samples,),默認(rèn) = 無。
max_fpr=None:
multi_class='raise':(多分類的問題在下一篇文章中解釋)
labels=None:
輸出:
auc:是一個(gè) float 的值。
import numpy as np
import sklearn.metrics import roc_auc_score
prob_all = []
lable_all = []
for i, (data,label) in tqdm(train_data_loader):
prob = model(data) #表示模型的預(yù)測(cè)輸出
prob_all.extend(prob[:,1].cpu().numpy()) #prob[:,1]返回每一行第二列的數(shù),根據(jù)該函數(shù)的參數(shù)可知,y_score表示的較大標(biāo)簽類的分?jǐn)?shù),因此就是最大索引對(duì)應(yīng)的那個(gè)值,而不是最大索引值
label_all.extend(label)
print("AUC:{:.4f}".format(roc_auc_score(label_all,prob_all)))
補(bǔ)充:pytorch訓(xùn)練模型的一些坑
1. 圖像讀取
opencv的python和c++讀取的圖像結(jié)果不一致,是因?yàn)閜ython和c++采用的opencv版本不一樣,從而使用的解碼庫(kù)不同,導(dǎo)致讀取的結(jié)果不同。
2. 圖像變換
PIL和pytorch的圖像resize操作,與opencv的resize結(jié)果不一樣,這樣會(huì)導(dǎo)致訓(xùn)練采用PIL,預(yù)測(cè)時(shí)采用opencv,結(jié)果差別很大,尤其是在檢測(cè)和分割任務(wù)中比較明顯。
3. 數(shù)值計(jì)算
pytorch的torch.exp與c++的exp計(jì)算,10e-6的數(shù)值時(shí)候會(huì)有10e-3的誤差,對(duì)于高精度計(jì)算需要特別注意,比如
兩個(gè)輸入5.601597, 5.601601, 經(jīng)過exp計(jì)算后變成270.85862343143174, 270.85970686809225
以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
Pytorch隨機(jī)數(shù)生成常用的4種方法匯總
隨機(jī)數(shù)廣泛應(yīng)用在科學(xué)研究,但是計(jì)算機(jī)無法產(chǎn)生真正的隨機(jī)數(shù),一般成為偽隨機(jī)數(shù),下面這篇文章主要給大家介紹了關(guān)于Pytorch隨機(jī)數(shù)生成常用的4種方法,需要的朋友可以參考下2023-05-05
利用Opencv實(shí)現(xiàn)圖片的油畫特效實(shí)例
這篇文章主要給大家介紹了關(guān)于利用Opencv實(shí)現(xiàn)圖片的油畫特效的相關(guān)資料,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2021-02-02
Python寫的Socks5協(xié)議代理服務(wù)器
這篇文章主要介紹了Python寫的Socks5協(xié)議代理服務(wù)器,代碼來自網(wǎng)上,需要的朋友可以參考下2014-08-08
Python使用enumerate獲取迭代元素下標(biāo)
這篇文章主要介紹了python使用enumerate獲取迭代元素下標(biāo),文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2020-02-02
基于Python實(shí)現(xiàn)將列表數(shù)據(jù)生成折線圖
這篇文章主要介紹了如何利用Python中的pandas庫(kù)和matplotlib庫(kù),實(shí)現(xiàn)將列表數(shù)據(jù)生成折線圖,文中的示例代碼簡(jiǎn)潔易懂,需要的可以參考一下2022-03-03
Flask與FastAPI對(duì)比選擇最佳Python?Web框架的超詳細(xì)指南
Flask和FastAPI都是流行的Python?Web框架,各有特點(diǎn),Flask輕量級(jí)、靈活,適合小型項(xiàng)目和原型開發(fā)但不支持異步操作,FastAPI高性能、支持異步,內(nèi)置數(shù)據(jù)驗(yàn)證和自動(dòng)生成API文檔,適合高并發(fā)和API開發(fā),需要的朋友可以參考下2025-02-02
Python循環(huán)語(yǔ)句之while循環(huán)和for循環(huán)詳解
在Python中,循環(huán)語(yǔ)句用于重復(fù)執(zhí)行一段代碼,直到滿足某個(gè)條件為止,在Python中,有兩種主要的循環(huán)語(yǔ)句:for循環(huán)和while循環(huán),本文就來給大家介紹一下這兩個(gè)循環(huán)的用法,需要的朋友可以參考下2023-08-08
Python+django實(shí)現(xiàn)文件下載
本文是python+django系列的第二篇文章,主要是講述是先文件下載的方法和代碼,有需要的小伙伴可以參考下。2016-01-01

