zuko.flows.autoregressive#
Autoregressive flows and transformations.
Classes#
Creates a lazy masked autoregressive transformation. |
|
Creates a masked autoregressive flow (MAF). |
Descriptions#
- class zuko.flows.autoregressive.MaskedAutoregressiveTransform(features=None, context=0, passes=None, order=None, *args, **kwargs)#
Creates a lazy masked autoregressive transformation.
References
Masked Autoregressive Flow for Density Estimation (Papamakarios et al., 2017)- Parameters:
features (int) – The number of features.
context (int) – The number of context features.
passes (int) – The number of sequential passes for the inverse transformation. If
None
, use the number of features instead, making the transformation fully autoregressive. Coupling corresponds topasses=2
.order (LongTensor) – The feature ordering. If
None
, userange(features)
instead.univariate – The univariate transformation constructor.
shapes – The shapes of the univariate transformation parameters.
kwargs – Keyword arguments passed to
zuko.nn.MaskedMLP
.
Example
>>> t = MaskedAutoregressiveTransform(3, 4) >>> t MaskedAutoregressiveTransform( (base): MonotonicAffineTransform() (order): [0, 1, 2] (hyper): MaskedMLP( (0): MaskedLinear(in_features=7, out_features=64, bias=True) (1): ReLU() (2): MaskedLinear(in_features=64, out_features=64, bias=True) (3): ReLU() (4): MaskedLinear(in_features=64, out_features=6, bias=True) ) ) >>> x = torch.randn(3) >>> x tensor([-0.9485, 1.5290, 0.2018]) >>> c = torch.randn(4) >>> y = t(c)(x) >>> t(c).inv(y) tensor([-0.9485, 1.5290, 0.2018])
- class zuko.flows.autoregressive.MAF(features, context=0, transforms=3, randperm=False, **kwargs)#
Creates a masked autoregressive flow (MAF).
References
Masked Autoregressive Flow for Density Estimation (Papamakarios et al., 2017)- Parameters:
features (int) – The number of features.
context (int) – The number of context features.
transforms (int) – The number of autoregressive transformations.
randperm (bool) – Whether features are randomly permuted between transformations or not. If
False
, features are in ascending (descending) order for even (odd) transformations.kwargs – Keyword arguments passed to
MaskedAutoregressiveTransform
.
Example
>>> flow = MAF(3, 4, transforms=3) >>> flow MAF( (transform): LazyComposedTransform( (0): MaskedAutoregressiveTransform( (base): MonotonicAffineTransform() (order): [0, 1, 2] (hyper): MaskedMLP( (0): MaskedLinear(in_features=7, out_features=64, bias=True) (1): ReLU() (2): MaskedLinear(in_features=64, out_features=64, bias=True) (3): ReLU() (4): MaskedLinear(in_features=64, out_features=6, bias=True) ) ) (1): MaskedAutoregressiveTransform( (base): MonotonicAffineTransform() (order): [2, 1, 0] (hyper): MaskedMLP( (0): MaskedLinear(in_features=7, out_features=64, bias=True) (1): ReLU() (2): MaskedLinear(in_features=64, out_features=64, bias=True) (3): ReLU() (4): MaskedLinear(in_features=64, out_features=6, bias=True) ) ) (2): MaskedAutoregressiveTransform( (base): MonotonicAffineTransform() (order): [0, 1, 2] (hyper): MaskedMLP( (0): MaskedLinear(in_features=7, out_features=64, bias=True) (1): ReLU() (2): MaskedLinear(in_features=64, out_features=64, bias=True) (3): ReLU() (4): MaskedLinear(in_features=64, out_features=6, bias=True) ) ) ) (base): Unconditional(DiagNormal(loc: torch.Size([3]), scale: torch.Size([3]))) ) >>> c = torch.randn(4) >>> x = flow(c).sample() >>> x tensor([-1.7154, -0.4401, 0.7505]) >>> flow(c).log_prob(x) tensor(-4.4630, grad_fn=<AddBackward0>)