PyTorch Implementation of Differentiable ODE Solvers
PyTorch Implementation of Differentiable ODE Solvers
This library provides ordinary differential equation (ODE) solvers implemented in PyTorch. Backpropagation through ODE solutions is supported using the adjoint method for constant memory cost. For usage of ODE solvers in deep learning applications, see reference [1].
As the solvers are implemented in PyTorch, algorithms in this repository are fully supported to run on the GPU.
Installation
To install latest stable version:
pip install torchdiffeq
To install latest on GitHub:
pip install git+https://github.com/rtqichen/torchdiffeq
Examples
Examples are placed in the examples
directory.
We encourage those who are interested in using this library to take a look at examples/ode_demo.py
for understanding how to use torchdiffeq
to fit a simple spiral ODE.
Basic usage
This library