zuko.distributions#

Parameterizable probability distributions.

Classes#

NormalizingFlow

Creates a normalizing flow for a random variable \(X\) towards a base distribution \(p(Z)\) through a transformation \(f\).

Joint

Creates a distribution for a multivariate random variable \(X\) which is the concatenation of \(n\) independent random variables \(Z_i\).

Mixture

Creates a mixture of distributions for a random variable \(X\).

GeneralizedNormal

Creates a generalized normal distribution.

DiagNormal

Creates a multivariate normal distribution parametrized by the variables mean \(\mu\) and standard deviation \(\sigma\), but assumes no correlation between the variables.

BoxUniform

Creates a distribution for a multivariate random variable \(X\) distributed uniformly over an hypercube domain.

TransformedUniform

Creates a distribution for a random variable \(X\), whose transformation \(f(X)\) is uniformly distributed over the interval \([f(l), f(u)]\).

Truncated

Truncates a base distribution \(p(X)\) between a lower bound \(l\) and an upper bound \(u\).

Sort

Creates a distribution for a \(n\)-d random variable \(X\), whose elements \(X_i\) are \(n\) draws from a base distribution \(p(Z)\), ordered such that \(X_i \leq X_{i + 1}\).

TopK

Creates a distribution for a \(k\)-d random variable \(X\), whose elements \(X_i\) are the top \(k\) among \(n\) draws from a base distribution \(p(Z)\), ordered such that \(X_i \leq X_{i + 1}\).

Minimum

Creates a distribution for a random variable \(X\), which is the minimum among \(n\) draws from a base distribution \(p(Z)\).

Maximum

Creates a distribution for a random variable \(X\), which is the maximum among \(n\) draws from a base distribution \(p(Z)\).

Descriptions#

class zuko.distributions.NormalizingFlow(transform, base)#

Creates a normalizing flow for a random variable \(X\) towards a base distribution \(p(Z)\) through a transformation \(f\).

The density of a realization \(x\) is given by the change of variables

\[p(X = x) = p(Z = f(x)) \left| \det \frac{\partial f(x)}{\partial x} \right| .\]

To sample from \(p(X)\), realizations \(z \sim p(Z)\) are mapped through the inverse transformation \(g = f^{-1}\).

References

A Family of Non-parametric Density Estimation Algorithms (Tabak et al., 2013)
Variational Inference with Normalizing Flows (Rezende et al., 2015)
Normalizing Flows for Probabilistic Modeling and Inference (Papamakarios et al., 2021)
Parameters:
  • transform (Transform) – A transformation \(f\).

  • base (Distribution) – A base distribution \(p(Z)\).

Example

>>> d = NormalizingFlow(ExpTransform(), Gamma(2.0, 1.0))
>>> d.sample()
tensor(1.1316)
class zuko.distributions.Joint(*marginals)#

Creates a distribution for a multivariate random variable \(X\) which is the concatenation of \(n\) independent random variables \(Z_i\).

\[p(X = x) = \prod_i p(Z_i = x_i)\]
Parameters:

marginals (Distribution) – A list of distributions \(p(Z_i)\).

Example

>>> d = Joint(Uniform(0.0, 1.0), Normal(0.0, 1.0))
>>> d.event_shape
torch.Size([2])
>>> d.sample()
tensor([ 0.8969, -2.6717])
class zuko.distributions.Mixture(base, logits)#

Creates a mixture of distributions for a random variable \(X\).

\[p(X = x) = \frac{1}{\sum_i w_i} \sum_i w_i \, p(Z_i = x)\]

Wikipedia

https://wikipedia.org/wiki/Mixture_model

Parameters:
  • base (Distribution) – A batch of base distributions \(p(Z_i)\).

  • logits (Tensor) – The unnormalized log-weights \(\log w_i\).

Example

>>> d = Mixture(Normal(torch.randn(2), torch.ones(2)), torch.randn(2))
>>> d.event_shape
torch.Size([])
>>> d.sample()
tensor(2.8732)
class zuko.distributions.GeneralizedNormal(beta)#

Creates a generalized normal distribution.

\[p(X = x) = \frac{\beta}{2 \Gamma(1 / \beta)} \exp(-|x|^\beta)\]

Wikipedia

https://wikipedia.org/wiki/Generalized_normal_distribution

Parameters:

beta (Tensor) – The shape parameter \(\beta\).

Example

>>> d = GeneralizedNormal(2.0)
>>> d.sample()
tensor(0.7480)
class zuko.distributions.DiagNormal(loc, scale, ndims=1)#

Creates a multivariate normal distribution parametrized by the variables mean \(\mu\) and standard deviation \(\sigma\), but assumes no correlation between the variables.

Parameters:
  • loc (Tensor) – The mean \(\mu\) of the variables.

  • scale (Tensor) – The standard deviation \(\sigma\) of the variables.

  • ndims (int) – The number of batch dimensions to interpret as event dimensions.

Example

>>> d = DiagNormal(torch.zeros(3), torch.ones(3))
>>> d.event_shape
torch.Size([3])
>>> d.sample()
tensor([ 0.7304, -0.1976, -1.7591])
class zuko.distributions.BoxUniform(lower, upper, ndims=1)#

Creates a distribution for a multivariate random variable \(X\) distributed uniformly over an hypercube domain. Formally,

\[l_i \leq X_i < u_i ,\]

where \(l_i\) and \(u_i\) are respectively the lower and upper bounds of the domain in the \(i\)-th dimension.

Parameters:
  • lower (Tensor) – The lower bounds (inclusive).

  • upper (Tensor) – The upper bounds (exclusive).

  • ndims (int) – The number of batch dimensions to interpret as event dimensions.

Example

>>> d = BoxUniform(-torch.ones(3), torch.ones(3))
>>> d.event_shape
torch.Size([3])
>>> d.sample()
tensor([ 0.1859, -0.9698,  0.0665])
class zuko.distributions.TransformedUniform(f, lower, upper)#

Creates a distribution for a random variable \(X\), whose transformation \(f(X)\) is uniformly distributed over the interval \([f(l), f(u)]\).

\[\begin{split}p(X = x) = \frac{1}{f(u) - f(l)} \begin{cases} f'(x) & \text{if } f(l) \leq f(x) < f(u) \\ 0 & \text{otherwise} \end{cases}\end{split}\]
Parameters:
  • f (Transform) – A transformation \(f\), monotonically increasing over \([l, u]\).

  • lower (Tensor) – A lower bound \(l\) (inclusive).

  • upper (Tensor) – An upper bound \(u\) (exclusive).

Example

>>> d = TransformedUniform(ExpTransform(), -1.0, 1.0)
>>> d.sample()
tensor(0.5594)
class zuko.distributions.Truncated(base, lower=-inf, upper=inf)#

Truncates a base distribution \(p(X)\) between a lower bound \(l\) and an upper bound \(u\).

\[\begin{split}p(X = x | l \leq X < u) = \frac{1}{P(X \leq u) - P(X \leq l)} \begin{cases} p(X = x) & \text{if } l \leq x < u \\ 0 & \text{otherwise} \end{cases}\end{split}\]
Parameters:
  • base (Distribution) – A base distribution \(p(X)\).

  • lower (Tensor) – A lower bound \(l\) (inclusive).

  • upper (Tensor) – An upper bound \(u\) (exclusive).

Example

>>> d = Truncated(Normal(0.0, 1.0), 1.0, 2.0)
>>> d.sample()
tensor(1.2573)
class zuko.distributions.Sort(base, n=2, descending=False)#

Creates a distribution for a \(n\)-d random variable \(X\), whose elements \(X_i\) are \(n\) draws from a base distribution \(p(Z)\), ordered such that \(X_i \leq X_{i + 1}\).

\[\begin{split}p(X = x) = \begin{cases} n! \, \prod_{i = 1}^n p(Z = x_i) & \text{if $x$ is ordered} \\ 0 & \text{otherwise} \end{cases}\end{split}\]
Parameters:
  • base (Distribution) – A base distribution \(p(Z)\).

  • n (int) – The number of draws \(n\).

  • descending (bool) – Whether the elements are sorted in descending order or not.

Example

>>> d = Sort(Normal(0.0, 1.0), 3)
>>> d.event_shape
torch.Size([3])
>>> d.sample()
tensor([-1.4434, -0.3861,  0.2439])
class zuko.distributions.TopK(base, k=1, n=2, **kwargs)#

Creates a distribution for a \(k\)-d random variable \(X\), whose elements \(X_i\) are the top \(k\) among \(n\) draws from a base distribution \(p(Z)\), ordered such that \(X_i \leq X_{i + 1}\).

\[\begin{split}p(X = x) = \begin{cases} \frac{n!}{(n - k)!} \, \prod_{i = 1}^k p(Z = x_i) \, P(Z \geq x_k)^{n - k} & \text{if $x$ is ordered} \\ 0 & \text{otherwise} \end{cases}\end{split}\]
Parameters:
  • base (Distribution) – A base distribution \(p(Z)\).

  • k (int) – The number of selected elements \(k\).

  • n (int) – The number of draws \(n\).

  • kwargs – Keyword arguments passed to Sort.

Example

>>> d = TopK(Normal(0.0, 1.0), 2, 3)
>>> d.event_shape
torch.Size([2])
>>> d.sample()
tensor([-0.2167,  0.6739])
class zuko.distributions.Minimum(base, n=2)#

Creates a distribution for a random variable \(X\), which is the minimum among \(n\) draws from a base distribution \(p(Z)\).

\[p(X = x) = n \, p(Z = x) \, P(Z \geq x)^{n - 1}\]
Parameters:
  • base (Distribution) – A base distribution \(p(Z)\).

  • n (int) – The number of draws \(n\).

Example

>>> d = Minimum(Normal(0.0, 1.0), 3)
>>> d.event_shape
torch.Size([])
>>> d.sample()
tensor(-1.7552)
class zuko.distributions.Maximum(base, n=2)#

Creates a distribution for a random variable \(X\), which is the maximum among \(n\) draws from a base distribution \(p(Z)\).

\[p(X = x) = n \, p(Z = x) \, P(Z \leq x)^{n - 1}\]
Parameters:
  • base (Distribution) – A base distribution \(p(Z)\).

  • n (int) – The number of draws \(n\).

Example

>>> d = Maximum(Normal(0.0, 1.0), 3)
>>> d.event_shape
torch.Size([])
>>> d.sample()
tensor(1.1644)