3. Train from energy#

This notebook walks you through training a normalizing flow by gradient descent when data is unavailable, but an energy function \(U(x)\) proportional to the density \(p(x)\) is available.

import matplotlib.pyplot as plt
import torch
import zuko

3.1. Energy#

We consider a simple multi-modal energy function.

\[ \log U(x) = \sin(\pi x_1) - 2 \big( x_1^2 + x_2^2 - 2 \big)^2 \]
def log_energy(x):
    x1, x2 = x[..., 0], x[..., 1]
    return torch.sin(torch.pi * x1) - 2 * (x1 ** 2 + x2 ** 2 - 2) ** 2
x1 = torch.linspace(-3, 3, 64)
x2 = torch.linspace(-3, 3, 64)

x = torch.stack(torch.meshgrid(x1, x2, indexing='xy'), dim=-1)

energy = log_energy(x).exp()
plt.figure(figsize=(4.8, 4.8))
plt.imshow(energy)
plt.show()
../_images/6e2ca26f523033c0728337704f5d4187ab7c05ad93de0ac28ae23aaa66b56fe4.png

3.2. Flow#

We use a neural spline flow (NSF) as density estimator \(q_\phi(x)\). However, we inverse the transformation(s), which makes sampling more efficient as the inverse call of an autoregressive transformation is \(D\) (where \(D\) is the number of features) times slower than its forward call.

flow = zuko.flows.NSF(features=2, transforms=3, hidden_features=(64, 64))
flow = zuko.flows.Flow(flow.transform.inv, flow.base)
flow
Flow(
  (transform): LazyInverse(
    (transform): LazyComposedTransform(
      (0): MaskedAutoregressiveTransform(
        (base): MonotonicRQSTransform(bins=8)
        (order): [0, 1]
        (hyper): MaskedMLP(
          (0): MaskedLinear(in_features=2, out_features=64, bias=True)
          (1): ReLU()
          (2): MaskedLinear(in_features=64, out_features=64, bias=True)
          (3): ReLU()
          (4): MaskedLinear(in_features=64, out_features=46, bias=True)
        )
      )
      (1): MaskedAutoregressiveTransform(
        (base): MonotonicRQSTransform(bins=8)
        (order): [1, 0]
        (hyper): MaskedMLP(
          (0): MaskedLinear(in_features=2, out_features=64, bias=True)
          (1): ReLU()
          (2): MaskedLinear(in_features=64, out_features=64, bias=True)
          (3): ReLU()
          (4): MaskedLinear(in_features=64, out_features=46, bias=True)
        )
      )
      (2): MaskedAutoregressiveTransform(
        (base): MonotonicRQSTransform(bins=8)
        (order): [0, 1]
        (hyper): MaskedMLP(
          (0): MaskedLinear(in_features=2, out_features=64, bias=True)
          (1): ReLU()
          (2): MaskedLinear(in_features=64, out_features=64, bias=True)
          (3): ReLU()
          (4): MaskedLinear(in_features=64, out_features=46, bias=True)
        )
      )
    )
  )
  (base): Unconditional(DiagNormal(loc: torch.Size([2]), scale: torch.Size([2])))
)

The objective is to minimize the Kullback-Leibler (KL) divergence between the modeled distribution \(q_\phi(x)\) and the true data distribution \(p(x)\).

\[\begin{split} \begin{align} \arg \min_\phi & ~ \mathrm{KL} \big( q_\phi(x) || p(x) \big) \\ = \arg \min_\phi & ~ \mathbb{E}_{q_\phi(x)} \left[ \log \frac{q_\phi(x)}{p(x)} \right] \\ = \arg \min_\phi & ~ \mathbb{E}_{q_\phi(x)} \big[ \log q_\phi(x) - \log U(x) \big] \end{align} \end{split}\]

Note that this “reverse KL” objective is prone to mode collapses, especially for high-dimensional data.

optimizer = torch.optim.Adam(flow.parameters(), lr=1e-3)

for epoch in range(8):
    losses = []

    for _ in range(256):
        x, log_prob = flow().rsample_and_log_prob((256,))  # faster than rsample + log_prob

        loss = log_prob.mean() - log_energy(x).mean()
        loss.backward()

        optimizer.step()
        optimizer.zero_grad()

        losses.append(loss.detach())

    losses = torch.stack(losses)

    print(f'({epoch})', losses.mean().item(), '±', losses.std().item())
(0) -1.012157678604126 ± 1.0205215215682983
(1) -1.5622574090957642 ± 0.03264421969652176
(2) -1.5753192901611328 ± 0.033491406589746475
(3) -1.5814640522003174 ± 0.025743382051587105
(4) -1.5768922567367554 ± 0.04906836897134781
(5) -1.5749255418777466 ± 0.13962876796722412
(6) -1.5877153873443604 ± 0.015589614398777485
(7) -1.5886530876159668 ± 0.029878195375204086
samples = flow().sample((16384,))

plt.figure(figsize=(4.8, 4.8))
plt.hist2d(*samples.T, bins=64, range=((-3, 3), (-3, 3)))
plt.show()
../_images/58630fdced581ef2ab3b9b07ea1597b85d27f491df8800a6b1b58c12b0fc715c.png