Implementation of Graph Transformer in Pytorch

Graph Transformer – Pytorch
Implementation of Graph Transformer in Pytorch, for potential use in replicating Alphafold2. This was recently used by both Costa et al and Bakers lab for transforming MSA and pair-wise embedding into 3d coordinates.
Todo
- add rotary embeddings for injecting adjacency information
Install
$ pip install graph-transformer-pytorch
Usage
import torch
from graph_transformer_pytorch import GraphTransformer
model = GraphTransformer(
dim = 256,
depth = 6,
edge_dim = 512, # optional - if left out, edge dimensions is assumed to be the same as the node dimensions above
with_feedforwards = True, # whether to add a feedforward after each attention layer, suggested by literature to be needed
gated_residual = True # to use the gated residual to prevent over-smoothing
)
nodes =