NFNets and Adaptive Gradient Clipping for SGD implemented in PyTorch
Paper: https://arxiv.org/abs/2102.06171.pdf
Original code: https://github.com/deepmind/deepmind-research/tree/master/nfnets
Do star this repository if it helps your work!
Note: See this comment for a generic implementation for any optimizer as a temporary reference for anyone who needs it.
Install from PyPi:
pip3 install nfnets-pytorch
or install the latest code using:
pip3 install git+https://github.com/vballoli/nfnets-pytorch
WSConv2d
Use WSConv2d
and WSConvTranspose2d
like any other torch.nn.Conv2d
or torch.nn.ConvTranspose2d
modules.
import torch
from torch import nn
from nfnets import WSConv2d
conv = nn.Conv2d(3,6,3)
w_conv = WSConv2d(