In [1]:
# Import libs

import torch
import torch.nn as nn
In [2]:
# Let's say the data is a size 28x28 image

IMAGE_SIZE = 28
BATCH_SIZE = 1
CHANNEL_SIZE = 1
In [3]:
# Define simple MLP with skip connection

class ResidualMLP(nn.Module):
    
    MODE_CONCATENTATION = "concatenation"
    MODE_SUMMATION = "summation"
    
    def __init__(self, in_features: int, mode: str):
        super().__init__()
        
        assert mode in (ResidualMLP.MODE_CONCATENTATION, ResidualMLP.MODE_SUMMATION)
        
        self.in_features = in_features
        self.mode = mode
        
        if mode == ResidualMLP.MODE_CONCATENTATION:
            self.layers = [
                nn.Linear(in_features, in_features),
                nn.ReLU(),
                nn.Linear(in_features * 2, in_features),
                nn.ReLU(),
                nn.Linear(in_features * 2, in_features),
            ]
            
        elif mode == ResidualMLP.MODE_SUMMATION:
            self.layers = [
                nn.Linear(in_features, in_features),
                nn.ReLU(),
                nn.Linear(in_features, in_features),
                nn.ReLU(),
                nn.Linear(in_features, in_features),
            ]
            
        self.sigmoid = nn.Sigmoid()
                
    def forward(self, x: torch.Tensor):
        
        # input x
        __x = x.clone()
        
        if self.mode == ResidualMLP.MODE_CONCATENTATION:
            """structure:
                ├───────────────────────┬────────────────────────┐
                │                (cat)              (cat)
                784 ───────> 784 ───────> 1568 ───────> 784 ───────> 1568 ───────> 784
            """ 
            
            for li, layer in enumerate(self.layers):
                x = layer(x)
                if li != len(self.layers) - 1 and isinstance(layer, nn.Linear):
                    x = torch.cat([x, __x], dim=1)
        
        elif self.mode == ResidualMLP.MODE_SUMMATION:
            """structure:
                ├───────────┬───────────┬──────────┬───────────┬───────────┐
                │       (sum)      (sum)      (sum)      (sum)      (sum)
                784 ───────> 784 ───────> 784 ───────> 784 ───────> 784 ───────> 784
            """
            
            for layer in self.layers:
                x = layer(x)
                if isinstance(layer, nn.Linear):
                    x += __x
            
        x = self.sigmoid(x)
        
        return x.reshape(-1, CHANNEL_SIZE, IMAGE_SIZE, IMAGE_SIZE)
In [4]:
# Input & Output shapes checking

noise = torch.randn(size=(BATCH_SIZE, CHANNEL_SIZE, IMAGE_SIZE, IMAGE_SIZE))
print(noise.shape)

mlp = ResidualMLP(784, ResidualMLP.MODE_CONCATENTATION)
print(mlp(noise.reshape(-1, IMAGE_SIZE * IMAGE_SIZE)).shape)

mlp = ResidualMLP(784, ResidualMLP.MODE_SUMMATION)
print(mlp(noise.reshape(-1, IMAGE_SIZE * IMAGE_SIZE)).shape)
torch.Size([1, 1, 28, 28])
torch.Size([1, 1, 28, 28])
torch.Size([1, 1, 28, 28])