Extras for diffrax.
DiffEqSolver
: an object-oriented interface todiffrax.diffeqsolve
.VectorizedDenseInterpolation
: a vectorized form ofdiffrax.DenseInterpolation
that works on batched results fromdiffrax.diffeqsolve
.
For example,
import jax.numpy as jnp
import diffrax as dfx
from diffraxtra import DiffEqSolver
# Construct a solver object.
solver = DiffEqSolver(dfx.Dopri5(),
stepsize_controller=dfx.PIDController(rtol=1e-5, atol=1e-5))
# And a differential equation to solve.
term = dfx.ODETerm(lambda t, y, args: -y)
# Then solve the differential equation.
saveat = dfx.SaveAt(t1=True, dense=True)
soln = solver(term, t0=0, t1=3, dt0=0.1, y0=1, saveat=saveat,
vectorize_interpolation=True)
print(soln)
# Solution(
# t0=f32[], t1=f32[], ts=f32[1],
# ys=f32[1],
# interpolation=VectorizedDenseInterpolation(
# scalar_interpolation=DenseInterpolation( ... ),
# batch_shape=(),
# y0_shape=()
# ),
# ...
# )
soln.evaluate(jnp.array([0.1, 0.2, 0.3, 0.4]).reshape(2, 2))
# Array([[0.90483742, 0.81872516],
# [0.74080871, 0.67031456]], dtype=float64)
pip install diffraxtra
>>> import jax.numpy as jnp
>>> import diffrax as dfx
>>> from diffraxtra import DiffEqSolver
Construct a solver object.
>>> solver = DiffEqSolver(dfx.Dopri5(),
... stepsize_controller=dfx.PIDController(rtol=1e-5, atol=1e-5))
And a differential equation to solve.
>>> term = dfx.ODETerm(lambda t, y, args: -y)
Then solve the differential equation.
>>> soln = solver(term, t0=0, t1=3, dt0=0.1, y0=1)
>>> soln
Solution( t0=f64[], t1=f64[], ts=f64[1],
ys=f64[1], ... )
The solution can be saved at specific times.
>>> saveat = dfx.SaveAt(ts=[0., 1., 2., 3.])
>>> soln = solver(term, t0=0, t1=3, dt0=0.1, y0=1, saveat=saveat)
>>> soln
Solution( t0=f64[], t1=f64[], ts=f64[4],
ys=f64[4], ... )
The solution can be densely interpolated.
>>> saveat = dfx.SaveAt(t1=True, dense=True)
>>> soln = solver(term, t0=0, t1=3, dt0=0.1, y0=1, saveat=saveat)
>>> soln
Solution( t0=f64[], t1=f64[], ts=f64[1],
ys=f64[1], ... )
>>> soln.evaluate(0.5).round(3)
Array(0.607, dtype=float64)
Using the VectorizedDenseInterpolation
class, the interpolation can be
vectorized, enabling evaluation of batched solutions over batches of times.
>>> from diffraxtra import VectorizedDenseInterpolation
>>> soln = solver(term, t0=0, t1=3, dt0=0.1, y0=1, saveat=saveat)
>>> soln = VectorizedDenseInterpolation.apply_to_solution(soln)
>>> soln.evaluate(jnp.array([0.1, 0.2, 0.3, 0.4]).reshape(2, 2))
Array([[0.90483742, 0.81872516],
[0.74080871, 0.67031456]], dtype=float64)
This can be more conveniently done using the vectorize_interpolation
argument.
>>> soln = solver(term, t0=0, t1=3, dt0=0.1, y0=1, saveat=saveat,
... vectorize_interpolation=True)
>>> soln.evaluate(jnp.array([0.1, 0.2, 0.3, 0.4]).reshape(2, 2))
Array([[0.90483742, 0.81872516],
[0.74080871, 0.67031456]], dtype=float64)
There are many ways to construct a DiffEqSolver
object. For example, we can
can make a new one from an existing DiffEqSolver
object
>>> solver = DiffEqSolver(dfx.Dopri5())
>>> DiffEqSolver.from_(solver) is solver
True
From a diffrax.AbstractSolver
object.
>>> solver = DiffEqSolver.from_(dfx.Dopri5())
>>> solver
DiffEqSolver(solver=Dopri5())
(Where all other arguments are their default values and printed only if changed.)
From a collections.abc.Mapping
>>> solver = DiffEqSolver.from_({"solver": dfx.Dopri5(),
... "stepsize_controller": dfx.PIDController(rtol=1e-5, atol=1e-5)})
>>> solver
DiffEqSolver(
solver=Dopri5(), stepsize_controller=PIDController(rtol=1e-05, atol=1e-05)
)
For a full enumeration of the ways to construct a DiffEqSolver
object, see
diffraxtra.DiffEqSolver.from_
.
Vectorized wrapper around a diffrax.DenseInterpolation
This also works on non-batched interpolations.
>>> import jax
>>> import jax.numpy as jnp
>>> import diffrax as dfx
We'll start with a non-batched interpolation:
>>> vector_field = lambda t, y, args: -y
>>> term = dfx.ODETerm(vector_field)
>>> solver = dfx.Dopri5()
>>> ts = jnp.array([0.0, 1, 2, 3])
>>> saveat = dfx.SaveAt(ts=ts, dense=True)
>>> stepsize_controller = dfx.PIDController(rtol=1e-5, atol=1e-5)
>>> sol = dfx.diffeqsolve(
... term, solver, t0=0, t1=3, dt0=0.1, y0=1, saveat=saveat,
... stepsize_controller=stepsize_controller)
>>> interp = VectorizedDenseInterpolation(sol.interpolation)
>>> interp # doctest: +SKIP
VectorizedDenseInterpolation(
scalar_interpolation=DenseInterpolation(
ts=f64[1,4097],
ts_size=weak_i64[1],
infos={'k': f64[1,4096,7], 'y0': f64[1,4096], 'y1': f64[1,4096]},
interpolation_cls=<class 'diffrax._solver.dopri5._Dopri5Interpolation'>,
direction=weak_i64[1],
t0_if_trivial=f64[1],
y0_if_trivial=f64[1]
),
batch_shape=(),
y0_shape=()
)
This can be evaluated by the normal means:
>>> interp.evaluate(ts[-1]) # scalar evaluation
Array(0.04978961, dtype=float64)
It also works on arrays, without needed to manually apply jax.vmap
:
>>> interp.evaluate(ts) # It works on arrays!
Array([1. , 0.36788338, 0.13533922, 0.04978961], dtype=float64)
>>> interp.evaluate(ts, ts[0]) # t1 - t0 mixed scalar and array
Array([0. , 0.63211662, 0.86466078, 0.95021039], dtype=float64)
Better yet, the time array may be arbitrarily shaped:
>>> interp.evaluate(ts.reshape(2, 2)).round(3)
Array([[1. , 0.368],
[0.135, 0.05 ]], dtype=float64)
As a convenience, we can also apply the VectorizedDenseInterpolation
to the
solution to modify the interpolation "in-place" (when in a jitted context,
otherwise out-of-place, returning a copy):
>>> sol = VectorizedDenseInterpolation.apply_to_solution(sol)
>>> isinstance(sol, dfx.Solution)
True
>>> isinstance(sol.interpolation, VectorizedDenseInterpolation)
True
Now we'll batch the interpolation:
>>> @jax.vmap
... def solve(y0):
... sol = dfx.diffeqsolve(
... term, solver, t0=0, t1=3, dt0=0.1, y0=y0, saveat=saveat,
... stepsize_controller=stepsize_controller)
... return sol
>>> sol = solve(jnp.array([1, 2, 3]))
>>> interp = VectorizedDenseInterpolation(sol.interpolation)
>>> interp.evaluate(ts[-1]).round(3) # scalar eval of batched interp
Array([0.05 , 0.1 , 0.149], dtype=float64)
>>> interp.evaluate(ts).astype(jnp.float64).round(3) # array eval of batched interp
Array([[1. , 0.368, 0.135, 0.05 ],
[2. , 0.736, 0.271, 0.1 ],
[3. , 1.104, 0.406, 0.149]], dtype=float64)
>>> interp.evaluate(ts, ts[0]).round(3) # mixed scalar and array eval
Array([[0. , 0.632, 0.865, 0.95 ],
[0. , 1.264, 1.729, 1.9 ],
[0. , 1.896, 2.594, 2.851]], dtype=float64)
>>> ys = interp.evaluate(ts.reshape(2, 2)).round(3) # arbitrary shape eval
>>> ys
Array([[[1. , 0.368],
[0.135, 0.05 ]],
[[2. , 0.736],
[0.271, 0.1 ]],
[[3. , 1.104],
[0.406, 0.149]]], dtype=float64)
>>> ys.shape # (batch, *times)
(3, 2, 2)
If you enjoyed using this library and would like to cite the software you use then click the link above.
We welcome contributions!