import torch
import matplotlib.pyplot as plt
time_steps = 1000
beta_start = 0.0001
beta_end = 0.02
# linear schedule
betas = torch.linspace(beta_start, beta_end, time_steps)
alphas = 1 - betas
alphas_bar = torch.cumprod(alphas, axis=0)
alphas_bar_sqrt = torch.sqrt(alphas_bar)
one_minus_alphas_bar = 1 - alphas_bar
one_minus_alphas_bar_sqrt = torch.sqrt(one_minus_alphas_bar)
def q_sample(x0, t):
return (
alphas_bar_sqrt[t] * x0
+ one_minus_alphas_bar_sqrt[t] * torch.randn_like(x0)
)
https://lilianweng.github.io/posts/2021-07-11-diffusion-models/
\begin{aligned} \mathbf{x}_t &= \sqrt{\alpha_t}\mathbf{x}_{t-1} + \sqrt{1 - \alpha_t}\boldsymbol{\epsilon}_{t-1} & \text{ ;where } \boldsymbol{\epsilon}_{t-1}, \boldsymbol{\epsilon}_{t-2}, \dots \sim \mathcal{N}(\mathbf{0}, \mathbf{I}) \\ &= \sqrt{\alpha_t \alpha_{t-1}} \mathbf{x}_{t-2} + \sqrt{1 - \alpha_t \alpha_{t-1}} \bar{\boldsymbol{\epsilon}}_{t-2} & \text{ ;where } \bar{\boldsymbol{\epsilon}}_{t-2} \text{ merges two Gaussians (*).} \\ &= \dots \\ &= \sqrt{\bar{\alpha}_t}\mathbf{x}_0 + \sqrt{1 - \bar{\alpha}_t}\boldsymbol{\epsilon} \\ q(\mathbf{x}_t \vert \mathbf{x}_0) &= \mathcal{N}(\mathbf{x}_t; \sqrt{\bar{\alpha}_t} \mathbf{x}_0, (1 - \bar{\alpha}_t)\mathbf{I}) \end{aligned}
torch.manual_seed(777)
theta_pos = torch.linspace(0, torch.pi, 300)
theta_pos_xy = torch.stack([torch.cos(theta_pos), torch.sin(theta_pos)], dim=-1)
theta_pos_xy += torch.tensor([-0.3, 0])
theta_pos_xy += torch.randn_like(theta_pos_xy) * 0.07
theta_pos_xy_np = theta_pos_xy.numpy()
theta_neg = torch.linspace(-torch.pi, 0, 300)
theta_neg_xy = torch.stack([torch.cos(theta_neg), torch.sin(theta_neg)], dim=-1)
theta_neg_xy += 0.3
theta_neg_xy += torch.randn_like(theta_neg_xy) * 0.07
theta_neg_xy_np = theta_neg_xy.numpy()
fig, axes = plt.subplots(1, 3, figsize=(17, 6))
x0 = torch.cat([theta_pos_xy, theta_neg_xy], dim=0)
axes[0].scatter(x0[:, 0], x0[:, 1], color="b", s=3)
x_49 = q_sample(x0, t=49)
axes[1].scatter(x_49[:, 0], x_49[:, 1], color="b", s=3)
x_999 = q_sample(x0, t=999)
axes[2].scatter(x_999[:, 0], x_999[:, 1], color="b", s=3)
for ax in axes:
ax.set_aspect("equal")
ax.set_xlim(-2.0, 2.0)
ax.set_ylim(-2.0, 2.0)
ax.axis("off")
plt.tight_layout()
plt.show()