1. Learn the basics#
This notebook walks you through the basics of PyTorch/Zuko distributions and transformations, how to parametrize probabilistic models, how to instantiate pre-built normalizing flows and finally how to create custom flow architectures. Training is covered in other tutorials.
import torch
import zuko
1.1. Distributions and transformations#
PyTorch defines two components for probabilistic modeling: the Distribution
and the Transform
. A distribution object represents the probability distribution \(p(X)\) of a random variable \(X\). A distribution must implement the sample
and log_prob
methods, meaning that we can draw realizations \(x \sim p(X)\) from the distribution and evaluate the log-likelihood \(\log p(X = x)\) of realizations.
distribution = torch.distributions.Normal(torch.tensor(0.0), torch.tensor(1.0))
x = distribution.sample() # x ~ p(X)
log_p = distribution.log_prob(x) # log p(X = x)
x, log_p
(tensor(0.2122), tensor(-0.9415))
A transform object represents a bijective transformation \(f: X \mapsto Y\) from a domain to a co-domain. A transformation must implement a forward call \(y = f(x)\), an inverse call \(x = f^{-1}(y)\) and the log_abs_det_jacobian
method to compute the log-absolute-determinant of the transfomation’s Jacobian \(\log \left| \det \frac{\partial f(x)}{\partial x} \right|\).
transform = torch.distributions.AffineTransform(torch.tensor(2.0), torch.tensor(3.0))
y = transform(x) # f(x)
xx = transform.inv(y) # f^{-1}(f(x))
ladj = transform.log_abs_det_jacobian(x, y) # log |det df(x)/dx|
y, xx, ladj
(tensor(2.6367), tensor(0.2122), tensor(1.0986))
Combining a base distribution \(p(Z)\) and a transformation \(f: X \mapsto Z\) defines a new distribution \(p(X)\). The likelihood is given by the change of random variables formula
and sampling from \(p(X)\) can be performed by first drawing realizations \(z \sim p(Z)\) and then applying the inverse transformation \(x = f^{-1}(z)\). Such combination of a base distribution and a bijective transformation is sometimes called a normalizing flow as the base distribution is often standard normal.
flow = zuko.distributions.NormalizingFlow(transform, distribution)
x = flow.sample()
log_p = flow.log_prob(x)
x, log_p
(tensor(-0.6321), tensor(0.1743))
1.2. Parametrization#
When designing the distributions module, the PyTorch team decided that distributions and transformations should be lightweight objects that are used as part of computations but destroyed afterwards. Consequently, the Distribution
and Transform
classes are not sub-classes of torch.nn.Module
, which means that we cannot retrieve their parameters with .parameters()
, send their internal tensor to GPU with .to('cuda')
or train them as regular neural networks. In addition, the concepts of conditional distribution and transformation, which are essential for probabilistic inference, are impossible to express with the current interface.
To solve these problems, zuko
defines two concepts: the LazyDistribution
and the LazyTransform
, which are modules whose forward pass returns a distribution or transformation, respectively. These components hold the parameters of the distributions/transformations as well as the recipe to build them, such that the actual distribution/transformation objects are lazily built and destroyed when necessary. Importantly, because the creation of the distribution/transformation object is delayed, an eventual condition can be easily taken into account. This design enables lazy distributions to act like distributions while retaining features inherent to modules, such as trainable parameters.
1.2.1. Variational inference#
Let’s say we have a dataset of pairs \((x, c) \sim p(X, C)\) and want to model the distribution of \(X\) given \(c\), that is \(p(X | c)\). The goal of variational inference is to find the model \(q_{\phi^\star}(X | c)\) that is most similar to \(p(X | c)\) among a family of (conditional) distributions \(q_\phi(X | c)\) distinguished by their parameters \(\phi\). Expressing the dissimilarity between two distributions as their Kullback-Leibler (KL) divergence, the variational inference objective becomes
For example, let \(X\) be a standard Gaussian variable and \(C\) be a vector of three unit Gaussian variables \(C_i\) centered at \(X\).
x = torch.distributions.Normal(0, 1).sample((1024,))
c = torch.distributions.Normal(x, 1).sample((3,)).T
for i in range(3):
print(x[i], c[i])
tensor(1.5043) tensor([0.3570, 1.6565, 1.0535])
tensor(1.0068) tensor([2.3783, 0.3633, 2.3330])
tensor(-0.5320) tensor([ 0.5544, -0.7699, -0.2988])
We choose a Gaussian model of the form \(\mathcal{N}(x | \mu_\phi(c), \sigma_\phi^2(c))\) as our distribution familly, which we implement as a LazyDistribution
.
class GaussianModel(zuko.flows.LazyDistribution):
def __init__(self):
super().__init__()
self.hyper = torch.nn.Sequential(
torch.nn.Linear(3, 64),
torch.nn.ReLU(),
torch.nn.Linear(64, 64),
torch.nn.ReLU(),
torch.nn.Linear(64, 2), # mu, log(sigma)
)
def forward(self, c: torch.Tensor):
mu, log_sigma = self.hyper(c).unbind(dim=-1)
return torch.distributions.Normal(mu, log_sigma.exp())
model = GaussianModel()
model
GaussianModel(
(hyper): Sequential(
(0): Linear(in_features=3, out_features=64, bias=True)
(1): ReLU()
(2): Linear(in_features=64, out_features=64, bias=True)
(3): ReLU()
(4): Linear(in_features=64, out_features=2, bias=True)
)
)
Calling the forward method of the model with a context \(c\) returns a distribution object, which we can use to draw realizations or evaluate the likelihood of realizations.
distribution = model(c=c[0])
distribution
Normal(loc: -0.24979770183563232, scale: 1.1048245429992676)
distribution.sample()
tensor(0.4013)
distribution.log_prob(x[0])
tensor(-2.2790, grad_fn=<SubBackward0>)
The result of log_prob
is part of a computation graph (it has a grad_fn
) and therefore it can be used to train the parameters of the model by variational inference. Importantly, when the parameters of the model are modified, for example due to a gradient descent step, you must remember to call the forward method again to re-build the distribution with the new parameters.
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
for _ in range(64):
loss = -model(c).log_prob(x).mean() # E[-log q(x | c)]
loss.backward()
optimizer.step()
optimizer.zero_grad()
1.3. Normalizing flows#
Following the same spirit, a parameteric normalizing flow in Zuko is a special LazyDistribution
that contains a LazyTransform
and a base LazyDistribution
. To increase expressivity, the transformation is usually the composition of a sequence of “simple” transformations
for which we can compute the determinant of the Jacobian as
where \(x_{0} = x\) and \(x_i = f_i(x_{i-1})\). In the univariate case, finding a bijective transformation whose determinant of the Jacobian is tractable is easy: any differentiable monotonic function works. In the multivariate case, the most common way to make the determinant easy to compute is to enforce a triangular Jacobian. This is achieved by a transformation \(y = f(x)\) where each element \(y_i\) is a monotonic function of \(x_i\), conditionned on the preceding elements \(x_{<i}\).
Autoregressive and coupling transformations [1-2] are notable examples of this class of transformations.
transform = zuko.flows.MaskedAutoregressiveTransform(
features=5,
context=0, # no context
univariate=zuko.transforms.MonotonicRQSTransform, # rational-quadratic spline
shapes=([8], [8], [7]), # shapes of the spline parameters (8 bins)
hidden_features=(64, 128, 256), # size of the hyper-network
)
transform
MaskedAutoregressiveTransform(
(base): MonotonicRQSTransform(bins=8)
(order): [0, 1, 2, 3, 4]
(hyper): MaskedMLP(
(0): MaskedLinear(in_features=5, out_features=64, bias=True)
(1): ReLU()
(2): MaskedLinear(in_features=64, out_features=128, bias=True)
(3): ReLU()
(4): MaskedLinear(in_features=128, out_features=256, bias=True)
(5): ReLU()
(6): MaskedLinear(in_features=256, out_features=115, bias=True)
)
)
f = transform()
x = torch.randn(5)
y = f(x)
xx = f.inv(y)
print(x, xx, sep='\n')
tensor([ 0.6927, 0.6688, 0.2485, -1.8932, -1.5444])
tensor([ 0.6927, 0.6688, 0.2485, -1.8932, -1.5444], grad_fn=<WhereBackward0>)
torch.autograd.functional.jacobian(f, x).round(decimals=3)
tensor([[ 0.9950, 0.0000, 0.0000, 0.0000, 0.0000],
[-0.0020, 0.8900, 0.0000, 0.0000, 0.0000],
[-0.0060, 0.0010, 0.9690, 0.0000, 0.0000],
[ 0.0390, -0.0050, 0.0000, 0.9540, 0.0000],
[-0.0190, -0.0010, 0.0030, -0.0010, 0.8670]])
We can see that the Jacobian of the autoregressive transformation is indeed triangular.
1.3.1. Pre-built architecture#
Zuko provides many pre-built flow architectures including NICE
, MAF
, NSF
, CNF
and many others [1-4]. We recommend users to try MAF
and NSF
first as they are efficient baselines. In the following cell, we instantiate a contional flow (3 sample features and 8 context features) with 3 affine autoregressive transformations.
flow = zuko.flows.MAF(features=5, context=8, transforms=3)
flow
MAF(
(transform): LazyComposedTransform(
(0): MaskedAutoregressiveTransform(
(base): MonotonicAffineTransform()
(order): [0, 1, 2, 3, 4]
(hyper): MaskedMLP(
(0): MaskedLinear(in_features=13, out_features=64, bias=True)
(1): ReLU()
(2): MaskedLinear(in_features=64, out_features=64, bias=True)
(3): ReLU()
(4): MaskedLinear(in_features=64, out_features=10, bias=True)
)
)
(1): MaskedAutoregressiveTransform(
(base): MonotonicAffineTransform()
(order): [4, 3, 2, 1, 0]
(hyper): MaskedMLP(
(0): MaskedLinear(in_features=13, out_features=64, bias=True)
(1): ReLU()
(2): MaskedLinear(in_features=64, out_features=64, bias=True)
(3): ReLU()
(4): MaskedLinear(in_features=64, out_features=10, bias=True)
)
)
(2): MaskedAutoregressiveTransform(
(base): MonotonicAffineTransform()
(order): [0, 1, 2, 3, 4]
(hyper): MaskedMLP(
(0): MaskedLinear(in_features=13, out_features=64, bias=True)
(1): ReLU()
(2): MaskedLinear(in_features=64, out_features=64, bias=True)
(3): ReLU()
(4): MaskedLinear(in_features=64, out_features=10, bias=True)
)
)
)
(base): Unconditional(DiagNormal(loc: torch.Size([5]), scale: torch.Size([5])))
)
1.3.2. Custom architecture#
Alternatively, a flow can be built as a custom Flow
object given a sequence of lazy transformations and a base lazy distribution. Follows a condensed example of many things that are possible in Zuko. But remember, with great power comes great responsibility (and great bugs).
from zuko.flows import (
Flow,
GeneralCouplingTransform,
MaskedAutoregressiveTransform,
NeuralAutoregressiveTransform,
Unconditional,
)
from zuko.distributions import BoxUniform
from zuko.transforms import (
AffineTransform,
MonotonicRQSTransform,
RotationTransform,
SigmoidTransform,
)
flow = Flow(
transform=[
Unconditional( # [0, 255] to ]0, 1[
AffineTransform, # y = loc + scale * x
torch.tensor(1 / 512), # loc
torch.tensor(1 / 256), # scale
buffer=True, # not trainable
),
Unconditional(lambda: SigmoidTransform().inv), # y = logit(x)
MaskedAutoregressiveTransform( # autoregressive transform (affine by default)
features=5,
context=8,
passes=5, # fully-autoregressive
),
Unconditional(RotationTransform, torch.randn(5, 5)), # trainable rotation
GeneralCouplingTransform( # coupling transform
features=5,
context=8,
univariate=MonotonicRQSTransform, # rational-quadratic spline
shapes=([8], [8], [7]), # shapes of the spline parameters (8 bins)
hidden_features=(256, 256), # size of the hyper-network
activation=torch.nn.ELU, # ELU activation in hyper-network
).inv, # inverse
Unconditional( # ignore context
NeuralAutoregressiveTransform(
features=5,
order=[5, 2, 0, 3, 1], # autoregressive order
passes=2, # 2-pass autoregressive (equivalent to coupling)
)
),
],
base=Unconditional( # ignore context
BoxUniform,
torch.full([5], -3.0), # lower bound
torch.full([5], +3.0), # upper bound
buffer=True, # not trainable
),
)
flow
Flow(
(transform): LazyComposedTransform(
(0): Unconditional(AffineTransform())
(1): Unconditional(Inverse(SigmoidTransform()))
(2): MaskedAutoregressiveTransform(
(base): MonotonicAffineTransform()
(order): [0, 1, 2, 3, 4]
(hyper): MaskedMLP(
(0): MaskedLinear(in_features=13, out_features=64, bias=True)
(1): ReLU()
(2): MaskedLinear(in_features=64, out_features=64, bias=True)
(3): ReLU()
(4): MaskedLinear(in_features=64, out_features=10, bias=True)
)
)
(3): Unconditional(RotationTransform())
(4): LazyInverse(
(transform): GeneralCouplingTransform(
(base): MonotonicRQSTransform(bins=8)
(mask): [0, 1, 0, 1, 0]
(hyper): MLP(
(0): Linear(in_features=10, out_features=256, bias=True)
(1): ELU(alpha=1.0)
(2): Linear(in_features=256, out_features=256, bias=True)
(3): ELU(alpha=1.0)
(4): Linear(in_features=256, out_features=69, bias=True)
)
)
)
(5): Unconditional(
(meta): NeuralAutoregressiveTransform(
(base): MonotonicTransform()
(order): [1, 0, 0, 1, 0]
(hyper): MaskedMLP(
(0): MaskedLinear(in_features=5, out_features=64, bias=True)
(1): ReLU()
(2): MaskedLinear(in_features=64, out_features=64, bias=True)
(3): ReLU()
(4): MaskedLinear(in_features=64, out_features=80, bias=True)
)
(network): MonotonicMLP(
(0): MonotonicLinear(in_features=17, out_features=64, bias=True, stack=5)
(1): TwoWayELU(alpha=1.0)
(2): MonotonicLinear(in_features=64, out_features=64, bias=True, stack=5)
(3): TwoWayELU(alpha=1.0)
(4): MonotonicLinear(in_features=64, out_features=1, bias=True, stack=5)
)
)
)
)
(base): Unconditional(BoxUniform(low: torch.Size([5]), high: torch.Size([5])))
)
1.4. References#
Masked Autoregressive Flow for Density Estimation (Papamakarios et al., 2017)
https://arxiv.org/abs/1705.07057NICE: Non-linear Independent Components Estimation (Dinh et al., 2014)
https://arxiv.org/abs/1410.8516Neural Spline Flows (Durkan et al., 2019)
https://arxiv.org/abs/1906.04032Neural Ordinary Differential Equations (Chen et al., 2018)
https://arxiv.org/abs/1806.07366