A Transformer that Ponders, using the scheme from the PonderNet paper
Ponder(ing) Transformer
Implementation of a Transformer that learns to adapt the number of computational steps it takes depending on the difficulty of the input sequence, using the scheme from the PonderNet paper. Will also try to abstract out a pondering module that can be used with any block that returns an output with the halting probability.
This repository would not have been possible without repeated viewings of Yannic’s educational video
Install
$ pip install ponder-transformer
Usage
import torch
from ponder_transformer import PonderTransformer
model = PonderTransformer(
num_tokens = 20000,
dim = 512,
max_seq_len = 512
)
mask = torch.ones(1, 512).bool()
x = torch.randint(0, 20000, (1, 512))
y = torch.randint(0, 20000, (1, 512))
loss = model(x, labels = y, mask = mask)
loss.backward()