Continuous-Time Meta-Learning with Forward Mode Differentiation

ICLR 2022 (Spotlight) – Installation – Example – Citation
This repository contains the official implementation in JAX of COMLN (Deleu et al., 2022), a gradient-based meta-learning algorithm, where adaptation follows a gradient flow. It contains an implementation of the memory-efficient algorithm to compute the meta-gradients, based on forward-mode differentiation. The implementation is based on jax-meta.
Installation
To avoid any conflict with your existing Python setup, we are suggesting to work in a virtual environment:
python -m venv venv
source venv/bin/activate
Follow these instructions to install the version of JAX corresponding to your versions of CUDA and CuDNN. Note that if you want to test