Turning SymPy expressions into JAX functions
Turn SymPy expressions into parametrized, differentiable, vectorizable, JAX functions.
All SymPy floats become trainable input parameters. SymPy symbols become columns of a passed matrix.
Installation
pip install git+https://github.com/MilesCranmer/sympy2jax.git
Example
import sympy
from sympy import symbols
import jax
import jax.numpy as jnp
from jax import random
from sympy2jax import sympy2jax
Let’s create an expression in SymPy:
x, y = symbols('x y')
expression = 1.0 * sympy.cos(x) +