Attention for PyTorch with Linear Memory Footprint

Unofficially implements https://arxiv.org/abs/2112.05682 to get Linear Memory Cost on Attention (+ some sidekick speedup on the GPU when compared to reference implementation in JAX
)
Usage:
git clone https://github.com/CHARM-Tx/linear_mem_attention-pytorch
cd linear_mem_attention_pytorch
python setup.py install
Usage:
High Level
from linear_mem_attention_torch.