_images/banner.svg_images/banner_dark.svg

Zuko#

Zuko is a Python package that implements normalizing flows in PyTorch. It relies as much as possible on distributions and transformations already provided by PyTorch. Unfortunately, the Distribution and Transform classes of torch are not sub-classes of torch.nn.Module, which means you cannot send their internal tensors to GPU with .to('cuda') or retrieve their parameters with .parameters().

To solve this problem, zuko defines two abstract classes: zuko.flows.DistributionModule and zuko.flows.TransformModule. The former is any Module whose forward pass returns a Distribution and the latter is any Module whose forward pass returns a Transform. A normalizing flow is just a DistributionModule which contains a list of TransformModule and a base DistributionModule. This design allows for flows that behave like distributions while retaining the benefits of Module. It also makes the implementations easier to understand and extend.

Installation#

The zuko package is available on PyPI, which means it is installable via pip.

pip install zuko

Alternatively, if you need the latest features, you can install it from the repository.

pip install git+https://github.com/francois-rozet/zuko

Getting started#

Normalizing flows are provided in the zuko.flows module. To build one, supply the number of sample and context features as well as the transformations’ hyperparameters. Then, feeding a context \(c\) to the flow returns a conditional distribution \(p(x | c)\) which can be evaluated and sampled from.

import torch
import zuko

# Neural spline flow (NSF) with 3 sample features and 5 context features
flow = zuko.flows.NSF(3, 5, transforms=3, hidden_features=[128] * 3)

# Train to maximize the log-likelihood
optimizer = torch.optim.AdamW(flow.parameters(), lr=1e-3)

for x, c in trainset:
    loss = -flow(c).log_prob(x)  # -log p(x | c)
    loss = loss.mean()

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

# Sample 64 points x ~ p(x | c*)
x = flow(c_star).sample((64,))

Alternatively, flows can be built as custom zuko.flows.FlowModule objects.

from zuko.flows import FlowModule, MaskedAutoregressiveTransform, Unconditional
from zuko.distributions import DiagNormal
from zuko.transforms import PermutationTransform

flow = FlowModule(
    transforms=[
        MaskedAutoregressiveTransform(3, 5, hidden_features=[128] * 3),
        Unconditional(PermutationTransform, torch.randperm(3), buffer=True),
        MaskedAutoregressiveTransform(3, 5, hidden_features=[128] * 3),
    ],
    base=Unconditional(
        DiagNormal,
        torch.zeros(3),
        torch.ones(3),
        buffer=True,
    ),
)

References#

Variational Inference with Normalizing Flows (Rezende et al., 2015)
Masked Autoregressive Flow for Density Estimation (Papamakarios et al., 2017)
Neural Spline Flows (Durkan et al., 2019)
Neural Autoregressive Flows (Huang et al., 2018)