In [1]:
import torch
import random
import numpy as np
import matplotlib.pyplot as plt

from tqdm import tqdm
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

print("torch.cuda.is_available():", torch.cuda.is_available())
torch.cuda.is_available(): True
In [2]:
class Classifier(torch.nn.Module):
    def __init__(self):
        super().__init__()
        
        self.main = torch.nn.Sequential(
            torch.nn.Linear(784, 784),
            torch.nn.ReLU(inplace=True),
            torch.nn.Linear(784, 10),
            torch.nn.Softmax(dim=1)
        )
        
    def forward(self, x):
        return self.main(x)
    
    
class Trainer:
    def __init__(
        self, 
        model, 
        dataset, 
        epochs,
        batch_size,
        optimizer,
        loss_function,
        accumulation_steps=1
    ):
        self.model = model
        self.dataset = dataset
        self.epochs = epochs
        self.batch_size = batch_size
        self.optimizer = optimizer
        self.loss_function = loss_function
        self.accumulation_steps = accumulation_steps
        
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        
        self.dataloader = DataLoader(
            self.dataset, self.batch_size, drop_last=True
        )
        
        self.train_losses = []
        self.train_accuracies = []
        
    def fit(self):
        
        self.model.to(self.device)
        
        for _ in tqdm(range(1, self.epochs + 1)):
            
            train_loss = 0
            train_correct = 0
            
            for batch_index, (x, y) in enumerate(self.dataloader):
                
                x = x.to(self.device)
                y = y.to(self.device)
                
                predictions = self.model(x.reshape(-1, 784))
                loss = self.loss_function(predictions, y)
                
                # Average loss over accumulation steps to maintain gradient scale for batches
                loss /= self.accumulation_steps
                loss.backward()
                
                train_loss += loss.item() * self.accumulation_steps
                train_correct += (predictions.argmax(dim=1) == y).sum().item()
                
                # Update weights and reset gradients when the batch_index has reached a multiple of accumulation_steps
                if (batch_index + 1) % self.accumulation_steps == 0:
                    self.optimizer.step()
                    self.model.zero_grad()
            
            self.train_losses.append(train_loss / len(self.dataloader))
            self.train_accuracies.append(train_correct / (self.batch_size * len(self.dataloader)))
In [3]:
# Set seed
seed = 777
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True

# Load dataset
dataset = datasets.MNIST(
    root='./data', 
    train=True, 
    download=True, 
    transform=transforms.ToTensor()
)
In [4]:
# Training without Gradient Accumulation
model = Classifier()
optimizer = torch.optim.Adam(model.parameters())
loss_function = torch.nn.CrossEntropyLoss()

trainer_without_ga = Trainer(
    model=model,
    dataset=dataset,
    epochs=50,
    batch_size=32 * 8,  # 32x8=256
    accumulation_steps=1,
    optimizer=optimizer,
    loss_function=loss_function,
)

trainer_without_ga.fit()
100%|██████████| 50/50 [06:46<00:00,  8.13s/it]
In [5]:
# Training with Gradient Accumulation
model = Classifier()
optimizer = torch.optim.Adam(model.parameters())
loss_function = torch.nn.CrossEntropyLoss()

trainer_with_ga = Trainer(
    model=model,
    dataset=dataset,
    epochs=50,
    batch_size=32,  # 32x8=256
    accumulation_steps=8,
    optimizer=optimizer,
    loss_function=loss_function,
)

trainer_with_ga.fit()
100%|██████████| 50/50 [10:38<00:00, 12.76s/it]
In [6]:
# Visualize training losses and accuracies
plt.figure(figsize=(15, 5))

plt.subplot(1, 2, 1)
plt.plot(trainer_without_ga.train_losses, label="without GA", alpha=0.5)
plt.plot(trainer_with_ga.train_losses, label="with GA", alpha=0.5)
plt.title("Training Loss\n")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.grid(True, alpha=0.5)

plt.subplot(1, 2, 2)
plt.plot(trainer_without_ga.train_accuracies, label="without GA", alpha=0.5)
plt.plot(trainer_with_ga.train_accuracies, label="with GA", alpha=0.5)
plt.title("Training Accuracy\n")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.legend()
plt.grid(True, alpha=0.5)

plt.tight_layout()