zuko.utils#
General purpose helpers.
Functions#
Applies the bisection method to find \(x\) between the bounds \(a\) and \(b\) such that \(f_\phi(x)\) is close to \(y\). |
|
Broadcasts tensors together. |
|
Estimates the definite integral of a function \(f_\phi(x)\) from \(a\) to \(b\) using a \(n\)-point Gauss-Legendre quadrature. |
|
Integrates a system of first-order ordinary differential equations (ODEs) |
|
Unpacks a packed tensor. |
Descriptions#
- zuko.utils.bisection(f, y, a, b, n=16, phi=())#
Applies the bisection method to find \(x\) between the bounds \(a\) and \(b\) such that \(f_\phi(x)\) is close to \(y\).
Gradients are propagated through \(y\) and \(\phi\) via implicit differentiation.
Wikipedia
https://wikipedia.org/wiki/Bisection_method
- Parameters:
f (Callable[[Tensor], Tensor]) – A univariate function \(f_\phi\).
y (Tensor) – The target \(y\).
a (float | Tensor) – The bound \(a\) such that \(f_\phi(a) \leq y\).
b (float | Tensor) – The bound \(b\) such that \(y \leq f_\phi(b)\).
n (int) – The number of iterations.
phi (Iterable[Tensor]) – The parameters \(\phi\) of \(f_\phi\).
- Returns:
The solution \(x\).
- Return type:
Example
>>> f = torch.cos >>> y = torch.tensor(0.0) >>> bisection(f, y, 2.0, 1.0, n=16) tensor(1.5708)
- zuko.utils.broadcast(*tensors, ignore=0)#
Broadcasts tensors together.
The term broadcasting describes how PyTorch treats tensors with different shapes during arithmetic operations. In short, if possible, dimensions that have different sizes are expanded (without making copies) to be compatible.
- Parameters:
- Returns:
The broadcasted tensors.
- Return type:
Example
>>> x = torch.rand(3, 1, 2) >>> y = torch.rand(4, 5) >>> x, y = broadcast(x, y, ignore=1) >>> x.shape torch.Size([3, 4, 2]) >>> y.shape torch.Size([3, 4, 5])
- zuko.utils.gauss_legendre(f, a, b, n=3, phi=())#
Estimates the definite integral of a function \(f_\phi(x)\) from \(a\) to \(b\) using a \(n\)-point Gauss-Legendre quadrature.
\[\int_a^b f_\phi(x) ~ dx \approx (b - a) \sum_{i = 1}^n w_i f_\phi(x_i)\]Wikipedia
https://wikipedia.org/wiki/Gauss-Legendre_quadrature
- Parameters:
- Returns:
The definite integral estimation.
- Return type:
Example
>>> f = lambda x: torch.exp(-x**2) >>> a, b = torch.tensor([-0.69, 4.2]) >>> gauss_legendre(f, a, b, n=16) tensor(1.4807)
- zuko.utils.odeint(f, x, t0, t1, phi=(), atol=1e-06, rtol=1e-05)#
Integrates a system of first-order ordinary differential equations (ODEs)
\[\frac{dx}{dt} = f_\phi(t, x) ,\]from \(t_0\) to \(t_1\) using the adaptive Dormand-Prince method. The output is the final state
\[x(t_1) = x_0 + \int_{t_0}^{t_1} f_\phi(t, x(t)) ~ dt .\]Gradients are propagated through \(x_0\), \(t_0\), \(t_1\) and \(\phi\) via the adaptive checkpoint adjoint (ACA) method.
References
Neural Ordinary Differential Equations (Chen el al., 2018)Adaptive Checkpoint Adjoint Method for Gradient Estimation in Neural ODE (Zhuang et al., 2020)- Parameters:
- Returns:
The final state \(x(t_1)\).
- Return type:
Example
>>> A = torch.randn(3, 3) >>> f = lambda t, x: x @ A >>> x0 = torch.randn(3) >>> x1 = odeint(f, x0, 0.0, 1.0) >>> x1 tensor([-3.7454, -0.4140, 0.2677])
- zuko.utils.unpack(x, shapes)#
Unpacks a packed tensor.
- Parameters:
- Returns:
The unpacked tensors, with shapes \((*, S_i)\).
- Return type:
Example
>>> x = torch.randn(26) >>> y, z = unpack(x, ((1, 2, 3), (4, 5))) >>> y.shape torch.Size([1, 2, 3]) >>> z.shape torch.Size([4, 5])