Model summary in pytorch

Is there any way, I can print the summary of a model in PyTorch like model.summary() method does in Keras as follows?

    Model Summary:
    ____________________________________________________________________________________________________
    Layer (type)                     Output Shape          Param #     Connected to                     
    ====================================================================================================
    input_1 (InputLayer)             (None, 1, 15, 27)     0                                            
    ____________________________________________________________________________________________________
    convolution2d_1 (Convolution2D)  (None, 8, 15, 27)     872         input_1[0][0]                    
    ____________________________________________________________________________________________________
    maxpooling2d_1 (MaxPooling2D)    (None, 8, 7, 27)      0           convolution2d_1[0][0]            
    ____________________________________________________________________________________________________
    flatten_1 (Flatten)              (None, 1512)          0           maxpooling2d_1[0][0]             
    ____________________________________________________________________________________________________
    dense_1 (Dense)                  (None, 1)             1513        flatten_1[0][0]                  
    ====================================================================================================
    Total params: 2,385
    Trainable params: 2,385
    Non-trainable params: 0

While you will not get as detailed information about the model as in Keras' model.summary, simply printing the model will give you some idea about the different layers involved and their specifications.

For instance:

    from torchvision import models
    model = models.vgg16()
    print(model)

The output in this case would be something as follows:

    VGG (
      (features): Sequential (
        (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU (inplace)
        (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (3): ReLU (inplace)
        (4): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1))
        (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (6): ReLU (inplace)
        (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (8): ReLU (inplace)
        (9): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1))
        (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (11): ReLU (inplace)
        (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (13): ReLU (inplace)
        (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (15): ReLU (inplace)
        (16): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1))
        (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (18): ReLU (inplace)
        (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (20): ReLU (inplace)
        (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (22): ReLU (inplace)
        (23): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1))
        (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (25): ReLU (inplace)
        (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (27): ReLU (inplace)
        (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (29): ReLU (inplace)
        (30): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1))
      )
      (classifier): Sequential (
        (0): Dropout (p = 0.5)
        (1): Linear (25088 -> 4096)
        (2): ReLU (inplace)
        (3): Dropout (p = 0.5)
        (4): Linear (4096 -> 4096)
        (5): ReLU (inplace)
        (6): Linear (4096 -> 1000)
      )
    )

Now you could, as mentioned by Kashyap, use the state_dict method to get the weights of the different layers. But using this listing of the layers would perhaps provide more direction is creating a helper function to get that Keras like model summary! Hope this helps!

From: stackoverflow.com/q/42480111

Back to homepage or read more recommendations: