In [1]:
import torch

from torch import nn
from torch_geometric import nn as tgnn

class CustomGraphNeuralNetwork(tgnn.MessagePassing):
    def __init__(self, in_channels, out_channels, aggr="add"):
        super().__init__(aggr=aggr)
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.aggr = aggr
        
        self.message_nn = nn.Linear(in_channels * 2, out_channels)
        self.update_nn = nn.Linear(out_channels + in_channels, out_channels)
        
    def forward(self, x, edge_index):
        """
            x: original node features
            edge_index: connectivity
        """
        return self.propagate(x=x, edge_index=edge_index)
    
    def message(self, x_i, x_j):
        """
            x_i: features of target node
            x_j: features of source node
        """
        
        return self.message_nn(torch.cat([x_i, x_j], dim=1))
    
    def update(self, aggr_out, x):
        """
            aggr_out: node embeddings
            x: original node features
        """
        return self.update_nn(torch.cat([aggr_out, x], dim=1))
    
    
torch.manual_seed(777)
num_nodes = 5
num_edges = 10
in_channels = 3
out_channels = 5

edge_index = torch.randint(high=num_nodes, size=(num_edges, 2)).t()
edge_index.sort(dim=1)
edge_index = edge_index[edge_index[:, 0] != edge_index[:, 1]]

gnn = CustomGraphNeuralNetwork(in_channels, out_channels)

x = torch.randint(high=3, size=(num_nodes, in_channels)).float()
x_updated = gnn(x, edge_index)

print(x)
print(f"x.shape: {x.shape}")

print()
print(x_updated)
print(f"x_updated.shape: {x_updated.shape}")
tensor([[0., 1., 0.],
        [2., 0., 0.],
        [2., 1., 1.],
        [0., 2., 2.],
        [0., 1., 2.]])
x.shape: torch.Size([5, 3])

tensor([[ 1.0350,  0.9280, -0.8414, -1.2024,  1.0344],
        [ 0.4925,  0.0646, -0.2983,  0.7108,  0.6179],
        [ 1.1456,  1.4837, -3.0477, -1.9014,  3.8567],
        [-0.0514,  1.9048, -1.2556, -2.9974,  2.6675],
        [ 0.9137,  0.0719, -0.4560, -1.0675,  1.7408]],
       grad_fn=<AddmmBackward0>)
x_updated.shape: torch.Size([5, 5])