Turning SymPy expressions into JAX functions
![](https://www.deeplearningdaily.com/wp-content/uploads/2021/09/turning-sympy-expressions-into-jax-functions_615392e4a1d6c-375x210.jpeg)
Turn SymPy expressions into parametrized, differentiable, vectorizable, JAX functions.
All SymPy floats become trainable input parameters. SymPy symbols become columns of a passed matrix.
Installation
pip install git+https://github.com/MilesCranmer/sympy2jax.git
Example
import sympy
from sympy import symbols
import jax
import jax.numpy as jnp
from jax import random
from sympy2jax import sympy2jax
Let’s create an expression in SymPy:
x, y = symbols('x y')
expression = 1.0 * sympy.cos(x) +