zuko.utils#

General purpose helpers.

Functions#

bisection

Applies the bisection method to find \(x\) between the bounds \(a\) and \(b\) such that \(f_\phi(x)\) is close to \(y\).

broadcast

Broadcasts tensors together.

gauss_legendre

Estimates the definite integral of a function \(f_\phi(x)\) from \(a\) to \(b\) using a \(n\)-point Gauss-Legendre quadrature.

odeint

Integrates a system of first-order ordinary differential equations (ODEs)

unpack

Unpacks a packed tensor.

Descriptions#

zuko.utils.bisection(f, y, a, b, n=16, phi=())#

Applies the bisection method to find \(x\) between the bounds \(a\) and \(b\) such that \(f_\phi(x)\) is close to \(y\).

Gradients are propagated through \(y\) and \(\phi\) via implicit differentiation.

Wikipedia

https://wikipedia.org/wiki/Bisection_method

Parameters:
  • f (Callable[[Tensor], Tensor]) – A univariate function \(f_\phi\).

  • y (Tensor) – The target \(y\).

  • a (float | Tensor) – The bound \(a\) such that \(f_\phi(a) \leq y\).

  • b (float | Tensor) – The bound \(b\) such that \(y \leq f_\phi(b)\).

  • n (int) – The number of iterations.

  • phi (Iterable[Tensor]) – The parameters \(\phi\) of \(f_\phi\).

Returns:

The solution \(x\).

Return type:

Tensor

Example

>>> f = torch.cos
>>> y = torch.tensor(0.0)
>>> bisection(f, y, 2.0, 1.0, n=16)
tensor(1.5708)
zuko.utils.broadcast(*tensors, ignore=0)#

Broadcasts tensors together.

The term broadcasting describes how PyTorch treats tensors with different shapes during arithmetic operations. In short, if possible, dimensions that have different sizes are expanded (without making copies) to be compatible.

Parameters:
  • tensors (Tensor) – The tensors to broadcast.

  • ignore (int | Sequence[int]) – The number(s) of dimensions not to broadcast.

Returns:

The broadcasted tensors.

Return type:

List[Tensor]

Example

>>> x = torch.rand(3, 1, 2)
>>> y = torch.rand(4, 5)
>>> x, y = broadcast(x, y, ignore=1)
>>> x.shape
torch.Size([3, 4, 2])
>>> y.shape
torch.Size([3, 4, 5])
zuko.utils.gauss_legendre(f, a, b, n=3, phi=())#

Estimates the definite integral of a function \(f_\phi(x)\) from \(a\) to \(b\) using a \(n\)-point Gauss-Legendre quadrature.

\[\int_a^b f_\phi(x) ~ dx \approx (b - a) \sum_{i = 1}^n w_i f_\phi(x_i)\]

Wikipedia

https://wikipedia.org/wiki/Gauss-Legendre_quadrature

Parameters:
  • f (Callable[[Tensor], Tensor]) – A univariate function \(f_\phi\).

  • a (Tensor) – The lower limit \(a\).

  • b (Tensor) – The upper limit \(b\).

  • n (int) – The number of points \(n\) at which the function is evaluated.

  • phi (Iterable[Tensor]) – The parameters \(\phi\) of \(f_\phi\).

Returns:

The definite integral estimation.

Return type:

Tensor

Example

>>> f = lambda x: torch.exp(-x**2)
>>> a, b = torch.tensor([-0.69, 4.2])
>>> gauss_legendre(f, a, b, n=16)
tensor(1.4807)
zuko.utils.odeint(f, x, t0, t1, phi=(), atol=1e-06, rtol=1e-05)#

Integrates a system of first-order ordinary differential equations (ODEs)

\[\frac{dx}{dt} = f_\phi(t, x) ,\]

from \(t_0\) to \(t_1\) using the adaptive Dormand-Prince method. The output is the final state

\[x(t_1) = x_0 + \int_{t_0}^{t_1} f_\phi(t, x(t)) ~ dt .\]

Gradients are propagated through \(x_0\), \(t_0\), \(t_1\) and \(\phi\) via the adaptive checkpoint adjoint (ACA) method.

References

Neural Ordinary Differential Equations (Chen el al., 2018)
Adaptive Checkpoint Adjoint Method for Gradient Estimation in Neural ODE (Zhuang et al., 2020)
Parameters:
Returns:

The final state \(x(t_1)\).

Return type:

Tensor | Sequence[Tensor]

Example

>>> A = torch.randn(3, 3)
>>> f = lambda t, x: x @ A
>>> x0 = torch.randn(3)
>>> x1 = odeint(f, x0, 0.0, 1.0)
>>> x1
tensor([-3.7454, -0.4140,  0.2677])
zuko.utils.unpack(x, shapes)#

Unpacks a packed tensor.

Parameters:
  • x (Tensor) – A packed tensor, with shape \((*, D)\).

  • shapes (Sequence[Size]) – A sequence of shapes \(S_i\), corresponding to the total number of elements \(D\).

Returns:

The unpacked tensors, with shapes \((*, S_i)\).

Return type:

Sequence[Tensor]

Example

>>> x = torch.randn(26)
>>> y, z = unpack(x, ((1, 2, 3), (4, 5)))
>>> y.shape
torch.Size([1, 2, 3])
>>> z.shape
torch.Size([4, 5])