A lightweight engine for slicing a PyTorch tensor into parallel shards
TorchShard
TorchShard is a lightweight engine for slicing a PyTorch tensor into parallel shards. It can reduce GPU memory and scale up the training when the model has massive linear layers (e.g., ViT, BERT and GPT) or huge classes (millions). It has the same API design as PyTorch.
Installation
pip install torchshard
More options in INSTALL.md.
Usage
import torchshard as ts
ts.init_process_group(group_size=2) # init parallel groups
m = torch.nn.Sequential(
torch.nn.Linear(20, 30, bias=True),
ts.nn.ParallelLinear(30, 30, bias=True, dim=None), # equal to nn.Linear()
ts.nn.ParallelLinear(30, 30, bias=True, dim=0), # parallel in row dimension
ts.nn.ParallelLinear(30, 30, bias=True, dim=1), # parallel in column dimension
).cuda()
x = m(x) # forward
loss = ts.nn.functional.parallel_cross_entropy(x, y) # parallel loss function
loss.backward() # backward
torch.save(
ts.collect_state_dict(m, m.state_dict()), 'm.pt') # save model state