zuko.flows.autoregressive#

Autoregressive flows and transformations.

Classes#

MaskedAutoregressiveTransform

Creates a lazy masked autoregressive transformation.

MAF

Creates a masked autoregressive flow (MAF).

Descriptions#

class zuko.flows.autoregressive.MaskedAutoregressiveTransform(features=None, context=0, passes=None, order=None, *args, **kwargs)#

Creates a lazy masked autoregressive transformation.

References

Masked Autoregressive Flow for Density Estimation (Papamakarios et al., 2017)
Parameters:
  • features (int) – The number of features.

  • context (int) – The number of context features.

  • passes (int) – The number of sequential passes for the inverse transformation. If None, use the number of features instead, making the transformation fully autoregressive. Coupling corresponds to passes=2.

  • order (LongTensor) – The feature ordering. If None, use range(features) instead.

  • univariate – The univariate transformation constructor.

  • shapes – The shapes of the univariate transformation parameters.

  • kwargs – Keyword arguments passed to zuko.nn.MaskedMLP.

Example

>>> t = MaskedAutoregressiveTransform(3, 4)
>>> t
MaskedAutoregressiveTransform(
  (base): MonotonicAffineTransform()
  (order): [0, 1, 2]
  (hyper): MaskedMLP(
    (0): MaskedLinear(in_features=7, out_features=64, bias=True)
    (1): ReLU()
    (2): MaskedLinear(in_features=64, out_features=64, bias=True)
    (3): ReLU()
    (4): MaskedLinear(in_features=64, out_features=6, bias=True)
  )
)
>>> x = torch.randn(3)
>>> x
tensor([-0.9485,  1.5290,  0.2018])
>>> c = torch.randn(4)
>>> y = t(c)(x)
>>> t(c).inv(y)
tensor([-0.9485,  1.5290,  0.2018])
class zuko.flows.autoregressive.MAF(features, context=0, transforms=3, randperm=False, **kwargs)#

Creates a masked autoregressive flow (MAF).

References

Masked Autoregressive Flow for Density Estimation (Papamakarios et al., 2017)
Parameters:
  • features (int) – The number of features.

  • context (int) – The number of context features.

  • transforms (int) – The number of autoregressive transformations.

  • randperm (bool) – Whether features are randomly permuted between transformations or not. If False, features are in ascending (descending) order for even (odd) transformations.

  • kwargs – Keyword arguments passed to MaskedAutoregressiveTransform.

Example

>>> flow = MAF(3, 4, transforms=3)
>>> flow
MAF(
  (transform): LazyComposedTransform(
    (0): MaskedAutoregressiveTransform(
      (base): MonotonicAffineTransform()
      (order): [0, 1, 2]
      (hyper): MaskedMLP(
        (0): MaskedLinear(in_features=7, out_features=64, bias=True)
        (1): ReLU()
        (2): MaskedLinear(in_features=64, out_features=64, bias=True)
        (3): ReLU()
        (4): MaskedLinear(in_features=64, out_features=6, bias=True)
      )
    )
    (1): MaskedAutoregressiveTransform(
      (base): MonotonicAffineTransform()
      (order): [2, 1, 0]
      (hyper): MaskedMLP(
        (0): MaskedLinear(in_features=7, out_features=64, bias=True)
        (1): ReLU()
        (2): MaskedLinear(in_features=64, out_features=64, bias=True)
        (3): ReLU()
        (4): MaskedLinear(in_features=64, out_features=6, bias=True)
      )
    )
    (2): MaskedAutoregressiveTransform(
      (base): MonotonicAffineTransform()
      (order): [0, 1, 2]
      (hyper): MaskedMLP(
        (0): MaskedLinear(in_features=7, out_features=64, bias=True)
        (1): ReLU()
        (2): MaskedLinear(in_features=64, out_features=64, bias=True)
        (3): ReLU()
        (4): MaskedLinear(in_features=64, out_features=6, bias=True)
      )
    )
  )
  (base): Unconditional(DiagNormal(loc: torch.Size([3]), scale: torch.Size([3])))
)
>>> c = torch.randn(4)
>>> x = flow(c).sample()
>>> x
tensor([-1.7154, -0.4401,  0.7505])
>>> flow(c).log_prob(x)
tensor(-4.4630, grad_fn=<AddBackward0>)