zuko.lazy

Lazy distributions and transformations.

Classes

Flow

Creates a lazy normalizing flow.

LazyComposedTransform

Creates a lazy composed transformation.

LazyDistribution

Abstract lazy distribution.

LazyTransform

Abstract lazy transformation.

Unconditional

Creates an unconditional lazy module from a constructor.

UnconditionalDistribution

Creates an unconditional lazy distribution from a constructor.

UnconditionalTransform

Creates an unconditional lazy transformation from a constructor.

Descriptions

class zuko.lazy.Flow(transform, base)

Creates a lazy normalizing flow.

Parameters:
forward(c=None)
Parameters:

c (Tensor | None) – A context \(c\).

Returns:

A normalizing flow \(p(X | c)\).

Return type:

NormalizingFlow

class zuko.lazy.LazyComposedTransform(*transforms)

Creates a lazy composed transformation.

Parameters:

transforms (LazyTransform) – A sequence of lazy transformations \(f_i\).

forward(c=None)
Parameters:

c (Any | None) – A context \(c\).

Returns:

A transformation \(y = f_n \circ \dots \circ f_0(x | c)\).

Return type:

Transform

class zuko.lazy.LazyDistribution(*args, **kwargs)

Abstract lazy distribution.

A lazy distribution is a module that builds and returns a distribution \(p(X | c)\) within its forward pass, given a context \(c\).

abstract forward(c=None)
Parameters:

c (Any | None) – A context \(c\).

Returns:

A distribution \(p(X | c)\).

Return type:

Distribution

class zuko.lazy.LazyTransform(*args, **kwargs)

Abstract lazy transformation.

A lazy transformation is a module that builds and returns a transformation \(y = f(x | c)\) within its forward pass, given a context \(c\).

abstract forward(c=None)
Parameters:

c (Any | None) – A context \(c\).

Returns:

A transformation \(y = f(x | c)\).

Return type:

Transform

property inv: LazyTransform

A lazy inverse transformation \(x = f^{-1}(y | c)\).

class zuko.lazy.Unconditional(meta, *args, buffer=False, **kwargs)

Creates an unconditional lazy module from a constructor.

Typically, the constructor returns a distribution or transformation. The positional arguments of the constructor are registered as buffers or parameters.

Warning

Unconditional is deprecated and will be removed in the future. Use UnconditionalDistribution or UnconditionalTransform instead.

Parameters:
  • meta (Callable[..., Any]) – An arbitrary constructor function.

  • args (Tensor) – The positional tensor arguments passed to meta.

  • buffer (bool) – Whether tensors are registered as buffers or parameters.

  • kwargs – The keyword arguments passed to meta.

forward(c=None)
Parameters:

c (Tensor | None) – A context \(c\). This argument is always ignored.

Returns:

meta(*args, **kwargs)

Return type:

Any

class zuko.lazy.UnconditionalDistribution(f, *args, buffer=False, **kwargs)

Creates an unconditional lazy distribution from a constructor.

The arguments of the constructor are registered as buffers or parameters.

Parameters:
  • f (Callable[..., Distribution]) – A distribution constructor. If f is a module, it is registered as a submodule.

  • args (Any) – The positional arguments passed to f.

  • buffer (bool) – Whether tensor arguments are registered as buffers or parameters.

  • kwargs (Any) – The keyword arguments passed to f.

Examples

>>> f = zuko.distributions.DiagNormal
>>> mu, sigma = torch.zeros(3), torch.ones(3)
>>> base = UnconditionalDistribution(f, mu, sigma, buffer=True)
>>> base()
DiagNormal(loc: torch.Size([3]), scale: torch.Size([3]))
>>> base().sample()
tensor([ 1.5410, -0.2934, -2.1788])
forward(c=None)
Parameters:

c (Tensor | None) – A context \(c\). This argument is always ignored.

Returns:

self.f(*self.args, **self.kwargs)

Return type:

Distribution

class zuko.lazy.UnconditionalTransform(f, *args, buffer=False, **kwargs)

Creates an unconditional lazy transformation from a constructor.

The arguments of the constructor are registered as buffers or parameters.

Parameters:
  • f (Callable[..., Transform]) – A transformation constructor. If f is a module, it is registered as a submodule.

  • args (Any) – The positional arguments passed to f.

  • buffer (bool) – Whether tensor arguments are registered as buffers or parameters.

  • kwargs (Any) – The keyword arguments passed to f.

Examples

>>> f = zuko.transforms.ExpTransform
>>> t = UnconditionalTransform(f)
>>> t()
ExpTransform()
>>> x = torch.randn(3)
>>> t()(x)
tensor([4.6692, 0.7457, 0.1132])
forward(c=None)
Parameters:

c (Tensor | None) – A context \(c\). This argument is always ignored.

Returns:

self.f(*self.args, **self.kwargs)

Return type:

Transform