A Mixed Precision library for JAX in python
![](https://www.deeplearningdaily.com/wp-content/uploads/2021/08/a-mixed-precision-library-for-jax-in-python_610b102e8a8b4-375x210.png)
Mixed precision training in JAX
Mixed precision training [0] is a technique that mixes the use of full and
half precision floating point numbers during training to reduce the memory
bandwidth requirements and improve the computational efficiency of a given
model.
This library implements support for mixed precision training in JAX by providing
two key abstractions (mixed precision “policies” and loss scaling). Neural
network libraries (such as Haiku) can integrate with jmp
and provide
“Automatic Mixed Precision (AMP)” support (automating or simplifying applying
policies to modules).
All code examples below assume the following:
import jax
import jax.numpy as jnp
import jmp
half = jnp.float16 # On TPU this should be jnp.bfloat16.
full = jnp.float32
Installation
JMP is written in pure Python, but depends on C++ code via JAX and NumPy.