zuko.utils#
General purpose helpers.
Functions#
Applies the bisection method to find \(x\) between the bounds \(a\) and \(b\) such that \(f_\phi(x)\) is close to \(y\). |
|
Broadcasts tensors together. |
|
Estimates the definite integral of a function \(f_\phi(x)\) from \(a\) to \(b\) using a \(n\)-point Gauss-Legendre quadrature. |
|
Integrates a system of first-order ordinary differential equations (ODEs) |
Descriptions#
- zuko.utils.bisection(f, y, a, b, n=16, phi=())#
Applies the bisection method to find \(x\) between the bounds \(a\) and \(b\) such that \(f_\phi(x)\) is close to \(y\).
Gradients are propagated through \(y\) and \(\phi\) via implicit differentiation.
Wikipedia
https://wikipedia.org/wiki/Bisection_method
- Parameters
f (Callable[[Tensor], Tensor]) – A univariate function \(f_\phi\).
y (Tensor) – The target \(y\).
a (Union[float, Tensor]) – The bound \(a\) such that \(f_\phi(a) \leq y\).
b (Union[float, Tensor]) – The bound \(b\) such that \(y \leq f_\phi(b)\).
n (int) – The number of iterations.
phi (Iterable[Tensor]) – The parameters \(\phi\) of \(f_\phi\).
- Returns
The solution \(x\).
- Return type
Example
>>> f = torch.cos >>> y = torch.tensor(0.0) >>> bisection(f, y, 2.0, 1.0, n=16) tensor(1.5708)
- zuko.utils.broadcast(*tensors, ignore=0)#
Broadcasts tensors together.
The term broadcasting describes how PyTorch treats tensors with different shapes during arithmetic operations. In short, if possible, dimensions that have different sizes are expanded (without making copies) to be compatible.
- Parameters
- Returns
The broadcasted tensors.
- Return type
Example
>>> x = torch.rand(3, 1, 2) >>> y = torch.rand(4, 5) >>> x, y = broadcast(x, y, ignore=1) >>> x.shape torch.Size([3, 4, 2]) >>> y.shape torch.Size([3, 4, 5])
- zuko.utils.gauss_legendre(f, a, b, n=3, phi=())#
Estimates the definite integral of a function \(f_\phi(x)\) from \(a\) to \(b\) using a \(n\)-point Gauss-Legendre quadrature.
\[\int_a^b f_\phi(x) ~ dx \approx (b - a) \sum_{i = 1}^n w_i f_\phi(x_i)\]Wikipedia
https://wikipedia.org/wiki/Gauss-Legendre_quadrature
- Parameters
- Returns
The definite integral estimation.
- Return type
Example
>>> f = lambda x: torch.exp(-x**2) >>> a, b = torch.tensor([-0.69, 4.2]) >>> gauss_legendre(f, a, b, n=16) tensor(1.4807)
- zuko.utils.odeint(f, x, t0, t1, phi=())#
Integrates a system of first-order ordinary differential equations (ODEs)
\[\frac{dx}{dt} = f_\phi(t, x) ,\]from \(t_0\) to \(t_1\) using the adaptive Dormand-Prince method. The output is the final state
\[x(t_1) = x_0 + \int_{t_0}^{t_1} f_\phi(t, x(t)) ~ dt .\]Gradients are propagated through \(x_0\), \(t_0\), \(t_1\) and \(\phi\) via the adaptive checkpoint adjoint (ACA) method.
References
Neural Ordinary Differential Equations (Chen el al., 2018)Adaptive Checkpoint Adjoint Method for Gradient Estimation in Neural ODE (Zhuang et al., 2020)- Parameters
f (Callable[[Tensor, Tensor], Tensor]) – A system of first-order ODEs \(f_\phi\).
x (Union[Tensor, Sequence[Tensor]]) – The initial state \(x_0\).
t0 (Union[float, Tensor]) – The initial integration time \(t_0\).
t1 (Union[float, Tensor]) – The final integration time \(t_1\).
phi (Iterable[Tensor]) – The parameters \(\phi\) of \(f_\phi\).
- Returns
The final state \(x(t_1)\).
- Return type
Example
>>> A = torch.randn(3, 3) >>> f = lambda t, x: x @ A >>> x0 = torch.randn(3) >>> x1 = odeint(f, x0, 0.0, 1.0) >>> x1 tensor([-3.7454, -0.4140, 0.2677])