The Squeeze and Excitation network is a channel-wise attention mechanism that is used to improve the overall performance of the network. In today’s article, we are going to implement the Squeeze and Excitation module in TensorFlow and PyTorch.
What is Squeeze and Excitation Network?
The squeeze and excitation attention mechanism was introduced in the year 2018 by Hu et al. in their paper “Squeeze-and-Excitation Networks” at CVPR 2018 with a journal version in TPAMI. It is one of the most dominant papers in the field of attention mechanisms and was cited more than 8000 times.
The Squeeze and Excitation Network basically introduces a novel channel-wise attention mechanism for CNNs (Convolutional Neural Network) to improve their channel interdependencies. The network adds a parameter that re-weights each channel accordingly so that it becomes more sensitive towards significant features while ignoring the irrelevant features.
Squeeze and Excitation Network is a channel-wise attention mechanism that recalibrates each channel accordingly to create a more robust representation by enhancing the important features.
READ MORE: Squeeze and Excitation Networks
Squeeze and Excitation Implementation in TensorFlow
from tensorflow.keras.layers import GlobalAveragePooling2D, Reshape, Dense, Input def SqueezeAndExcitation(inputs, ratio=8): b, _, _, c = inputs.shape x = GlobalAveragePooling2D()(inputs) x = Dense(c//ratio, activation="relu", use_bias=False)(x) x = Dense(c, activation="sigmoid", use_bias=False)(x) x = inputs * x return x if __name__ == "__main__": inputs = Input(shape=(128, 128, 32)) y = SqueezeAndExcitation(inputs) print(y.shape)
Squeeze and Excitation Implementation in PyTorch
import torch import torch.nn as nn class SqueezeAndExcitation(nn.Module): def __init__(self, channel, ratio=8): super().__init__() self.avg_pool = nn.AdaptiveAvgPool2d(1) self.network = nn.Sequential( nn.Linear(channel, channel//ratio, bias=False), nn.ReLU(inplace=True), nn.Linear(channel//ratio, channel, bias=False), nn.Sigmoid() ) def forward(self, inputs): b, c, _, _ = inputs.shape x = self.avg_pool(inputs) x = x.view(b, c) x = self.network(x) x = x.view(b, c, 1, 1) x = inputs * x return x if __name__ == "__main__": inputs = torch.randn((8, 32, 128, 128)) se = SqueezeAndExcitation(32, ratio=8) y = se(inputs) print(y.shape)
In this coding tutorial, you have learned about one of the most widely used channel-wise attention mechanisms known as “Squeeze and Excitation Network”.
Still, have some questions or queries? Just comment below. For more updates. Follow me.
This Post Has One Comment
can you share a code of using squeeze and excitation network on custom CNN for a classification