pytorch?K折交叉驗(yàn)證過程說明及實(shí)現(xiàn)方式
K折交叉交叉驗(yàn)證的過程如下
以200條數(shù)據(jù),十折交叉驗(yàn)證為例子,十折也就是將數(shù)據(jù)分成10組,進(jìn)行10組訓(xùn)練,每組用于測試的數(shù)據(jù)為:數(shù)據(jù)總條數(shù)/組數(shù),即每組20條用于valid,180條用于train,每次valid的都是不同的。
(1)將200條數(shù)據(jù),分成按照 數(shù)據(jù)總條數(shù)/組數(shù)(折數(shù)),進(jìn)行切分。然后取出第i份作為第i次的valid,剩下的作為train
(2)將每組中的train數(shù)據(jù)利用DataLoader和Dataset,進(jìn)行封裝。
(3)將train數(shù)據(jù)用于訓(xùn)練,epoch可以自己定義,然后利用valid做驗(yàn)證。得到一次的train_loss和 valid_loss。
(4)重復(fù)(2)(3)步驟,得到最終的 averge_train_loss和averge_valid_loss
上述過程如下圖所示:

上述的代碼如下:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader,Dataset
import torch.nn.functional as F
from torch.autograd import Variable
#####構(gòu)造的訓(xùn)練集####
x = torch.rand(100,28,28)
y = torch.randn(100,28,28)
x = torch.cat((x,y),dim=0)
label =[1] *100 + [0]*100
label = torch.tensor(label,dtype=torch.long)
######網(wǎng)絡(luò)結(jié)構(gòu)##########
class Net(nn.Module):
#定義Net
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(28*28, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 2)
def forward(self, x):
x = x.view(-1, self.num_flat_features(x))
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
def num_flat_features(self, x):
size = x.size()[1:]
num_features = 1
for s in size:
num_features *= s
return num_features
##########定義dataset##########
class TraindataSet(Dataset):
def __init__(self,train_features,train_labels):
self.x_data = train_features
self.y_data = train_labels
self.len = len(train_labels)
def __getitem__(self,index):
return self.x_data[index],self.y_data[index]
def __len__(self):
return self.len
########k折劃分############
def get_k_fold_data(k, i, X, y): ###此過程主要是步驟(1)
# 返回第i折交叉驗(yàn)證時(shí)所需要的訓(xùn)練和驗(yàn)證數(shù)據(jù),分開放,X_train為訓(xùn)練數(shù)據(jù),X_valid為驗(yàn)證數(shù)據(jù)
assert k > 1
fold_size = X.shape[0] // k # 每份的個(gè)數(shù):數(shù)據(jù)總條數(shù)/折數(shù)(組數(shù))
X_train, y_train = None, None
for j in range(k):
idx = slice(j * fold_size, (j + 1) * fold_size) #slice(start,end,step)切片函數(shù)
##idx 為每組 valid
X_part, y_part = X[idx, :], y[idx]
if j == i: ###第i折作valid
X_valid, y_valid = X_part, y_part
elif X_train is None:
X_train, y_train = X_part, y_part
else:
X_train = torch.cat((X_train, X_part), dim=0) #dim=0增加行數(shù),豎著連接
y_train = torch.cat((y_train, y_part), dim=0)
#print(X_train.size(),X_valid.size())
return X_train, y_train, X_valid,y_valid
def k_fold(k, X_train, y_train, num_epochs=3,learning_rate=0.001, weight_decay=0.1, batch_size=5):
train_loss_sum, valid_loss_sum = 0, 0
train_acc_sum ,valid_acc_sum = 0,0
for i in range(k):
data = get_k_fold_data(k, i, X_train, y_train) # 獲取k折交叉驗(yàn)證的訓(xùn)練和驗(yàn)證數(shù)據(jù)
net = Net() ### 實(shí)例化模型
### 每份數(shù)據(jù)進(jìn)行訓(xùn)練,體現(xiàn)步驟三####
train_ls, valid_ls = train(net, *data, num_epochs, learning_rate,\
weight_decay, batch_size)
print('*'*25,'第',i+1,'折','*'*25)
print('train_loss:%.6f'%train_ls[-1][0],'train_acc:%.4f\n'%valid_ls[-1][1],\
'valid loss:%.6f'%valid_ls[-1][0],'valid_acc:%.4f'%valid_ls[-1][1])
train_loss_sum += train_ls[-1][0]
valid_loss_sum += valid_ls[-1][0]
train_acc_sum += train_ls[-1][1]
valid_acc_sum += valid_ls[-1][1]
print('#'*10,'最終k折交叉驗(yàn)證結(jié)果','#'*10)
####體現(xiàn)步驟四#####
print('train_loss_sum:%.4f'%(train_loss_sum/k),'train_acc_sum:%.4f\n'%(train_acc_sum/k),\
'valid_loss_sum:%.4f'%(valid_loss_sum/k),'valid_acc_sum:%.4f'%(valid_acc_sum/k))
#########訓(xùn)練函數(shù)##########
def train(net, train_features, train_labels, test_features, test_labels, num_epochs, learning_rate,weight_decay, batch_size):
train_ls, test_ls = [], [] ##存儲(chǔ)train_loss,test_loss
dataset = TraindataSet(train_features, train_labels)
train_iter = DataLoader(dataset, batch_size, shuffle=True)
### 將數(shù)據(jù)封裝成 Dataloder 對(duì)應(yīng)步驟(2)
#這里使用了Adam優(yōu)化算法
optimizer = torch.optim.Adam(params=net.parameters(), lr= learning_rate, weight_decay=weight_decay)
for epoch in range(num_epochs):
for X, y in train_iter: ###分批訓(xùn)練
output = net(X)
loss = loss_func(output,y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
### 得到每個(gè)epoch的 loss 和 accuracy
train_ls.append(log_rmse(0,net, train_features, train_labels))
if test_labels is not None:
test_ls.append(log_rmse(1,net, test_features, test_labels))
#print(train_ls,test_ls)
return train_ls, test_ls
def log_rmse(flag,net,x,y):
if flag == 1: ### valid 數(shù)據(jù)集
net.eval()
output = net(x)
result = torch.max(output,1)[1].view(y.size())
corrects = (result.data == y.data).sum().item()
accuracy = corrects*100.0/len(y) #### 5 是 batch_size
loss = loss_func(output,y)
net.train()
return (loss.data.item(),accuracy)
loss_func = nn.CrossEntropyLoss() ###申明loss函
k_fold(10,x,label) ### k=10,十折交叉驗(yàn)證上述代碼中,直接按照順序從x中每次截取20條作為valid,也可以先打亂然后在截取,這樣效果應(yīng)該會(huì)更好。
如下所示:
import random import torch x = torch.rand(100,28,28) y = torch.randn(100,28,28) x = torch.cat((x,y),dim=0) label =[1] *100 + [0]*100 label = torch.tensor(label,dtype=torch.long) index = [i for i in range(len(x))] random.shuffle(index) x = x[index] label = label[index]
交叉驗(yàn)證區(qū)分k折代碼分析
from sklearn.model_selection import GroupKFold
x = np.array([1,2,3,4,5,6,7,8,9,10])
y = np.array([1,2,3,4,5,6,7,8,9,10])
z = np.array(['hello1','hello2','hello3','hello4','hello5','hello6','hello7','hello8','hello9','hello10'])
gkf = GroupKFold(n_splits = 5)
for i,(train_idx,valid_idx) in enumerate(list(gkf.split(x,y,z))):
#groups:object,Always ignored,exists for compatibility.
print('train_idx = ')
print(train_idx)
print('valid_idx = ')
print(valid_idx)
可以看出來首先train_idx以及valid_idx的相應(yīng)值都是從中亂序提取的,其次每個(gè)相應(yīng)值只提取一次,不會(huì)重復(fù)提取。
注意交叉驗(yàn)證的流程:這里首先放一個(gè)對(duì)應(yīng)的交叉驗(yàn)證的圖片:

注意這里的訓(xùn)練方式是每個(gè)初始化的模型分別訓(xùn)練n折的數(shù)值,然后算出對(duì)應(yīng)的權(quán)重內(nèi)容
也就是說這里每一次計(jì)算對(duì)應(yīng)的權(quán)重內(nèi)容(1~n)的時(shí)候,需要將模型的權(quán)重初始化,然后再進(jìn)行訓(xùn)練,訓(xùn)練最終結(jié)束之后,模型的權(quán)重為訓(xùn)練完成之后的平均值,多模類似于模型融合
總結(jié)
以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
python3實(shí)現(xiàn)在二叉樹中找出和為某一值的所有路徑(推薦)
這篇文章主要介紹了python3實(shí)現(xiàn)在二叉樹中找出和為某一值的所有路徑,本文通過一個(gè)實(shí)例demo給大家介紹的非常詳細(xì),具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2019-12-12
8種用Python實(shí)現(xiàn)線性回歸的方法對(duì)比詳解
這篇文章主要介紹了8種用Python實(shí)現(xiàn)線性回歸的方法對(duì)比詳解,說到如何用Python執(zhí)行線性回歸,大部分人會(huì)立刻想到用sklearn的linear_model,但事實(shí)是,Python至少有8種執(zhí)行線性回歸的方法,sklearn并不是最高效的,需要的朋友可以參考下2019-07-07
python實(shí)現(xiàn)爬蟲統(tǒng)計(jì)學(xué)校BBS男女比例之多線程爬蟲(二)
這篇文章主要介紹了python實(shí)現(xiàn)爬蟲統(tǒng)計(jì)學(xué)校BBS男女比例之多線程爬蟲,感興趣的小伙伴們可以參考一下2015-12-12
Python如何設(shè)置utf-8為默認(rèn)編碼的問題
這篇文章主要介紹了Python如何設(shè)置utf-8為默認(rèn)編碼的問題,具有很好的參考價(jià)值,希望對(duì)大家有所幫助,如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2024-06-06
python鏈接oracle數(shù)據(jù)庫以及數(shù)據(jù)庫的增刪改查實(shí)例
下面小編就為大家分享一篇python鏈接oracle數(shù)據(jù)庫以及數(shù)據(jù)庫的增刪改查實(shí)例,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2018-01-01
Python cookbook(數(shù)據(jù)結(jié)構(gòu)與算法)實(shí)現(xiàn)優(yōu)先級(jí)隊(duì)列的方法示例
這篇文章主要介紹了Python cookbook(數(shù)據(jù)結(jié)構(gòu)與算法)實(shí)現(xiàn)優(yōu)先級(jí)隊(duì)列的方法,結(jié)合實(shí)例形式分析了Python中基于給定優(yōu)先級(jí)進(jìn)行隊(duì)列元素排序的相關(guān)操作技巧,需要的朋友可以參考下2018-02-02
python實(shí)現(xiàn)字符串完美拆分split()的方法
今天小編就為大家分享一篇python實(shí)現(xiàn)字符串完美拆分split()的方法,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2019-07-07
Tensorflow中tf.ConfigProto()的用法詳解
今天小編就為大家分享一篇Tensorflow中tf.ConfigProto()的用法詳解,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2020-02-02

