zuko.flows.autoregressive

Autoregressive flows and transformations.

Classes

MAF

Creates a masked autoregressive flow (MAF).

MaskedAutoregressiveTransform

Creates a lazy masked autoregressive transformation.

Descriptions

class zuko.flows.autoregressive.MAF(features, context=0, transforms=3, randperm=False, **kwargs)

Creates a masked autoregressive flow (MAF).

References

Masked Autoregressive Flow for Density Estimation (Papamakarios et al., 2017)
Parameters:
  • features (int) – The number of features.

  • context (int) – The number of context features.

  • transforms (int) – The number of autoregressive transformations.

  • randperm (bool) – Whether features are randomly permuted between transformations or not. If False, features are in ascending (descending) order for even (odd) transformations.

  • kwargs – Keyword arguments passed to MaskedAutoregressiveTransform.

Example

>>> flow = MAF(3, 4, transforms=3)
>>> flow
MAF(
  (transform): LazyComposedTransform(
    (0): MaskedAutoregressiveTransform(
      (base): MonotonicAffineTransform()
      (order): [0, 1, 2]
      (hyper): MaskedMLP(
        (0): MaskedLinear(in_features=7, out_features=64, bias=True)
        (1): ReLU()
        (2): MaskedLinear(in_features=64, out_features=64, bias=True)
        (3): ReLU()
        (4): MaskedLinear(in_features=64, out_features=6, bias=True)
      )
    )
    (1): MaskedAutoregressiveTransform(
      (base): MonotonicAffineTransform()
      (order): [2, 1, 0]
      (hyper): MaskedMLP(
        (0): MaskedLinear(in_features=7, out_features=64, bias=True)
        (1): ReLU()
        (2): MaskedLinear(in_features=64, out_features=64, bias=True)
        (3): ReLU()
        (4): MaskedLinear(in_features=64, out_features=6, bias=True)
      )
    )
    (2): MaskedAutoregressiveTransform(
      (base): MonotonicAffineTransform()
      (order): [0, 1, 2]
      (hyper): MaskedMLP(
        (0): MaskedLinear(in_features=7, out_features=64, bias=True)
        (1): ReLU()
        (2): MaskedLinear(in_features=64, out_features=64, bias=True)
        (3): ReLU()
        (4): MaskedLinear(in_features=64, out_features=6, bias=True)
      )
    )
  )
  (base): UnconditionalDistribution(DiagNormal(loc: torch.Size([3]), scale: torch.Size([3])))
)
>>> c = torch.randn(4)
>>> x = flow(c).sample()
>>> x
tensor([-0.5012, -1.6298,  0.3803])
>>> flow(c).log_prob(x)
tensor(-3.7514, grad_fn=<AddBackward0>)
class zuko.flows.autoregressive.MaskedAutoregressiveTransform(features=None, context=0, passes=None, order=None, adjacency=None, *args, **kwargs)

Creates a lazy masked autoregressive transformation.

References

Masked Autoregressive Flow for Density Estimation (Papamakarios et al., 2017)
Parameters:
  • features (int) – The number of features.

  • context (int) – The number of context features.

  • passes (int) – The number of sequential passes for the inverse transformation. If None, use the number of features instead, making the transformation fully autoregressive. Coupling corresponds to passes=2.

  • order (LongTensor) – A feature ordering. If None, use range(features) instead.

  • adjacency (BoolTensor) – An adjacency matrix describing the transformation graph. If adjacency is provided, order is ignored and passes is replaced by the diameter of the graph. Its shape must be either (features, features) or (features, features + context). If the shape includes context, the rightmost context columns define connections to the conditioning variables.

  • univariate – The univariate transformation constructor.

  • shapes – The shapes of the univariate transformation parameters.

  • kwargs – Keyword arguments passed to zuko.nn.MaskedMLP.

Example

>>> t = MaskedAutoregressiveTransform(3, 4)
>>> t
MaskedAutoregressiveTransform(
  (base): MonotonicAffineTransform()
  (order): [0, 1, 2]
  (hyper): MaskedMLP(
    (0): MaskedLinear(in_features=7, out_features=64, bias=True)
    (1): ReLU()
    (2): MaskedLinear(in_features=64, out_features=64, bias=True)
    (3): ReLU()
    (4): MaskedLinear(in_features=64, out_features=6, bias=True)
  )
)
>>> x = torch.randn(3)
>>> x
tensor([ 1.7428, -1.6483, -0.9920])
>>> c = torch.randn(4)
>>> y = t(c)(x)
>>> t(c).inv(y)
tensor([ 1.7428, -1.6483, -0.9920], grad_fn=<DivBackward0>)