import torch
from torch import nn
from torch_geometric import nn as tgnn
class CustomGraphNeuralNetwork(tgnn.MessagePassing):
def __init__(self, in_channels, out_channels, aggr="add"):
super().__init__(aggr=aggr)
self.in_channels = in_channels
self.out_channels = out_channels
self.aggr = aggr
self.message_nn = nn.Linear(in_channels * 2, out_channels)
self.update_nn = nn.Linear(out_channels + in_channels, out_channels)
def forward(self, x, edge_index):
"""
x: original node features
edge_index: connectivity
"""
return self.propagate(x=x, edge_index=edge_index)
def message(self, x_i, x_j):
"""
x_i: features of target node
x_j: features of source node
"""
return self.message_nn(torch.cat([x_i, x_j], dim=1))
def update(self, aggr_out, x):
"""
aggr_out: node embeddings
x: original node features
"""
return self.update_nn(torch.cat([aggr_out, x], dim=1))
torch.manual_seed(777)
num_nodes = 5
num_edges = 10
in_channels = 3
out_channels = 5
edge_index = torch.randint(high=num_nodes, size=(num_edges, 2)).t()
edge_index.sort(dim=1)
edge_index = edge_index[edge_index[:, 0] != edge_index[:, 1]]
gnn = CustomGraphNeuralNetwork(in_channels, out_channels)
x = torch.randint(high=3, size=(num_nodes, in_channels)).float()
x_updated = gnn(x, edge_index)
print(x)
print(f"x.shape: {x.shape}")
print()
print(x_updated)
print(f"x_updated.shape: {x_updated.shape}")
tensor([[0., 1., 0.],
[2., 0., 0.],
[2., 1., 1.],
[0., 2., 2.],
[0., 1., 2.]])
x.shape: torch.Size([5, 3])
tensor([[ 1.0350, 0.9280, -0.8414, -1.2024, 1.0344],
[ 0.4925, 0.0646, -0.2983, 0.7108, 0.6179],
[ 1.1456, 1.4837, -3.0477, -1.9014, 3.8567],
[-0.0514, 1.9048, -1.2556, -2.9974, 2.6675],
[ 0.9137, 0.0719, -0.4560, -1.0675, 1.7408]],
grad_fn=<AddmmBackward0>)
x_updated.shape: torch.Size([5, 5])