Gradient Penalty
-
The gradient penalty is a regularization technique designed to enforce the Lipschitz constraint,
which plays a role in keeping the discriminator (or critic)'s gradient close to 1. The penalty term is:
\[
\,\\
\lambda \cdot \mathbb{E}_{\hat{x} \, \sim \, \mathbb{P}_{\hat{x}}} [\, (\| \nabla_{\hat{x}} \, D(\hat{x}) \|_2 - 1)^2\, ]
\,\\
\]
where:
-
\(\hat{x}\) are interpolated samples between real data \(x\) and generated data \(g(z)\), defined as:
\[
\,\\
\hat{x} = \epsilon x + (1 - \epsilon) g(z), \text{where} \epsilon \sim U[0, 1]
\,\\
\]
-
\(\| \nabla_{\hat{x}} \, D(\hat{x}) \|_2\) is the gradient norm
-
The penalty enforces the Lipschitz constraint by keeping the gradient norm close to 1.
-
\(\lambda\) is a hyperparameter controlling the strength of the penalty.
def compute_gradient_penalty(self, real_data: torch.Tensor, fake_data: torch.Tensor) -> torch.Tensor:
"""Compute the gradient penalty to enforce the Lipschitz constraint.
Args:
real_data (torch.Tensor): A batch of real data.
fake_data (torch.Tensor): A batch of generated data.
Returns:
torch.Tensor: The computed gradient penalty.
"""
batch_size, *_ = real_data.size()
# epsilon for the voxel interpolation (n, c, d, h, w)
e = torch.rand((batch_size, 1, 1, 1, 1)).to(self.DEVICE)
interpolated = (e * real_data + ((1 - e) * fake_data)).requires_grad_(True).to(self.DEVICE)
# Get the discriminator output for the interpolated data
d_interpolated = self.discriminator(interpolated)
# Get the gradients w.r.t. the interpolated data
gradients = torch.autograd.grad(
outputs=d_interpolated,
inputs=interpolated,
grad_outputs=torch.ones_like(d_interpolated).to(self.DEVICE),
create_graph=True,
retain_graph=True,
only_inputs=True,
)[0]
# Compute the gradient penalty
gradients = gradients.view(batch_size, -1)
gradients_norm = torch.sqrt(torch.sum(gradients**2, dim=1) + 1e-12)
gradient_penalty = self.LAMBDA_1 * ((gradients_norm - 1) ** 2).mean()
return gradient_penalty