CLOOB training (JAX) and inference (JAX and PyTorch)
![](https://www.deeplearningdaily.com/wp-content/uploads/2022/03/cloob-training-jax-and-inference-jax-and-pytorch_6246273d184a0-375x210.jpeg)
Pretrained models
PyTorch
from cloob_training import model_pt, pretrained
pretrained.list_configs()
returns:
['cloob_laion_400m_vit_b_16_16_epochs', 'cloob_laion_400m_vit_b_16_32_epochs']
The models can be used by:
config = pretrained.get_config('cloob_laion_400m_vit_b_16_16_epochs')
model = model_pt.get_pt_model(config)
checkpoint = pretrained.download_checkpoint(config)
model.load_state_dict(model_pt.get_pt_params(config, checkpoint))
model.eval().requires_grad_(False).to('cuda')
Model class attributes:
model.config
: the model config dict.
model.image_encoder
: the image encoder, which expects NCHW batches of normalized images (preprocessed by model.normalize
), where C