import torch
import random
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
print("torch.cuda.is_available():", torch.cuda.is_available())
torch.cuda.is_available(): True
class Classifier(torch.nn.Module):
def __init__(self):
super().__init__()
self.main = torch.nn.Sequential(
torch.nn.Linear(784, 784),
torch.nn.ReLU(inplace=True),
torch.nn.Linear(784, 10),
torch.nn.Softmax(dim=1)
)
def forward(self, x):
return self.main(x)
class Trainer:
def __init__(
self,
model,
dataset,
epochs,
batch_size,
optimizer,
loss_function,
accumulation_steps=1
):
self.model = model
self.dataset = dataset
self.epochs = epochs
self.batch_size = batch_size
self.optimizer = optimizer
self.loss_function = loss_function
self.accumulation_steps = accumulation_steps
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.dataloader = DataLoader(
self.dataset, self.batch_size, drop_last=True
)
self.train_losses = []
self.train_accuracies = []
def fit(self):
self.model.to(self.device)
for _ in tqdm(range(1, self.epochs + 1)):
train_loss = 0
train_correct = 0
for batch_index, (x, y) in enumerate(self.dataloader):
x = x.to(self.device)
y = y.to(self.device)
predictions = self.model(x.reshape(-1, 784))
loss = self.loss_function(predictions, y)
# Average loss over accumulation steps to maintain gradient scale for batches
loss /= self.accumulation_steps
loss.backward()
train_loss += loss.item() * self.accumulation_steps
train_correct += (predictions.argmax(dim=1) == y).sum().item()
# Update weights and reset gradients when the batch_index has reached a multiple of accumulation_steps
if (batch_index + 1) % self.accumulation_steps == 0:
self.optimizer.step()
self.model.zero_grad()
self.train_losses.append(train_loss / len(self.dataloader))
self.train_accuracies.append(train_correct / (self.batch_size * len(self.dataloader)))
# Set seed
seed = 777
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True
# Load dataset
dataset = datasets.MNIST(
root='./data',
train=True,
download=True,
transform=transforms.ToTensor()
)
# Training without Gradient Accumulation
model = Classifier()
optimizer = torch.optim.Adam(model.parameters())
loss_function = torch.nn.CrossEntropyLoss()
trainer_without_ga = Trainer(
model=model,
dataset=dataset,
epochs=50,
batch_size=32 * 8, # 32x8=256
accumulation_steps=1,
optimizer=optimizer,
loss_function=loss_function,
)
trainer_without_ga.fit()
100%|██████████| 50/50 [06:46<00:00, 8.13s/it]
# Training with Gradient Accumulation
model = Classifier()
optimizer = torch.optim.Adam(model.parameters())
loss_function = torch.nn.CrossEntropyLoss()
trainer_with_ga = Trainer(
model=model,
dataset=dataset,
epochs=50,
batch_size=32, # 32x8=256
accumulation_steps=8,
optimizer=optimizer,
loss_function=loss_function,
)
trainer_with_ga.fit()
100%|██████████| 50/50 [10:38<00:00, 12.76s/it]
# Visualize training losses and accuracies
plt.figure(figsize=(15, 5))
plt.subplot(1, 2, 1)
plt.plot(trainer_without_ga.train_losses, label="without GA", alpha=0.5)
plt.plot(trainer_with_ga.train_losses, label="with GA", alpha=0.5)
plt.title("Training Loss\n")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.grid(True, alpha=0.5)
plt.subplot(1, 2, 2)
plt.plot(trainer_without_ga.train_accuracies, label="without GA", alpha=0.5)
plt.plot(trainer_with_ga.train_accuracies, label="with GA", alpha=0.5)
plt.title("Training Accuracy\n")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.legend()
plt.grid(True, alpha=0.5)
plt.tight_layout()