fig, axes = plt.subplots(2, 2, figsize=(15, 11))
params = [
{'p_mu': 0.0, 'p_sigma': 1.0, 'q_mu': 0.5, 'q_sigma': 1.0},
{'p_mu': 0.0, 'p_sigma': 1.0, 'q_mu': 0.0, 'q_sigma': 1.0},
{'p_mu': 0.0, 'p_sigma': 1.0, 'q_mu': 1.0, 'q_sigma': 0.5},
{'p_mu': 0.0, 'p_sigma': 1.0, 'q_mu': -1.0, 'q_sigma': 2.0},
]
dx = 1e-2
x = np.arange(-5, 5 + dx, dx)
for ax, param in zip(axes.flatten(), params):
# Calculate PDFs
p = probability_density_function(x, param['p_sigma'], param['p_mu'])
q = probability_density_function(x, param['q_sigma'], param['q_mu'])
# Calculate KL divergence
kl_div = kl_divergence_continuous(p, q, dx)
ax.plot(x, p, label=f'P ~ N({param["p_mu"]},{param["p_sigma"]})', linewidth=2)
ax.plot(x, q, label=f'Q ~ N({param["q_mu"]},{param["q_sigma"]})', linewidth=2)
ax.fill_between(x, p, q, alpha=0.3, where=(p>q))
ax.fill_between(x, p, q, alpha=0.3, where=(p<q))
ax.axvline(x=param['p_mu'], color='blue', linestyle='--', alpha=0.5, linewidth=0.5)
ax.axvline(x=param['q_mu'], color='orange', linestyle='--', alpha=0.9, linewidth=0.5)
ax.set_xlabel('x')
ax.set_title(f'\n KL(P||Q) = {kl_div:.4f} \n')
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()