Model summary in PyTorch similar to model.summary() in Keras
Keras style model.summary() in PyTorch
Keras has a neat API to view the visualization of the model which is very helpful while debugging your network. Here is a barebone code to try and mimic the same in PyTorch. The aim is to provide information complementary to, what is not provided by print(your_model)
in PyTorch.
Usage
pip install torchsummary
orgit clone https://github.com/sksq96/pytorch-summary
from torchsummary import summary
summary(your_model, input_size=(channels, H, W))
- Note that the
input_size
is required to make a forward pass through the network.
Examples
CNN for MNIST
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 =