Pytorch+PyG實(shí)現(xiàn)GraphSAGE過程示例詳解
GraphSAGE簡(jiǎn)介
GraphSAGE(Graph Sampling and Aggregation)是一種常見的圖神經(jīng)網(wǎng)絡(luò)模型,主要用于結(jié)點(diǎn)級(jí)別的表征學(xué)習(xí)。該模型基于采樣和聚合策略,將一個(gè)結(jié)點(diǎn)及其鄰居節(jié)點(diǎn)信息融合在一起,得到其表征表示,并通過多輪迭代更新來提高表征的精度。
實(shí)現(xiàn)步驟
數(shù)據(jù)準(zhǔn)備
在本次實(shí)現(xiàn)中,我們?nèi)匀皇褂肅ora數(shù)據(jù)集作為示例進(jìn)行測(cè)試,由于GraphSage主要聚焦于單一節(jié)點(diǎn)特征的更新,因此這里不需要對(duì)數(shù)據(jù)集做特別處理,只需要將數(shù)據(jù)轉(zhuǎn)化成PyG格式即可。
import torch.nn.functional as F from torch_geometric.datasets import Planetoid from torch_geometric.utils import from_networkx, to_networkx # 加載cora數(shù)據(jù)集 dataset = Planetoid(root='./cora', name='Cora') data = dataset[0] # 將nx.Graph形式的圖轉(zhuǎn)換成PyG需要的格式 graph = to_networkx(data) data = from_networkx(graph) # 獲取節(jié)點(diǎn)數(shù)量和特征向量維度 num_nodes = data.num_nodes num_features = dataset.num_features num_classes = dataset.num_classes # 建立需要訓(xùn)練的節(jié)點(diǎn)分割數(shù)據(jù)集 data.train_mask = torch.zeros(num_nodes, dtype=torch.bool) data.val_mask = torch.zeros(num_nodes, dtype=torch.bool) data.test_mask = torch.zeros(num_nodes, dtype=torch.bool) data.train_mask[:num_nodes - 1000] = True data.test_mask[-1000:] = True data.val_mask[num_nodes - 2000: num_nodes - 1000] = True
實(shí)現(xiàn)模型
接下來,我們需要定義GraphSAGE模型。與傳統(tǒng)的GCN中只需要一層卷積操作不同,GraphSAGE包含兩層卷積和采樣(也稱“聚合”)操作。
from torch.nn import Sequential as Seq, Linear as Lin, ReLU
from torch_geometric.nn import SAGEConv
class GraphSAGE(torch.nn.Module):
def __init__(self, hidden_channels, num_layers):
super(GraphSAGE, self).__init__()
self.convs = nn.ModuleList()
for i in range(num_layers):
in_channels = hidden_channels if i != 0 else num_features
out_channels = num_classes if i == num_layers - 1 else hidden_channels
self.convs.append(SAGEConv(in_channels, out_channels))
def forward(self, x, edge_index):
for _, conv in enumerate(self.convs[:-1]):
x = F.relu(conv(x, edge_index))
# 最后一層不用激活函數(shù)
x = self.convs[-1](x, edge_index)
return F.log_softmax(x, dim=-1)
在上述代碼中,我們實(shí)現(xiàn)了多層GraphSAGE卷積和相應(yīng)的聚合函數(shù),并使用ReLU和softmax函數(shù)來進(jìn)行特征提取和分類分?jǐn)?shù)的輸出。
模型訓(xùn)練
定義好模型之后,就可以開始針對(duì)Cora數(shù)據(jù)集進(jìn)行模型訓(xùn)練。首先還是需要先指定優(yōu)化器和損失函數(shù),并設(shè)定一些參數(shù)用于記錄訓(xùn)練過程中的信息,如Epochs、Batch size、學(xué)習(xí)率等。
# 初始化GraphSage并指定參數(shù)
num_layers = 2
hidden_channels = 256
model = GraphSAGE(hidden_channels, num_layers).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
loss_func = nn.CrossEntropyLoss()
# 訓(xùn)練過程
for epoch in range(500):
model.train()
optimizer.zero_grad()
out = model(data.x.to(device), data.edge_index.to(device))
loss = loss_func(out[data.train_mask], data.y.to(device)[data.train_mask])
loss.backward()
optimizer.step()
# 在各個(gè)測(cè)試階段檢測(cè)一下準(zhǔn)確率
if epoch % 10 == 0:
with torch.no_grad():
_, pred = model(data.x.to(device), data.edge_index.to(device)).max(dim=1)
correct = float(pred[data.test_mask].eq(data.y.to(device)[data.test_mask]).sum().item())
acc = correct / data.test_mask.sum().item()
print("Epoch {:03d}, Train Loss {:.4f}, Test Acc {:.4f}".format(
epoch, loss.item(), acc))
在上述代碼中,我們使用有標(biāo)記的訓(xùn)練數(shù)據(jù)擬合GraphSAGE模型,在各個(gè)驗(yàn)證階段測(cè)試準(zhǔn)確率,并通過梯度下降法優(yōu)化損失函數(shù)。
以上就是Pytorch+PyG實(shí)現(xiàn)GraphSAGE過程示例詳解的詳細(xì)內(nèi)容,更多關(guān)于Pytorch PyG實(shí)現(xiàn)GraphSAGE的資料請(qǐng)關(guān)注腳本之家其它相關(guān)文章!
相關(guān)文章
python里使用正則表達(dá)式的組嵌套實(shí)例詳解
這篇文章主要介紹了python里使用正則表達(dá)式的組嵌套實(shí)例詳解的相關(guān)資料,希望通過本文能幫助到大家,需要的朋友可以參考下2017-10-10
使用python將csv數(shù)據(jù)導(dǎo)入mysql數(shù)據(jù)庫
這篇文章主要為大家詳細(xì)介紹了如何使用python將csv數(shù)據(jù)導(dǎo)入mysql數(shù)據(jù)庫,文中的示例代碼講解詳細(xì),感興趣的小伙伴可以跟隨小編一起學(xué)習(xí)一下2024-05-05
python?spotlight庫簡(jiǎn)化交互式方法探索數(shù)據(jù)分析
這篇文章主要為大家介紹了python?spotlight庫簡(jiǎn)化的交互式方法探索數(shù)據(jù),有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪2024-01-01
基于Python制作一個(gè)端午節(jié)相關(guān)的小游戲
端午節(jié)快樂,今天我將為大家?guī)硪黄嘘P(guān)端午節(jié)的編程文章,希望能夠?yàn)榇蠹耀I(xiàn)上一份小小的驚喜,我們將會(huì)使用Python來實(shí)現(xiàn)一個(gè)與端午粽子相關(guān)的小應(yīng)用程序,在本文中,我將會(huì)介紹如何用Python代碼制做一個(gè)“粽子拆解器”,感興趣的小伙伴歡迎閱讀2023-06-06
python實(shí)現(xiàn)簡(jiǎn)易通訊錄修改版
這篇文章主要為大家詳細(xì)介紹了python實(shí)現(xiàn)簡(jiǎn)易通訊錄的修改版,具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2018-03-03
Python實(shí)現(xiàn)將16進(jìn)制字符串轉(zhuǎn)化為ascii字符的方法分析
這篇文章主要介紹了Python實(shí)現(xiàn)將16進(jìn)制字符串轉(zhuǎn)化為ascii字符的方法,結(jié)合實(shí)例形式分析了Python 16進(jìn)制字符串轉(zhuǎn)換為ascii字符的實(shí)現(xiàn)方法與相關(guān)注意事項(xiàng),需要的朋友可以參考下2017-07-07
Python實(shí)現(xiàn)完全數(shù)的示例詳解
完全數(shù),又稱完美數(shù),定義為:這個(gè)數(shù)的所有因數(shù)(不包括這個(gè)數(shù)本身)加起來剛好等于這個(gè)數(shù)。本文就來用Python實(shí)現(xiàn)計(jì)算完全數(shù),需要的可以參考一下2023-01-01
Python基于xlutils修改表格內(nèi)容過程解析
這篇文章主要介紹了Python基于xlutils修改表格內(nèi)容過程解析,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2020-07-07
python操作excel的方法(xlsxwriter包的使用)
這篇文章主要為大家詳細(xì)介紹了python操作excel的方法,xlsxwriter包的使用方法,具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2018-06-06

