Squeeze and Excitation Implementation in TensorFlow and PyTorch

The Squeeze and Excitation network is a channel-wise attention mechanism that is used to improve the overall performance of the network. In today’s article, we are going to implement the Squeeze and Excitation module in TensorFlow and PyTorch.

What is Squeeze and Excitation Network?

The squeeze and excitation attention mechanism was introduced in the year 2018 by Hu et al. in their paper “Squeeze-and-Excitation Networks” at CVPR 2018 with a journal version in TPAMI. It is one of the most dominant papers in the field of attention mechanisms and was cited more than 8000 times. 

The Squeeze and Excitation Network basically introduces a novel channel-wise attention mechanism for CNNs (Convolutional Neural Network) to improve their channel interdependencies. The network adds a parameter that re-weights each channel accordingly so that it becomes more sensitive towards significant features while ignoring the irrelevant features.

Squeeze and Excitation Network is a channel-wise attention mechanism that recalibrates each channel accordingly to create a more robust representation by enhancing the important features.

READ MORE: Squeeze and Excitation Networks

The block diagram of the Squeeze and Excitation block.
The block diagram of the Squeeze and Excitation block.

Squeeze and Excitation Implementation in TensorFlow

from tensorflow.keras.layers import GlobalAveragePooling2D, Reshape, Dense, Input

def SqueezeAndExcitation(inputs, ratio=8):
    b, _, _, c = inputs.shape
    x = GlobalAveragePooling2D()(inputs)
    x = Dense(c//ratio, activation="relu", use_bias=False)(x)
    x = Dense(c, activation="sigmoid", use_bias=False)(x)
    x = inputs * x
    return x

if __name__ == "__main__":
    inputs = Input(shape=(128, 128, 32))
    y = SqueezeAndExcitation(inputs)
    print(y.shape)

Squeeze and Excitation Implementation in PyTorch

import torch
import torch.nn as nn

class SqueezeAndExcitation(nn.Module):
    def __init__(self, channel, ratio=8):
        super().__init__()

        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.network = nn.Sequential(
            nn.Linear(channel, channel//ratio, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel//ratio, channel,  bias=False),
            nn.Sigmoid()
        )

    def forward(self, inputs):
        b, c, _, _ = inputs.shape
        x = self.avg_pool(inputs)
        x = x.view(b, c)
        x = self.network(x)
        x = x.view(b, c, 1, 1)
        x = inputs * x
        return x

if __name__ == "__main__":
    inputs = torch.randn((8, 32, 128, 128))
    se = SqueezeAndExcitation(32, ratio=8)
    y = se(inputs)
    print(y.shape)

Conclusion

In this coding tutorial, you have learned about one of the most widely used channel-wise attention mechanisms known as “Squeeze and Excitation Network”. 

Still, have some questions or queries? Just comment below. For more updates. Follow me.

Nikhil Tomar

I am an independent researcher in the field of Artificial Intelligence. I love to write about the technology I am working on.

You may also like...

1 Response

  1. chinnu jacob says:

    sir
    can you share a code of using squeeze and excitation network on custom CNN for a classification

Leave a Reply

Your email address will not be published. Required fields are marked *