In [1]:
import numpy as np
import matplotlib.pyplot as plt
In [2]:
def kl_divergence_continuous(p: np.ndarray, q: np.ndarray, dx: float):
    return np.sum(p * np.log(p / q)) * dx

def probability_density_function(x: np.ndarray, sigma: float, mu: float):
    return 1 / (sigma * np.sqrt(2*np.pi)) * np.exp(-(x - mu)**2 / (2 * sigma ** 2))
In [3]:
fig, axes = plt.subplots(2, 2, figsize=(15, 11))

params = [
    {'p_mu': 0.0, 'p_sigma': 1.0, 'q_mu': 0.5, 'q_sigma': 1.0},    
    {'p_mu': 0.0, 'p_sigma': 1.0, 'q_mu': 0.0, 'q_sigma': 1.0},    
    {'p_mu': 0.0, 'p_sigma': 1.0, 'q_mu': 1.0, 'q_sigma': 0.5},    
    {'p_mu': 0.0, 'p_sigma': 1.0, 'q_mu': -1.0, 'q_sigma': 2.0},   
]

dx = 1e-2
x = np.arange(-5, 5 + dx, dx)

for ax, param in zip(axes.flatten(), params):
    
    # Calculate PDFs
    p = probability_density_function(x, param['p_sigma'], param['p_mu'])
    q = probability_density_function(x, param['q_sigma'], param['q_mu'])
    
    # Calculate KL divergence
    kl_div = kl_divergence_continuous(p, q, dx)
    
    ax.plot(x, p, label=f'P ~ N({param["p_mu"]},{param["p_sigma"]})', linewidth=2)
    ax.plot(x, q, label=f'Q ~ N({param["q_mu"]},{param["q_sigma"]})', linewidth=2)
    ax.fill_between(x, p, q, alpha=0.3, where=(p>q))
    ax.fill_between(x, p, q, alpha=0.3, where=(p<q))
    
    ax.axvline(x=param['p_mu'], color='blue', linestyle='--', alpha=0.5, linewidth=0.5)
    ax.axvline(x=param['q_mu'], color='orange', linestyle='--', alpha=0.9, linewidth=0.5)
    
    ax.set_xlabel('x')
    ax.set_title(f'\n KL(P||Q) = {kl_div:.4f} \n')
    ax.legend()
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()