In computer vision and medical image analysis, semantic segmentation plays a pivotal role in understanding and interpreting visual data. One of the prominent architectures in this domain is ResUNet, a fusion of U-Net and ResNet architectures, renowned for its ability to efficiently capture local and global features. In this blog post, we’ll delve into the implementation of ResUNet using TensorFlow, dissecting each block of code to provide a comprehensive understanding.
Understanding ResUNet Architecture
ResUNet combines the encoder-decoder structure of U-Net with residual blocks from ResNet, offering a robust framework for semantic segmentation tasks. The architecture consists of an encoder, a bridge, and a decoder, facilitating the extraction of intricate features at various scales.
Original Paper: Road Extraction by Deep Residual U-Net
Importing Libraries
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, UpSampling2D, Concatenate, Input
from tensorflow.keras.models import Model
Here, we import all the required layers. We also import the Model class to build the final architecture.
Batch Normalization & ReLU Function
def batchnorm_relu(inputs):
""" Batch Normalization & ReLU """
x = BatchNormalization()(inputs)
x = Activation("relu")(x)
return x
This function applies batch normalization followed by ReLU activation to the input tensor.
Residual Block
This function defines a residual block comprising convolutional layers with a shortcut connection for identity mapping, enhancing feature propagation.
def residual_block(inputs, num_filters, strides=1):
""" Convolutional Layers """
x = batchnorm_relu(inputs)
x = Conv2D(num_filters, 3, padding="same", strides=strides)(x)
x = batchnorm_relu(x)
x = Conv2D(num_filters, 3, padding="same", strides=1)(x)
""" Shortcut Connection (Identity Mapping) """
s = Conv2D(num_filters, 1, padding="same", strides=strides)(inputs)
""" Addition """
x = x + s
return x
The residual block in the ResUNet architecture serves a critical role in facilitating efficient feature propagation and addressing the vanishing gradient problem. Incorporating shortcut connections, allows for the smooth flow of gradients during backpropagation, enhancing optimization and enabling the training of deeper networks.
Read More: What is Residual Network or ResNet?
Decoder Block
This function constructs a decoder block by upsampling the input, concatenating skip connections from the encoder, and applying a residual block.
def decoder_block(inputs, skip_features, num_filters):
x = UpSampling2D((2, 2))(inputs)
x = Concatenate()([x, skip_features])
x = residual_block(x, num_filters, strides=1)
return x
The decoder block plays a pivotal role in recovering spatial information and refining segmentation masks. Through upsampling and concatenation of skip connections, it enables the fusion of high-level and low-level features, contributing to the model’s ability to generate accurate and detailed segmentation predictions.
Building ResUNet
def build_resunet(input_shape):
""" RESUNET Architecture """
inputs = Input(input_shape)
""" Endoder 1 """
x = Conv2D(64, 3, padding="same", strides=1)(inputs)
x = batchnorm_relu(x)
x = Conv2D(64, 3, padding="same", strides=1)(x)
s = Conv2D(64, 1, padding="same")(inputs)
s1 = x + s
""" Encoder 2, 3 """
s2 = residual_block(s1, 128, strides=2)
s3 = residual_block(s2, 256, strides=2)
""" Bridge """
b = residual_block(s3, 512, strides=2)
""" Decoder 1, 2, 3 """
x = decoder_block(b, s3, 256)
x = decoder_block(x, s2, 128)
x = decoder_block(x, s1, 64)
""" Classifier """
outputs = Conv2D(1, 1, padding="same", activation="sigmoid")(x)
""" Model """
model = Model(inputs, outputs, name="RESUNET")
return model
This function assembles the ResUNet architecture by defining encoder, bridge, and decoder components, ultimately generating the model.
Model Compilation
if __name__ == "__main__":
shape = (224, 224, 3)
model = build_resunet(shape)
model.summary()
In the main block, we specify the input shape, build the ResUNet model, and display its summary.
Conclusion
ResUNet stands as a formidable architecture in semantic segmentation, blending the strengths of U-Net and ResNet. By implementing it in TensorFlow, we’ve elucidated its intricacies, offering insights into each code block. Armed with this understanding, practitioners can leverage ResUNet for diverse image segmentation tasks, empowering advancements in computer vision and medical imaging.