zuko.utils¶
General purpose helpers.
Classes¶
A version of |
Functions¶
Applies the bisection method to find \(x\) between the bounds \(a\) and \(b\) such that \(f_\phi(x)\) is close to \(y\). |
|
Broadcasts tensors together. |
|
Estimates the definite integral of a function \(f_\phi(x)\) from \(a\) to \(b\) using a \(n\)-point Gauss-Legendre quadrature. |
|
Integrates a system of first-order ordinary differential equations (ODEs) |
|
Unpacks a packed tensor. |
Descriptions¶
- class zuko.utils.Partial(f, /, *args, buffer=False, **kwargs)¶[source]
A version of
functools.partialthat is atorch.nn.Module.- Parameters:
f (Callable) – An arbitrary callable. If
fis a module, it is registered as a submodule.args – The positional arguments passed to
f.buffer (bool) – Whether tensor arguments are registered as buffers or parameters.
kwargs – The keyword arguments passed to
f.
Examples
>>> increment = Partial(torch.add, torch.tensor(1.0), buffer=True) >>> increment(torch.arange(3)) tensor([1., 2., 3.])
>>> weight = torch.randn((5, 3)) >>> linear = Partial(torch.nn.functional.linear, weight=weight) >>> x = torch.rand(2, 3) >>> linear(x) tensor([[-0.1364, -0.4034, 0.1887, -0.2045, -0.0151], [-2.0380, -1.5081, -0.4816, 0.0323, -0.7941]], grad_fn=<MmBackward0>)
>>> f = torch.distributions.Normal >>> loc, scale = torch.zeros(3), torch.ones(3) >>> dist = Partial(f, loc, scale, buffer=True) >>> dist() Normal(loc: torch.Size([3]), scale: torch.Size([3])) >>> dist().sample() tensor([ 0.1227, 0.1494, -0.6709])
- zuko.utils.bisection(f, y, a, b, n=16, phi=())¶[source]
Applies the bisection method to find \(x\) between the bounds \(a\) and \(b\) such that \(f_\phi(x)\) is close to \(y\).
Gradients are propagated through \(y\) and \(\phi\) via implicit differentiation.
Wikipedia
https://wikipedia.org/wiki/Bisection_method
- Parameters:
f (Callable[[Tensor], Tensor]) – A univariate function \(f_\phi\).
y (Tensor) – The target \(y\).
a (float | Tensor) – The bound \(a\) such that \(f_\phi(a) \leq y\).
b (float | Tensor) – The bound \(b\) such that \(y \leq f_\phi(b)\).
n (int) – The number of iterations.
phi (Iterable[Tensor]) – The parameters \(\phi\) of \(f_\phi\).
- Returns:
The solution \(x\).
- Return type:
Example
>>> f = torch.cos >>> y = torch.tensor(0.0) >>> bisection(f, y, 2.0, 1.0, n=16) tensor(1.5708)
- zuko.utils.broadcast(*tensors, ignore=0)¶[source]
Broadcasts tensors together.
The term broadcasting describes how PyTorch treats tensors with different shapes during arithmetic operations. In short, if possible, dimensions that have different sizes are expanded (without making copies) to be compatible.
- Parameters:
- Returns:
The broadcasted tensors.
- Return type:
Example
>>> x = torch.rand(3, 1, 2) >>> y = torch.rand(4, 5) >>> x, y = broadcast(x, y, ignore=1) >>> x.shape torch.Size([3, 4, 2]) >>> y.shape torch.Size([3, 4, 5])
- zuko.utils.gauss_legendre(f, a, b, n=3, phi=())¶[source]
Estimates the definite integral of a function \(f_\phi(x)\) from \(a\) to \(b\) using a \(n\)-point Gauss-Legendre quadrature.
\[\int_a^b f_\phi(x) ~ dx \approx (b - a) \sum_{i = 1}^n w_i f_\phi(x_i)\]Wikipedia
https://wikipedia.org/wiki/Gauss-Legendre_quadrature
- Parameters:
- Returns:
The definite integral estimation.
- Return type:
Example
>>> f = lambda x: torch.exp(-x**2) >>> a, b = torch.tensor([-0.69, 4.2]) >>> gauss_legendre(f, a, b, n=16) tensor(1.4807)
- zuko.utils.odeint(f, x, t0, t1, phi=(), atol=1e-06, rtol=1e-05)¶[source]
Integrates a system of first-order ordinary differential equations (ODEs)
\[\frac{dx}{dt} = f_\phi(t, x) ,\]from \(t_0\) to \(t_1\) using the adaptive Dormand-Prince method. The output is the final state
\[x(t_1) = x_0 + \int_{t_0}^{t_1} f_\phi(t, x(t)) ~ dt .\]Gradients are propagated through \(x_0\), \(t_0\), \(t_1\) and \(\phi\) via the adaptive checkpoint adjoint (ACA) method.
References
Neural Ordinary Differential Equations (Chen el al., 2018)Adaptive Checkpoint Adjoint Method for Gradient Estimation in Neural ODE (Zhuang et al., 2020)- Parameters:
- Returns:
The final state \(x(t_1)\).
- Return type:
Example
>>> A = torch.randn(3, 3) >>> f = lambda t, x: x @ A >>> x0 = torch.randn(3) >>> x1 = odeint(f, x0, 0.0, 1.0) >>> x1 tensor([-1.4596, 0.5008, 1.5828])