zuko.flows.core#
Core building blocks.
Classes#
Abstract lazy distribution. |
|
Abstract lazy transformation. |
|
Creates a lazy composed transformation. |
|
Creates a lazy normalizing flow. |
|
Creates an unconditional lazy module from a constructor. |
Descriptions#
- class zuko.flows.core.LazyDistribution(*args, **kwargs)#
Abstract lazy distribution.
A lazy distribution is a module that builds and returns a distribution \(p(X | c)\) within its forward pass, given a context \(c\).
- class zuko.flows.core.LazyTransform(*args, **kwargs)#
Abstract lazy transformation.
A lazy transformation is a module that builds and returns a transformation \(y = f(x | c)\) within its forward pass, given a context \(c\).
- abstract forward(c=None)#
- property inv: LazyTransform#
A lazy inverse transformation \(x = f^{-1}(y | c)\).
- class zuko.flows.core.LazyComposedTransform(*transforms)#
Creates a lazy composed transformation.
See also
- Parameters:
transforms (LazyTransform) – A sequence of lazy transformations \(f_i\).
- class zuko.flows.core.Flow(transform, base)#
Creates a lazy normalizing flow.
See also
- Parameters:
transform (Union[LazyTransform, Sequence[LazyTransform]]) – A lazy transformation or sequence of lazy transformations.
base (LazyDistribution) – A lazy distribution.
- class zuko.flows.core.Unconditional(meta, *args, buffer=False, **kwargs)#
Creates an unconditional lazy module from a constructor.
Typically, the constructor returns a distribution or transformation. The positional arguments of the constructor are registered as buffers or parameters.
- Parameters:
meta (Callable[..., Any]) – An arbitrary constructor function.
args (Tensor) – The positional tensor arguments passed to
meta
.buffer (bool) – Whether tensors are registered as buffers or parameters.
kwargs – The keyword arguments passed to
meta
.
Examples
>>> mu, sigma = torch.zeros(3), torch.ones(3) >>> d = Unconditional(DiagNormal, mu, sigma, buffer=True) >>> d() DiagNormal(loc: torch.Size([3]), scale: torch.Size([3])) >>> d().sample() tensor([-0.6687, -0.9690, 1.7461])
>>> t = Unconditional(ExpTransform) >>> t() ExpTransform() >>> x = torch.randn(3) >>> t()(x) tensor([0.5523, 0.7997, 0.9189])