zuko.flows.coupling#
Coupling flows and transformations.
Classes#
Creates a lazy general coupling transformation. |
|
Creates a NICE flow. |
Descriptions#
- class zuko.flows.coupling.GeneralCouplingTransform(features=None, context=0, mask=None, *args, **kwargs)#
Creates a lazy general coupling transformation.
See also
References
NICE: Non-linear Independent Components Estimation (Dinh et al., 2014)- Parameters:
features (int) – The number of features.
context (int) – The number of context features.
mask (BoolTensor) – The coupling mask. If
None
, use a checkered mask.univariate – The univariate transformation constructor.
shapes – The shapes of the univariate transformation parameters.
kwargs – Keyword arguments passed to
zuko.nn.MLP
.
Example
>>> t = GeneralCouplingTransform(3, 4) >>> t GeneralCouplingTransform( (base): MonotonicAffineTransform() (mask): [0, 1, 0] (hyper): MLP( (0): Linear(in_features=5, out_features=64, bias=True) (1): ReLU() (2): Linear(in_features=64, out_features=64, bias=True) (3): ReLU() (4): Linear(in_features=64, out_features=4, bias=True) ) ) >>> x = torch.randn(3) >>> x tensor([-0.8743, 0.6232, 1.2439]) >>> c = torch.randn(4) >>> y = t(c)(x) >>> t(c).inv(y) tensor([-0.8743, 0.6232, 1.2439])
- class zuko.flows.coupling.NICE(features, context=0, transforms=3, randmask=False, **kwargs)#
Creates a NICE flow.
Affine transformations are used by default, instead of the additive transformations used by Dinh et al. (2014) originally.
References
NICE: Non-linear Independent Components Estimation (Dinh et al., 2014)- Parameters:
features (int) – The number of features.
context (int) – The number of context features.
transforms (int) – The number of coupling transformations.
randmask (bool) – Whether random coupling masks are used or not. If
False
, use alternating checkered masks.kwargs – Keyword arguments passed to
GeneralCouplingTransform
.