# Import libs
import torch
import torch.nn as nn
# Let's say the data is a size 28x28 image
IMAGE_SIZE = 28
BATCH_SIZE = 1
CHANNEL_SIZE = 1
# Define simple MLP with skip connection
class ResidualMLP(nn.Module):
MODE_CONCATENTATION = "concatenation"
MODE_SUMMATION = "summation"
def __init__(self, in_features: int, mode: str):
super().__init__()
assert mode in (ResidualMLP.MODE_CONCATENTATION, ResidualMLP.MODE_SUMMATION)
self.in_features = in_features
self.mode = mode
if mode == ResidualMLP.MODE_CONCATENTATION:
self.layers = [
nn.Linear(in_features, in_features),
nn.ReLU(),
nn.Linear(in_features * 2, in_features),
nn.ReLU(),
nn.Linear(in_features * 2, in_features),
]
elif mode == ResidualMLP.MODE_SUMMATION:
self.layers = [
nn.Linear(in_features, in_features),
nn.ReLU(),
nn.Linear(in_features, in_features),
nn.ReLU(),
nn.Linear(in_features, in_features),
]
self.sigmoid = nn.Sigmoid()
def forward(self, x: torch.Tensor):
# input x
__x = x.clone()
if self.mode == ResidualMLP.MODE_CONCATENTATION:
"""structure:
├───────────────────────┬────────────────────────┐
│ (cat) (cat)
784 ───────> 784 ───────> 1568 ───────> 784 ───────> 1568 ───────> 784
"""
for li, layer in enumerate(self.layers):
x = layer(x)
if li != len(self.layers) - 1 and isinstance(layer, nn.Linear):
x = torch.cat([x, __x], dim=1)
elif self.mode == ResidualMLP.MODE_SUMMATION:
"""structure:
├───────────┬───────────┬──────────┬───────────┬───────────┐
│ (sum) (sum) (sum) (sum) (sum)
784 ───────> 784 ───────> 784 ───────> 784 ───────> 784 ───────> 784
"""
for layer in self.layers:
x = layer(x)
if isinstance(layer, nn.Linear):
x += __x
x = self.sigmoid(x)
return x.reshape(-1, CHANNEL_SIZE, IMAGE_SIZE, IMAGE_SIZE)
# Input & Output shapes checking
noise = torch.randn(size=(BATCH_SIZE, CHANNEL_SIZE, IMAGE_SIZE, IMAGE_SIZE))
print(noise.shape)
mlp = ResidualMLP(784, ResidualMLP.MODE_CONCATENTATION)
print(mlp(noise.reshape(-1, IMAGE_SIZE * IMAGE_SIZE)).shape)
mlp = ResidualMLP(784, ResidualMLP.MODE_SUMMATION)
print(mlp(noise.reshape(-1, IMAGE_SIZE * IMAGE_SIZE)).shape)
torch.Size([1, 1, 28, 28]) torch.Size([1, 1, 28, 28]) torch.Size([1, 1, 28, 28])