Clockwork Variational Autoencoders using JAX and Flax
Clockwork VAEs in JAX/Flax
Implementation of experiments in the paper Clockwork Variational Autoencoders (project website) using JAX and Flax, ported from the official TensorFlow implementation.
Running on a single TPU v3, training is 10x faster than reported in the paper (60h -> 6h on minerl).
Method
Clockwork VAEs are deep generative model that learn long-term dependencies in video by leveraging hierarchies of representations that progress at different clock speeds. In contrast to prior video prediction methods that typically focus on predicting sharp but short sequences in the future, Clockwork VAEs can accurately predict high-level content, such as object positions and identities, for 1000 frames.
Clockwork VAEs build upon the Recurrent State Space Model (RSSM), so each state contains a deterministic component