返回最大值的index pytorch方式
返回最大值的index
import torch a=torch.tensor([[.1,.2,.3], ? ? ? ? ? ? ? ? [1.1,1.2,1.3], ? ? ? ? ? ? ? ? [2.1,2.2,2.3], ? ? ? ? ? ? ? ? [3.1,3.2,3.3]]) print(a.argmax(dim=1)) print(a.argmax())
輸出:
tensor([ 2, 2, 2, 2])
tensor(11)
pytorch 找最大值
題意:使用神經(jīng)網(wǎng)絡(luò)實(shí)現(xiàn),從數(shù)組中找出最大值。
提供數(shù)據(jù):兩個(gè) csv 文件,一個(gè)存訓(xùn)練集:n 個(gè) m 維特征自然數(shù)數(shù)據(jù),另一個(gè)存每條數(shù)據(jù)對應(yīng)的 label ,就是每條數(shù)據(jù)中的最大值。
這里將隨機(jī)構(gòu)建訓(xùn)練集:
#%%
import numpy as np
import pandas as pd
import torch
import random
import torch.utils.data as Data
import torch.nn as nn
import torch.optim as optim
def GetData(m, n):
dataset = []
for j in range(m):
max_v = random.randint(0, 9)
data = [random.randint(0, 9) for i in range(n)]
dataset.append(data)
label = [max(dataset[i]) for i in range(len(dataset))]
data_list = np.column_stack((dataset, label))
data_list = data_list.astype(np.float32)
return data_list
#%%
# 數(shù)據(jù)集封裝 重載函數(shù)len, getitem
class GetMaxEle(Data.Dataset):
def __init__(self, trainset):
self.data = trainset
def __getitem__(self, index):
item = self.data[index]
x = item[:-1]
y = item[-1]
return x, y
def __len__(self):
return len(self.data)
# %% 定義網(wǎng)絡(luò)模型
class SingleNN(nn.Module):
def __init__(self, n_feature, n_hidden, n_output):
super(SingleNN, self).__init__()
self.hidden = nn.Linear(n_feature, n_hidden)
self.relu = nn.ReLU()
self.predict = nn.Linear(n_hidden, n_output)
def forward(self, x):
x = self.hidden(x)
x = self.relu(x)
x = self.predict(x)
return x
def train(m, n, batch_size, PATH):
# 隨機(jī)生成 m 個(gè) n 個(gè)維度的訓(xùn)練樣本
data_list =GetData(m, n)
dataset = GetMaxEle(data_list)
trainset = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
shuffle=True)
net = SingleNN(n_feature=10, n_hidden=100,
n_output=10)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
#
total_epoch = 100
for epoch in range(total_epoch):
for index, data in enumerate(trainset):
input_x, labels = data
labels = labels.long()
optimizer.zero_grad()
output = net(input_x)
# print(output)
# print(labels)
loss = criterion(output, labels)
loss.backward()
optimizer.step()
# scheduled_optimizer.step()
print(f"Epoch {epoch}, loss:{loss.item()}")
# %% 保存參數(shù)
torch.save(net.state_dict(), PATH)
#測試
def test(m, n, batch_size, PATH):
data_list = GetData(m, n)
dataset = GetMaxEle(data_list)
testloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)
dataiter = iter(testloader)
input_x, labels = dataiter.next()
net = SingleNN(n_feature=10, n_hidden=100,
n_output=10)
net.load_state_dict(torch.load(PATH))
outputs = net(input_x)
_, predicted = torch.max(outputs, 1)
print("Ground_truth:",labels.numpy())
print("predicted:",predicted.numpy())
if __name__ == "__main__":
m = 1000
n = 10
batch_size = 64
PATH = './max_list.pth'
train(m, n, batch_size, PATH)
test(m, n, batch_size, PATH)初始的想法是使用全連接網(wǎng)絡(luò)+分類來實(shí)現(xiàn), 但是結(jié)果不盡人意,主要原因:不同類別之間的樣本量差太大,幾乎90%都是最大值。
比如代碼中隨機(jī)構(gòu)建 10 個(gè) 0~9 的數(shù)字構(gòu)成一個(gè)樣本[2, 3, 5, 8, 9, 5, 3, 9, 3, 6], 該樣本標(biāo)簽是9。
以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
使用python把Excel中的數(shù)據(jù)在頁面中可視化
最近學(xué)習(xí)數(shù)據(jù)分析,感覺Python做數(shù)據(jù)分析真的好用,下面這篇文章主要給大家介紹了關(guān)于如何使用python把Excel中的數(shù)據(jù)在頁面中可視化的相關(guān)資料,需要的朋友可以參考下2022-03-03
Python使用ThreadPoolExecutor一次開啟多個(gè)線程
通過使用ThreadPoolExecutor,您可以同時(shí)開啟多個(gè)線程,從而提高程序的并發(fā)性能,本文就來介紹一下Python使用ThreadPoolExecutor一次開啟多個(gè)線程,感興趣的可以了解一下2023-11-11
Python實(shí)現(xiàn)TCP探測目標(biāo)服務(wù)路由軌跡的原理與方法詳解
這篇文章主要介紹了Python實(shí)現(xiàn)TCP探測目標(biāo)服務(wù)路由軌跡的原理與方法,結(jié)合實(shí)例形式分析了Python TCP探測目標(biāo)服務(wù)路由軌跡的原理、實(shí)現(xiàn)方法及相關(guān)操作注意事項(xiàng),需要的朋友可以參考下2019-09-09
Python requests HTTP驗(yàn)證登錄實(shí)現(xiàn)流程
這篇文章主要介紹了Python requests HTTP驗(yàn)證登錄實(shí)現(xiàn)流程,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2020-11-11
django和vue實(shí)現(xiàn)數(shù)據(jù)交互的方法
今天小編就為大家分享一篇django和vue實(shí)現(xiàn)數(shù)據(jù)交互的方法,具有很好的參考價(jià)值,希望對大家有所幫助。一起跟隨小編過來看看吧2019-08-08
Jinja2實(shí)現(xiàn)模板渲染與訪問對象屬性流程詳解
要了解jinja2,那么需要先理解模板的概念。模板在Python的web開發(fā)中廣泛使用,它能夠有效的將業(yè)務(wù)邏輯和頁面邏輯分開,使代碼可讀性增強(qiáng),并且更加容易理解和維護(hù)。模板簡單來說就是一個(gè)其中包含占位變量表示動態(tài)部分的文,模板文件在經(jīng)過動態(tài)賦值后,返回給用戶2023-03-03

