zuko.flows.core#

Core building blocks.

Classes#

LazyDistribution

Abstract lazy distribution.

LazyTransform

Abstract lazy transformation.

LazyComposedTransform

Creates a lazy composed transformation.

Flow

Creates a lazy normalizing flow.

Unconditional

Creates an unconditional lazy module from a constructor.

Descriptions#

class zuko.flows.core.LazyDistribution(*args, **kwargs)#

Abstract lazy distribution.

A lazy distribution is a module that builds and returns a distribution \(p(X | c)\) within its forward pass, given a context \(c\).

abstract forward(c=None)#
Parameters:

c (Any | None) – A context \(c\).

Returns:

A distribution \(p(X | c)\).

Return type:

Distribution

class zuko.flows.core.LazyTransform(*args, **kwargs)#

Abstract lazy transformation.

A lazy transformation is a module that builds and returns a transformation \(y = f(x | c)\) within its forward pass, given a context \(c\).

abstract forward(c=None)#
Parameters:

c (Any | None) – A context \(c\).

Returns:

A transformation \(y = f(x | c)\).

Return type:

Transform

property inv: LazyTransform#

A lazy inverse transformation \(x = f^{-1}(y | c)\).

class zuko.flows.core.LazyComposedTransform(*transforms)#

Creates a lazy composed transformation.

Parameters:

transforms (LazyTransform) – A sequence of lazy transformations \(f_i\).

forward(c=None)#
Parameters:

c (Any | None) – A context \(c\).

Returns:

A transformation \(y = f_n \circ \dots \circ f_0(x | c)\).

Return type:

Transform

class zuko.flows.core.Flow(transform, base)#

Creates a lazy normalizing flow.

Parameters:
forward(c=None)#
Parameters:

c (Tensor | None) – A context \(c\).

Returns:

A normalizing flow \(p(X | c)\).

Return type:

NormalizingFlow

class zuko.flows.core.Unconditional(meta, *args, buffer=False, **kwargs)#

Creates an unconditional lazy module from a constructor.

Typically, the constructor returns a distribution or transformation. The positional arguments of the constructor are registered as buffers or parameters.

Parameters:
  • meta (Callable[..., Any]) – An arbitrary constructor function.

  • args (Tensor) – The positional tensor arguments passed to meta.

  • buffer (bool) – Whether tensors are registered as buffers or parameters.

  • kwargs – The keyword arguments passed to meta.

Examples

>>> mu, sigma = torch.zeros(3), torch.ones(3)
>>> d = Unconditional(DiagNormal, mu, sigma, buffer=True)
>>> d()
DiagNormal(loc: torch.Size([3]), scale: torch.Size([3]))
>>> d().sample()
tensor([-0.6687, -0.9690,  1.7461])
>>> t = Unconditional(ExpTransform)
>>> t()
ExpTransform()
>>> x = torch.randn(3)
>>> t()(x)
tensor([0.5523, 0.7997, 0.9189])
forward(c=None)#
Parameters:

c (Tensor | None) – A context \(c\). This argument is always ignored.

Returns:

meta(*args, **kwargs)

Return type:

Any