1. Learn the basics#

This notebook walks you through the basics of PyTorch/Zuko distributions and transformations, how to parametrize probabilistic models, how to instantiate pre-built normalizing flows and finally how to create custom flow architectures. Training is covered in other tutorials.

import torch
import zuko

1.1. Distributions and transformations#

PyTorch defines two components for probabilistic modeling: the Distribution and the Transform. A distribution object represents the probability distribution \(p(X)\) of a random variable \(X\). A distribution must implement the sample and log_prob methods, meaning that we can draw realizations \(x \sim p(X)\) from the distribution and evaluate the log-likelihood \(\log p(X = x)\) of realizations.

distribution = torch.distributions.Normal(torch.tensor(0.0), torch.tensor(1.0))

x = distribution.sample()         # x ~ p(X)
log_p = distribution.log_prob(x)  # log p(X = x)

x, log_p
(tensor(0.2122), tensor(-0.9415))

A transform object represents a bijective transformation \(f: X \mapsto Y\) from a domain to a co-domain. A transformation must implement a forward call \(y = f(x)\), an inverse call \(x = f^{-1}(y)\) and the log_abs_det_jacobian method to compute the log-absolute-determinant of the transfomation’s Jacobian \(\log \left| \det \frac{\partial f(x)}{\partial x} \right|\).

transform = torch.distributions.AffineTransform(torch.tensor(2.0), torch.tensor(3.0))

y = transform(x)                             # f(x)
xx = transform.inv(y)                        # f^{-1}(f(x))
ladj = transform.log_abs_det_jacobian(x, y)  # log |det df(x)/dx|

y, xx, ladj
(tensor(2.6367), tensor(0.2122), tensor(1.0986))

Combining a base distribution \(p(Z)\) and a transformation \(f: X \mapsto Z\) defines a new distribution \(p(X)\). The likelihood is given by the change of random variables formula

\[ p(X = x) = p(Z = f(x)) \left| \det \frac{\partial f(x)}{\partial x} \right| \]

and sampling from \(p(X)\) can be performed by first drawing realizations \(z \sim p(Z)\) and then applying the inverse transformation \(x = f^{-1}(z)\). Such combination of a base distribution and a bijective transformation is sometimes called a normalizing flow as the base distribution is often standard normal.

flow = zuko.distributions.NormalizingFlow(transform, distribution)

x = flow.sample()
log_p = flow.log_prob(x)

x, log_p
(tensor(-0.6321), tensor(0.1743))

1.2. Parametrization#

When designing the distributions module, the PyTorch team decided that distributions and transformations should be lightweight objects that are used as part of computations but destroyed afterwards. Consequently, the Distribution and Transform classes are not sub-classes of torch.nn.Module, which means that we cannot retrieve their parameters with .parameters(), send their internal tensor to GPU with .to('cuda') or train them as regular neural networks. In addition, the concepts of conditional distribution and transformation, which are essential for probabilistic inference, are impossible to express with the current interface.

To solve these problems, zuko defines two concepts: the LazyDistribution and the LazyTransform, which are modules whose forward pass returns a distribution or transformation, respectively. These components hold the parameters of the distributions/transformations as well as the recipe to build them, such that the actual distribution/transformation objects are lazily built and destroyed when necessary. Importantly, because the creation of the distribution/transformation object is delayed, an eventual condition can be easily taken into account. This design enables lazy distributions to act like distributions while retaining features inherent to modules, such as trainable parameters.

1.2.1. Variational inference#

Let’s say we have a dataset of pairs \((x, c) \sim p(X, C)\) and want to model the distribution of \(X\) given \(c\), that is \(p(X | c)\). The goal of variational inference is to find the model \(q_{\phi^\star}(X | c)\) that is most similar to \(p(X | c)\) among a family of (conditional) distributions \(q_\phi(X | c)\) distinguished by their parameters \(\phi\). Expressing the dissimilarity between two distributions as their Kullback-Leibler (KL) divergence, the variational inference objective becomes

\[\begin{split} \begin{align} \phi^* = \arg \min_\phi & ~ \mathrm{KL} \big( p(x, c) || q_\phi(x | c) \, p(c) \big) \\ = \arg \min_\phi & ~ \mathbb{E}_{p(x, c)} \left[ \log \frac{p(x, c)}{q_\phi(x | c) \, p(c)} \right] \\ = \arg \min_\phi & ~ \mathbb{E}_{p(x, c)} \big[ -\log q_\phi(x | c) \big] \end{align} \end{split}\]

For example, let \(X\) be a standard Gaussian variable and \(C\) be a vector of three unit Gaussian variables \(C_i\) centered at \(X\).

x = torch.distributions.Normal(0, 1).sample((1024,))
c = torch.distributions.Normal(x, 1).sample((3,)).T

for i in range(3):
    print(x[i], c[i])
tensor(1.5043) tensor([0.3570, 1.6565, 1.0535])
tensor(1.0068) tensor([2.3783, 0.3633, 2.3330])
tensor(-0.5320) tensor([ 0.5544, -0.7699, -0.2988])

We choose a Gaussian model of the form \(\mathcal{N}(x | \mu_\phi(c), \sigma_\phi^2(c))\) as our distribution familly, which we implement as a LazyDistribution.

class GaussianModel(zuko.flows.LazyDistribution):
    def __init__(self):
        super().__init__()

        self.hyper = torch.nn.Sequential(
            torch.nn.Linear(3, 64),
            torch.nn.ReLU(),
            torch.nn.Linear(64, 64),
            torch.nn.ReLU(),
            torch.nn.Linear(64, 2),  # mu, log(sigma)
        )

    def forward(self, c: torch.Tensor):
        mu, log_sigma = self.hyper(c).unbind(dim=-1)

        return torch.distributions.Normal(mu, log_sigma.exp())

model = GaussianModel()
model
GaussianModel(
  (hyper): Sequential(
    (0): Linear(in_features=3, out_features=64, bias=True)
    (1): ReLU()
    (2): Linear(in_features=64, out_features=64, bias=True)
    (3): ReLU()
    (4): Linear(in_features=64, out_features=2, bias=True)
  )
)

Calling the forward method of the model with a context \(c\) returns a distribution object, which we can use to draw realizations or evaluate the likelihood of realizations.

distribution = model(c=c[0])
distribution
Normal(loc: -0.24979770183563232, scale: 1.1048245429992676)
distribution.sample()
tensor(0.4013)
distribution.log_prob(x[0])
tensor(-2.2790, grad_fn=<SubBackward0>)

The result of log_prob is part of a computation graph (it has a grad_fn) and therefore it can be used to train the parameters of the model by variational inference. Importantly, when the parameters of the model are modified, for example due to a gradient descent step, you must remember to call the forward method again to re-build the distribution with the new parameters.

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for _ in range(64):
    loss = -model(c).log_prob(x).mean()  # E[-log q(x | c)]
    loss.backward()

    optimizer.step()
    optimizer.zero_grad()

1.3. Normalizing flows#

Following the same spirit, a parameteric normalizing flow in Zuko is a special LazyDistribution that contains a LazyTransform and a base LazyDistribution. To increase expressivity, the transformation is usually the composition of a sequence of “simple” transformations

\[ f(x) = f_n \circ \dots \circ f_2 \circ f_1(x) \]

for which we can compute the determinant of the Jacobian as

\[ \mathrm{det} \frac{\partial f(x)}{\partial x} = \prod_{i = 1}^{n} \mathrm{det} \frac{\partial f_i(x_{i-1})}{\partial x_{i-1}} \]

where \(x_{0} = x\) and \(x_i = f_i(x_{i-1})\). In the univariate case, finding a bijective transformation whose determinant of the Jacobian is tractable is easy: any differentiable monotonic function works. In the multivariate case, the most common way to make the determinant easy to compute is to enforce a triangular Jacobian. This is achieved by a transformation \(y = f(x)\) where each element \(y_i\) is a monotonic function of \(x_i\), conditionned on the preceding elements \(x_{<i}\).

\[ y_i = f(x_i | x_{<i}) \]

Autoregressive and coupling transformations [1-2] are notable examples of this class of transformations.

transform = zuko.flows.MaskedAutoregressiveTransform(
    features=5,
    context=0,                                         # no context
    univariate=zuko.transforms.MonotonicRQSTransform,  # rational-quadratic spline
    shapes=([8], [8], [7]),                            # shapes of the spline parameters (8 bins)
    hidden_features=(64, 128, 256),                    # size of the hyper-network
)

transform
MaskedAutoregressiveTransform(
  (base): MonotonicRQSTransform(bins=8)
  (order): [0, 1, 2, 3, 4]
  (hyper): MaskedMLP(
    (0): MaskedLinear(in_features=5, out_features=64, bias=True)
    (1): ReLU()
    (2): MaskedLinear(in_features=64, out_features=128, bias=True)
    (3): ReLU()
    (4): MaskedLinear(in_features=128, out_features=256, bias=True)
    (5): ReLU()
    (6): MaskedLinear(in_features=256, out_features=115, bias=True)
  )
)
f = transform()
x = torch.randn(5)
y = f(x)
xx = f.inv(y)

print(x, xx, sep='\n')
tensor([ 0.6927,  0.6688,  0.2485, -1.8932, -1.5444])
tensor([ 0.6927,  0.6688,  0.2485, -1.8932, -1.5444], grad_fn=<WhereBackward0>)
torch.autograd.functional.jacobian(f, x).round(decimals=3)
tensor([[ 0.9950,  0.0000,  0.0000,  0.0000,  0.0000],
        [-0.0020,  0.8900,  0.0000,  0.0000,  0.0000],
        [-0.0060,  0.0010,  0.9690,  0.0000,  0.0000],
        [ 0.0390, -0.0050,  0.0000,  0.9540,  0.0000],
        [-0.0190, -0.0010,  0.0030, -0.0010,  0.8670]])

We can see that the Jacobian of the autoregressive transformation is indeed triangular.

1.3.1. Pre-built architecture#

Zuko provides many pre-built flow architectures including NICE, MAF, NSF, CNF and many others [1-4]. We recommend users to try MAF and NSF first as they are efficient baselines. In the following cell, we instantiate a contional flow (3 sample features and 8 context features) with 3 affine autoregressive transformations.

flow = zuko.flows.MAF(features=5, context=8, transforms=3)
flow
MAF(
  (transform): LazyComposedTransform(
    (0): MaskedAutoregressiveTransform(
      (base): MonotonicAffineTransform()
      (order): [0, 1, 2, 3, 4]
      (hyper): MaskedMLP(
        (0): MaskedLinear(in_features=13, out_features=64, bias=True)
        (1): ReLU()
        (2): MaskedLinear(in_features=64, out_features=64, bias=True)
        (3): ReLU()
        (4): MaskedLinear(in_features=64, out_features=10, bias=True)
      )
    )
    (1): MaskedAutoregressiveTransform(
      (base): MonotonicAffineTransform()
      (order): [4, 3, 2, 1, 0]
      (hyper): MaskedMLP(
        (0): MaskedLinear(in_features=13, out_features=64, bias=True)
        (1): ReLU()
        (2): MaskedLinear(in_features=64, out_features=64, bias=True)
        (3): ReLU()
        (4): MaskedLinear(in_features=64, out_features=10, bias=True)
      )
    )
    (2): MaskedAutoregressiveTransform(
      (base): MonotonicAffineTransform()
      (order): [0, 1, 2, 3, 4]
      (hyper): MaskedMLP(
        (0): MaskedLinear(in_features=13, out_features=64, bias=True)
        (1): ReLU()
        (2): MaskedLinear(in_features=64, out_features=64, bias=True)
        (3): ReLU()
        (4): MaskedLinear(in_features=64, out_features=10, bias=True)
      )
    )
  )
  (base): Unconditional(DiagNormal(loc: torch.Size([5]), scale: torch.Size([5])))
)

1.3.2. Custom architecture#

Alternatively, a flow can be built as a custom Flow object given a sequence of lazy transformations and a base lazy distribution. Follows a condensed example of many things that are possible in Zuko. But remember, with great power comes great responsibility (and great bugs).

from zuko.flows import (
    Flow,
    GeneralCouplingTransform,
    MaskedAutoregressiveTransform,
    NeuralAutoregressiveTransform,
    Unconditional,
)
from zuko.distributions import BoxUniform
from zuko.transforms import (
    AffineTransform,
    MonotonicRQSTransform,
    RotationTransform,
    SigmoidTransform,
)

flow = Flow(
    transform=[
        Unconditional(              # [0, 255] to ]0, 1[
            AffineTransform,        # y = loc + scale * x
            torch.tensor(1 / 512),  # loc
            torch.tensor(1 / 256),  # scale
            buffer=True,            # not trainable
        ),
        Unconditional(lambda: SigmoidTransform().inv),  # y = logit(x)
        MaskedAutoregressiveTransform(  # autoregressive transform (affine by default)
            features=5,
            context=8,
            passes=5,  # fully-autoregressive
        ),
        Unconditional(RotationTransform, torch.randn(5, 5)),  # trainable rotation
        GeneralCouplingTransform(  # coupling transform
            features=5,
            context=8,
            univariate=MonotonicRQSTransform,  # rational-quadratic spline
            shapes=([8], [8], [7]),            # shapes of the spline parameters (8 bins)
            hidden_features=(256, 256),        # size of the hyper-network
            activation=torch.nn.ELU,           # ELU activation in hyper-network
        ).inv,  # inverse
        Unconditional(  # ignore context
            NeuralAutoregressiveTransform(
                features=5,
                order=[5, 2, 0, 3, 1],  # autoregressive order
                passes=2,               # 2-pass autoregressive (equivalent to coupling)
            )
        ),
    ],
    base=Unconditional(  # ignore context
        BoxUniform,
        torch.full([5], -3.0),  # lower bound
        torch.full([5], +3.0),  # upper bound
        buffer=True,            # not trainable
    ),
)

flow
Flow(
  (transform): LazyComposedTransform(
    (0): Unconditional(AffineTransform())
    (1): Unconditional(Inverse(SigmoidTransform()))
    (2): MaskedAutoregressiveTransform(
      (base): MonotonicAffineTransform()
      (order): [0, 1, 2, 3, 4]
      (hyper): MaskedMLP(
        (0): MaskedLinear(in_features=13, out_features=64, bias=True)
        (1): ReLU()
        (2): MaskedLinear(in_features=64, out_features=64, bias=True)
        (3): ReLU()
        (4): MaskedLinear(in_features=64, out_features=10, bias=True)
      )
    )
    (3): Unconditional(RotationTransform())
    (4): LazyInverse(
      (transform): GeneralCouplingTransform(
        (base): MonotonicRQSTransform(bins=8)
        (mask): [0, 1, 0, 1, 0]
        (hyper): MLP(
          (0): Linear(in_features=10, out_features=256, bias=True)
          (1): ELU(alpha=1.0)
          (2): Linear(in_features=256, out_features=256, bias=True)
          (3): ELU(alpha=1.0)
          (4): Linear(in_features=256, out_features=69, bias=True)
        )
      )
    )
    (5): Unconditional(
      (meta): NeuralAutoregressiveTransform(
        (base): MonotonicTransform()
        (order): [1, 0, 0, 1, 0]
        (hyper): MaskedMLP(
          (0): MaskedLinear(in_features=5, out_features=64, bias=True)
          (1): ReLU()
          (2): MaskedLinear(in_features=64, out_features=64, bias=True)
          (3): ReLU()
          (4): MaskedLinear(in_features=64, out_features=80, bias=True)
        )
        (network): MonotonicMLP(
          (0): MonotonicLinear(in_features=17, out_features=64, bias=True, stack=5)
          (1): TwoWayELU(alpha=1.0)
          (2): MonotonicLinear(in_features=64, out_features=64, bias=True, stack=5)
          (3): TwoWayELU(alpha=1.0)
          (4): MonotonicLinear(in_features=64, out_features=1, bias=True, stack=5)
        )
      )
    )
  )
  (base): Unconditional(BoxUniform(low: torch.Size([5]), high: torch.Size([5])))
)

1.4. References#

  1. Masked Autoregressive Flow for Density Estimation (Papamakarios et al., 2017)
    https://arxiv.org/abs/1705.07057

  2. NICE: Non-linear Independent Components Estimation (Dinh et al., 2014)
    https://arxiv.org/abs/1410.8516

  3. Neural Spline Flows (Durkan et al., 2019)
    https://arxiv.org/abs/1906.04032

  4. Neural Ordinary Differential Equations (Chen et al., 2018)
    https://arxiv.org/abs/1806.07366