Skip to content

GalacticDynamics/diffraxtra

Repository files navigation

diffraxtra

diffrax extras

PyPI: diffraxtra PyPI versions: diffraxtra diffraxtra license

CI status codecov ruff ruff pre-commit


Extras for diffrax.

  • DiffEqSolver: an object-oriented interface to diffrax.diffeqsolve.
  • VectorizedDenseInterpolation: a vectorized form of diffrax.DenseInterpolation that works on batched results from diffrax.diffeqsolve.

For example,

import jax.numpy as jnp
import diffrax as dfx
from diffraxtra import DiffEqSolver

# Construct a solver object.
solver = DiffEqSolver(dfx.Dopri5(),
                      stepsize_controller=dfx.PIDController(rtol=1e-5, atol=1e-5))

# And a differential equation to solve.
term = dfx.ODETerm(lambda t, y, args: -y)

# Then solve the differential equation.
saveat = dfx.SaveAt(t1=True, dense=True)
soln = solver(term, t0=0, t1=3, dt0=0.1, y0=1, saveat=saveat,
              vectorize_interpolation=True)

print(soln)
# Solution(
#   t0=f32[], t1=f32[], ts=f32[1],
#   ys=f32[1],
#   interpolation=VectorizedDenseInterpolation(
#     scalar_interpolation=DenseInterpolation( ... ),
#     batch_shape=(),
#     y0_shape=()
#   ),
#   ...
# )

soln.evaluate(jnp.array([0.1, 0.2, 0.3, 0.4]).reshape(2, 2))
# Array([[0.90483742, 0.81872516],
#         [0.74080871, 0.67031456]], dtype=float64)

Installation

PyPI platforms PyPI version

pip install diffraxtra

Documentation

DiffEqSolver

>>> import jax.numpy as jnp
>>> import diffrax as dfx
>>> from diffraxtra import DiffEqSolver

Construct a solver object.

>>> solver = DiffEqSolver(dfx.Dopri5(),
...                stepsize_controller=dfx.PIDController(rtol=1e-5, atol=1e-5))

And a differential equation to solve.

>>> term = dfx.ODETerm(lambda t, y, args: -y)

Then solve the differential equation.

>>> soln = solver(term, t0=0, t1=3, dt0=0.1, y0=1)
>>> soln
Solution( t0=f64[], t1=f64[], ts=f64[1],
          ys=f64[1], ... )

The solution can be saved at specific times.

>>> saveat = dfx.SaveAt(ts=[0., 1., 2., 3.])
>>> soln = solver(term, t0=0, t1=3, dt0=0.1, y0=1, saveat=saveat)
>>> soln
Solution( t0=f64[], t1=f64[], ts=f64[4],
          ys=f64[4], ... )

The solution can be densely interpolated.

>>> saveat = dfx.SaveAt(t1=True, dense=True)
>>> soln = solver(term, t0=0, t1=3, dt0=0.1, y0=1, saveat=saveat)
>>> soln
Solution( t0=f64[], t1=f64[], ts=f64[1],
          ys=f64[1], ... )
>>> soln.evaluate(0.5).round(3)
Array(0.607, dtype=float64)

Using the VectorizedDenseInterpolation class, the interpolation can be vectorized, enabling evaluation of batched solutions over batches of times.

>>> from diffraxtra import VectorizedDenseInterpolation
>>> soln = solver(term, t0=0, t1=3, dt0=0.1, y0=1, saveat=saveat)
>>> soln = VectorizedDenseInterpolation.apply_to_solution(soln)
>>> soln.evaluate(jnp.array([0.1, 0.2, 0.3, 0.4]).reshape(2, 2))
Array([[0.90483742, 0.81872516],
       [0.74080871, 0.67031456]], dtype=float64)

This can be more conveniently done using the vectorize_interpolation argument.

>>> soln = solver(term, t0=0, t1=3, dt0=0.1, y0=1, saveat=saveat,
...               vectorize_interpolation=True)
>>> soln.evaluate(jnp.array([0.1, 0.2, 0.3, 0.4]).reshape(2, 2))
Array([[0.90483742, 0.81872516],
       [0.74080871, 0.67031456]], dtype=float64)

There are many ways to construct a DiffEqSolver object. For example, we can can make a new one from an existing DiffEqSolver object

>>> solver = DiffEqSolver(dfx.Dopri5())
>>> DiffEqSolver.from_(solver) is solver
True

From a diffrax.AbstractSolver object.

>>> solver = DiffEqSolver.from_(dfx.Dopri5())
>>> solver
DiffEqSolver(solver=Dopri5())

(Where all other arguments are their default values and printed only if changed.)

From a collections.abc.Mapping

>>> solver = DiffEqSolver.from_({"solver": dfx.Dopri5(),
...       "stepsize_controller": dfx.PIDController(rtol=1e-5, atol=1e-5)})
>>> solver
DiffEqSolver(
  solver=Dopri5(), stepsize_controller=PIDController(rtol=1e-05, atol=1e-05)
)

For a full enumeration of the ways to construct a DiffEqSolver object, see diffraxtra.DiffEqSolver.from_.

VectorizedDenseInterpolation

Vectorized wrapper around a diffrax.DenseInterpolation

This also works on non-batched interpolations.

>>> import jax
>>> import jax.numpy as jnp
>>> import diffrax as dfx

We'll start with a non-batched interpolation:

>>> vector_field = lambda t, y, args: -y
>>> term = dfx.ODETerm(vector_field)
>>> solver = dfx.Dopri5()
>>> ts = jnp.array([0.0, 1, 2, 3])
>>> saveat = dfx.SaveAt(ts=ts, dense=True)
>>> stepsize_controller = dfx.PIDController(rtol=1e-5, atol=1e-5)

>>> sol = dfx.diffeqsolve(
...     term, solver, t0=0, t1=3, dt0=0.1, y0=1, saveat=saveat,
...     stepsize_controller=stepsize_controller)
>>> interp = VectorizedDenseInterpolation(sol.interpolation)
>>> interp  # doctest: +SKIP
VectorizedDenseInterpolation(
  scalar_interpolation=DenseInterpolation(
    ts=f64[1,4097],
    ts_size=weak_i64[1],
    infos={'k': f64[1,4096,7], 'y0': f64[1,4096], 'y1': f64[1,4096]},
    interpolation_cls=<class 'diffrax._solver.dopri5._Dopri5Interpolation'>,
    direction=weak_i64[1],
    t0_if_trivial=f64[1],
    y0_if_trivial=f64[1]
  ),
  batch_shape=(),
  y0_shape=()
)

This can be evaluated by the normal means:

>>> interp.evaluate(ts[-1])  # scalar evaluation
Array(0.04978961, dtype=float64)

It also works on arrays, without needed to manually apply jax.vmap:

>>> interp.evaluate(ts)  # It works on arrays!
Array([1. , 0.36788338, 0.13533922, 0.04978961], dtype=float64)
>>> interp.evaluate(ts, ts[0])  # t1 - t0 mixed scalar and array
Array([0. , 0.63211662, 0.86466078, 0.95021039], dtype=float64)

Better yet, the time array may be arbitrarily shaped:

>>> interp.evaluate(ts.reshape(2, 2)).round(3)
Array([[1.   , 0.368],
       [0.135, 0.05 ]], dtype=float64)

As a convenience, we can also apply the VectorizedDenseInterpolation to the solution to modify the interpolation "in-place" (when in a jitted context, otherwise out-of-place, returning a copy):

>>> sol = VectorizedDenseInterpolation.apply_to_solution(sol)
>>> isinstance(sol, dfx.Solution)
True
>>> isinstance(sol.interpolation, VectorizedDenseInterpolation)
True

Now we'll batch the interpolation:

>>> @jax.vmap
... def solve(y0):
...     sol = dfx.diffeqsolve(
...         term, solver, t0=0, t1=3, dt0=0.1, y0=y0, saveat=saveat,
...         stepsize_controller=stepsize_controller)
...     return sol
>>> sol = solve(jnp.array([1, 2, 3]))
>>> interp = VectorizedDenseInterpolation(sol.interpolation)
>>> interp.evaluate(ts[-1]).round(3)  # scalar eval of batched interp
Array([0.05 , 0.1  , 0.149], dtype=float64)
>>> interp.evaluate(ts).astype(jnp.float64).round(3)  # array eval of batched interp
Array([[1.   , 0.368, 0.135, 0.05 ],
       [2.   , 0.736, 0.271, 0.1  ],
       [3.   , 1.104, 0.406, 0.149]], dtype=float64)
>>> interp.evaluate(ts, ts[0]).round(3)  # mixed scalar and array eval
Array([[0.   , 0.632, 0.865, 0.95 ],
       [0.   , 1.264, 1.729, 1.9  ],
       [0.   , 1.896, 2.594, 2.851]], dtype=float64)
>>> ys = interp.evaluate(ts.reshape(2, 2)).round(3)  # arbitrary shape eval
>>> ys
Array([[[1.   , 0.368],
        [0.135, 0.05 ]],
        [[2.   , 0.736],
        [0.271, 0.1  ]],
        [[3.   , 1.104],
        [0.406, 0.149]]], dtype=float64)
>>> ys.shape  # (batch, *times)
(3, 2, 2)

Citation

DOI

If you enjoyed using this library and would like to cite the software you use then click the link above.

Development

Actions Status codecov SPEC 0 — Minimum Supported Dependencies pre-commit ruff

We welcome contributions!