詳解使用Pytorch Geometric實(shí)現(xiàn)GraphSAGE模型
正文
GraphSAGE是一種用于圖神經(jīng)網(wǎng)絡(luò)中的節(jié)點(diǎn)嵌入學(xué)習(xí)方法。它通過聚合節(jié)點(diǎn)鄰居的信息來生成節(jié)點(diǎn)的低維表示,使節(jié)點(diǎn)表示能夠更好地應(yīng)用于各種下游任務(wù),如節(jié)點(diǎn)分類、鏈路預(yù)測等。
圖構(gòu)建
在使用GraphSAGE對節(jié)點(diǎn)進(jìn)行嵌入學(xué)習(xí)之前,我們需要先將原始數(shù)據(jù)轉(zhuǎn)換為圖結(jié)構(gòu),并將其存儲為Pytorch Tensor格式。例如,我們可以使用networkx庫來構(gòu)建一個簡單的圖:
import networkx as nx G = nx.karate_club_graph()
然后,我們可以使用Pytorch Geometric庫將NetworkX圖轉(zhuǎn)換為Pytorch Tensor格式。首先,我們需要安裝Pytorch Geometric并導(dǎo)入所需的類:
!pip install torch-geometric from torch_geometric.datasets import Planetoid from torch_geometric.transforms import NormalizeFeatures from torch_geometric.utils.convert import from_networkx
接著,我們可以使用from_networkx函數(shù)將NetworkX圖轉(zhuǎn)換為Pytorch Tensor格式:
data = from_networkx(G)
此時,data對象包含了關(guān)于節(jié)點(diǎn)、邊及其屬性的信息,例如:
data.edge_index: 2x(#edges)的長整型張量,表示邊的起點(diǎn)和終點(diǎn)
data.x: n×dn \times dn×d 的浮點(diǎn)型張量,表示每個節(jié)點(diǎn)的特征向量(其中nnn是節(jié)點(diǎn)數(shù)量,ddd是特征維度)
注意,此時的data對象并未包含鄰居信息。接下來,我們將介紹如何使用Sampler方法采樣節(jié)點(diǎn)鄰居。
Sampler方法
GraphSAGE使用Sampler方法來聚合鄰居信息。在Pytorch Geometric中,可以使用Various Sampling方法來實(shí)現(xiàn)Sampler。例如,使用ClusterData方法將圖分成多個子圖,然后對每個子圖進(jìn)行采樣操作。
以下是ClusterData的使用示例:
from torch_geometric.utils import degree, to_undirected
from torch_geometric.transforms import ClusterData
# Convert the graph to an undirected graph, so we can aggregate neighbors in both directions.
G = to_undirected(G)
# Compute the degree of each node.
deg = degree(data.edge_index[0], num_nodes=data.num_nodes)
# Use METIS algorithm to partition the graph into multiple subgraphs.
cluster_data = ClusterData(data, num_parts=2, recursive=False, transform=NormalizeFeatures(),
degree=deg)
這里我們將原始圖分成兩個子圖,并對每個子圖進(jìn)行規(guī)范化特征轉(zhuǎn)換。注意,在使用ClusterData方法之前,需要將原始圖轉(zhuǎn)換為無向圖。
另一個常用的Sampler方法是在隨機(jī)游動時對鄰居進(jìn)行采樣,這種方法被稱為隨機(jī)游走采樣(Random Walk Sampling)。以下是隨機(jī)游走采樣的示例代碼:
from torch_geometric.utils import random_walk # Perform random walk sampling to obtain node neighbor samples. walk_length = 20 # The length of random walk trail. num_steps = 4 # The number of nodes to sample from each step. data.batch = None data.edge_index = to_undirected(data.edge_index) # Use undirected edge for random walk. rw_data = random_walk(data.edge_index, walk_length=walk_length, num_steps=num_steps)
這里我們將使用一個長度為20、每個步驟采樣4個鄰居的隨機(jī)游走方法。注意,在使用隨機(jī)游走方法進(jìn)行采樣之前,需要使用無向邊。
GraphSAGE模型定義
GraphSAGE模型包含3個部分:1)圖卷積層;2)聚合器(Aggregator);3)輸出層。我們將在本節(jié)中介紹如何使用Pytorch實(shí)現(xiàn)這些組件。
首先,讓我們定義一個圖卷積層。圖卷積層的輸入是節(jié)點(diǎn)特征矩陣、鄰接矩陣和聚合器,輸出是新的節(jié)點(diǎn)特征矩陣。以下是圖卷積層的代碼實(shí)現(xiàn):
import torch.nn.functional as F
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn import global_mean_pool
class GraphSageConv(MessagePassing):
def __init__(self, in_channels, out_channels, aggr='mean'):
super(GraphSageConv, self).__init__(aggr=aggr)
self.lin = nn.Linear(in_channels, out_channels)
def forward(self, x, edge_index):
return self.propagate(edge_index, x=x)
def message(self, x_j):
return x_j
def update(self, aggr_out, x):
return F.relu(self.lin(torch.cat([x, aggr_out], dim=1)))
這里我們繼承了MessagePassing類,并在__init__函數(shù)中定義了一個全連接層,用于將輸入特征矩陣x從 dind_{in}din? 維映射到 doutd_{out}dout? 維。在forward函數(shù)中,我們使用propagate方法來實(shí)現(xiàn)消息傳遞操作;在message函數(shù)中,我們僅向下游節(jié)點(diǎn)發(fā)送原始特征數(shù)據(jù);在update函數(shù)中,我們首先對聚合結(jié)果進(jìn)行ReLU非線性變換,然后再通過全連接層進(jìn)行節(jié)點(diǎn)特征的更新。
接下來,讓我們定義一個聚合器。聚合器的輸入是采樣得到的鄰居特征矩陣,輸出是新的節(jié)點(diǎn)嵌入向量。以下是聚合器的代碼實(shí)現(xiàn):
class MeanAggregator(nn.Module):
def __init__(self, input_dim, output_dim):
super(MeanAggregator, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.lin = nn.Linear(input_dim, output_dim)
def forward(self, neigh_mean):
out = F.relu(self.lin(neigh_mean))
return out
這里我們定義了一個簡單的均值聚合器,其將鄰居特征矩陣中每列的均值作為節(jié)點(diǎn)嵌入向量,并使用全連接層進(jìn)行維度變換。
最后,讓我們定義整個GraphSage模型。GraphSage模型包含2個圖卷積層和1個輸出層。以下是模型的代碼實(shí)現(xiàn):
class GraphSAGE(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, num_layers=2):
super(GraphSAGE, self).__init__()
self.conv1 = GraphSageConv(in_channels, hidden_channels)
self.aggreg1 = MeanAggregator(hidden_channels, hidden_channels)
self.conv2 = GraphSageConv(hidden_channels, out_channels)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index)
x = global_mean_pool(x, edge_index) # Compute global mean over nodes.
x = self.aggreg1(x)
x = self.conv2(x, edge_index)
return x
這里我們定義了一個包含2層GraphSAGE Conv層的神經(jīng)網(wǎng)絡(luò)。在最后一層GraphSAGE Conv層之后,我們使用global_mean_pool函數(shù)來計(jì)算節(jié)點(diǎn)嵌入的全局平均值。注意,在本示例中,我們僅保留了一個輸出節(jié)點(diǎn),因此輸出矩陣的大小為1。如果需要輸出多個節(jié)點(diǎn),則需要設(shè)置global_mean_pool函數(shù)中的參數(shù)。
模型訓(xùn)練與測試
在定義好模型后,我們可以使用Pytorch進(jìn)行模型訓(xùn)練和測試。首先,讓我們定義一個損失函數(shù)和優(yōu)化器:
criterion = nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
這里我們使用交叉熵作為損失函數(shù),并使用Adam優(yōu)化器來更新模型參數(shù)。
接著,我們可以開始訓(xùn)練模型。以下是訓(xùn)練過程的代碼實(shí)現(xiàn):
num_epochs = 100
for epoch in range(num_epochs):
model.train()
optimizer.zero_grad()
out = model(data.x, data.edge_index)
loss = criterion(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
print('Epoch {:03d}, Loss: {:.4f}'.format(epoch, loss.item()))
這里我們遍歷所有數(shù)據(jù)樣本,計(jì)算預(yù)測結(jié)果和真實(shí)標(biāo)簽之間的交叉熵?fù)p失,并使用反向傳播來更新權(quán)重。我們在每個epoch結(jié)束后打印出當(dāng)前損失值。
最后,我們可以對模型進(jìn)行測試。以下是測試過程的代碼實(shí)現(xiàn):
model.eval()
with torch.no_grad():
pred = model(data.x, data.edge_index)
pred = pred.argmax(dim=1)
acc = (pred[data.test_mask] == data.y[data.test_mask]).sum().item() / data.test_mask.sum().item()
print('Test accuracy: {:.4f}'.format(acc))
這里我們使用測試集來計(jì)算模型的準(zhǔn)確率。注意,在執(zhí)行model.eval()后,我們需要使用torch.no_grad()包裝代碼塊,以禁止梯度計(jì)算。
總結(jié)
介紹了如何使用Pytorch Geometric實(shí)現(xiàn)GraphSAGE模型,包括構(gòu)建圖、定義Sampler方法、定義模型、訓(xùn)練和測試模型等步驟。GraphSAGE模型是一種常用的節(jié)點(diǎn)嵌入學(xué)習(xí)方法,可以應(yīng)用于各種下游任務(wù)中。
以上就是詳解使用Pytorch Geometric實(shí)現(xiàn)GraphSAGE模型的詳細(xì)內(nèi)容,更多關(guān)于Pytorch Geometric GraphSAGE的資料請關(guān)注腳本之家其它相關(guān)文章!
- PyTorch模型轉(zhuǎn)換為ONNX格式實(shí)現(xiàn)過程詳解
- 利用Pytorch實(shí)現(xiàn)ResNet網(wǎng)絡(luò)構(gòu)建及模型訓(xùn)練
- 詳解利用Pytorch實(shí)現(xiàn)ResNet網(wǎng)絡(luò)之評估訓(xùn)練模型
- pytorch模型的保存加載與續(xù)訓(xùn)練詳解
- AMP?Tensor?Cores節(jié)省內(nèi)存PyTorch模型詳解
- 詳解?PyTorch?Lightning模型部署到生產(chǎn)服務(wù)中
- Pytorch模型定義與深度學(xué)習(xí)自查手冊
- 一文詳解如何實(shí)現(xiàn)PyTorch模型編譯
相關(guān)文章
openstack中的rpc遠(yuǎn)程調(diào)用的方法
今天通過本文給大家分享openstack中的rpc遠(yuǎn)程調(diào)用的方法,本文給大家介紹的非常詳細(xì),對大家的學(xué)習(xí)或工作具有一定的參考借鑒價值,需要的朋友參考下吧2021-07-07
使用Python實(shí)現(xiàn)為PDF文件添加圖章
在日常工作中,我們經(jīng)常需要給PDF文檔添加一些標(biāo)識,比如公司的圖章或水印圖章,所以本文就來為大家詳細(xì)介紹一下如何使用Python實(shí)現(xiàn)為PDF文件添加圖章,需要的可以參考下2023-11-11
Python 中 and, or, &, |, ^ 
這篇文章主要介紹了Python 中 and, or, &, |, ^ 的使用小結(jié),本文給大家介紹的非常詳細(xì),對大家的學(xué)習(xí)或工作具有一定的參考借鑒價值,需要的朋友參考下吧2024-01-01
python多進(jìn)程控制學(xué)習(xí)小結(jié)
這篇文章主要介紹了python多進(jìn)程控制學(xué)習(xí)小結(jié),想要充分利用多核CPU資源,Python中大部分情況下都需要使用多進(jìn)程,Python中提供了multiprocessing這個包實(shí)現(xiàn)多進(jìn)程。感興趣的小伙伴們可以參考一下2018-10-10
Python函數(shù)必須先定義,后調(diào)用說明(函數(shù)調(diào)用函數(shù)例外)
這篇文章主要介紹了Python函數(shù)必須先定義,后調(diào)用說明(函數(shù)調(diào)用函數(shù)例外),具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2020-06-06

