Gumbel-Softmax
-
The sampling process cannot be differentiable, making it difficult for the neural network to learn.
For example, when sampling a one-hot vector, probabilistic selection is made, but the sampling process cannot backpropagate.
(Argmax is not differentiable in neural networks)
\[
\,\\
\sigma(z_i) = \frac{\exp(z_i)}{\sum_{j=1}^{K} \exp(z_j)}
\,\\
\]
where \(\sigma(z_i)\) is the original softmax function.
-
Gumbel-Softmax is an approximate sampling method that can continuously be differentiable.
Specifically, it is used to sample from a categorical distribution in a neural network, allowing for backpropagation and learning.
\[
\,\\
y_i = \frac{\exp( ( \log(\pi_i) + g_i ) / \tau)}{\sum_{j=1}^{K} \exp( ( \log(\pi_j) + g_j ) / \tau)}
\,\\
\]
where \(g_i = -\log(-\log(u_i)), u_i \sim \text{Uniform}(0, 1)\).
if \(\tau \rightarrow 0\), \(y_i\) becomes approximately one-hot vector.
if \(\tau \rightarrow \infty\), \(y_i\) becomes approximately uniform distribution.
-
Straight-Through Trick
-
Gumbel-Softmax sampling can be used to generate continuous probability vectors,
but in reality, you may want the final output to be a discrete one-hot vector.
In this case, you can use the Straight-Through Trick to create a one-hot vector
while still allowing backpropagation.
-
Forward pass: Using argmax, setting the probability vector from Gumbel-Softmax to one-hot vector.
-
Backward pass: Using the original probability vector, passing gradient.
import torch
def gumbel_softmax(logits, tau=1.0, hard=False):
g = -torch.log(-torch.log(torch.rand_like(logits)))
y = torch.softmax((logits + g) / tau, dim=-1)
if hard:
y_hard = torch.zeros_like(y)
y_hard.scatter_(-1, y.argmax(dim=-1, keepdim=True), 1)
y = y_hard - y.detach() + y # Straight-Through Trick
return y
logits = torch.randn(7)
y_soft = gumbel_softmax(logits, hard=False)
print(y_soft)
# tensor([0.1083, 0.0069, 0.5550, 0.2059, 0.0542, 0.0047, 0.0651])
y_hard = gumbel_softmax(logits, hard=True)
print(y_hard)
# tensor([0., 0., 1., 0., 0., 0., 0.])