ResUNET: A TensorFlow Implementation for Semantic Segmentation

In computer vision and medical image analysis, semantic segmentation plays a pivotal role in understanding and interpreting visual data. One of the prominent architectures in this domain is ResUNet, a fusion of U-Net and ResNet architectures, renowned for its ability to efficiently capture local and global features. In this blog post, we’ll delve into the implementation of ResUNet using TensorFlow, dissecting each block of code to provide a comprehensive understanding.

Understanding ResUNet Architecture

ResUNet combines the encoder-decoder structure of U-Net with residual blocks from ResNet, offering a robust framework for semantic segmentation tasks. The architecture consists of an encoder, a bridge, and a decoder, facilitating the extraction of intricate features at various scales.

The block diagram of the ResUNet from its research paper.
The block diagram of the ResUNet from its research paper.

Original Paper: Road Extraction by Deep Residual U-Net


Importing Libraries

from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, UpSampling2D, Concatenate, Input
from tensorflow.keras.models import Model

Here, we import all the required layers. We also import the Model class to build the final architecture.

Batch Normalization & ReLU Function

def batchnorm_relu(inputs):
    """ Batch Normalization & ReLU """
    x = BatchNormalization()(inputs)
    x = Activation("relu")(x)
    return x

This function applies batch normalization followed by ReLU activation to the input tensor.

Residual Block

This function defines a residual block comprising convolutional layers with a shortcut connection for identity mapping, enhancing feature propagation.

def residual_block(inputs, num_filters, strides=1):
    """ Convolutional Layers """
    x = batchnorm_relu(inputs)
    x = Conv2D(num_filters, 3, padding="same", strides=strides)(x)
    x = batchnorm_relu(x)
    x = Conv2D(num_filters, 3, padding="same", strides=1)(x)

    """ Shortcut Connection (Identity Mapping) """
    s = Conv2D(num_filters, 1, padding="same", strides=strides)(inputs)

    """ Addition """
    x = x + s

    return x
The block diagram showing the residual block.
The block diagram showing the residual block.

The residual block in the ResUNet architecture serves a critical role in facilitating efficient feature propagation and addressing the vanishing gradient problem. Incorporating shortcut connections, allows for the smooth flow of gradients during backpropagation, enhancing optimization and enabling the training of deeper networks.


Read More: What is Residual Network or ResNet?


Decoder Block

This function constructs a decoder block by upsampling the input, concatenating skip connections from the encoder, and applying a residual block.

def decoder_block(inputs, skip_features, num_filters):
    x = UpSampling2D((2, 2))(inputs)
    x = Concatenate()([x, skip_features])
    x = residual_block(x, num_filters, strides=1)
    return x

The decoder block plays a pivotal role in recovering spatial information and refining segmentation masks. Through upsampling and concatenation of skip connections, it enables the fusion of high-level and low-level features, contributing to the model’s ability to generate accurate and detailed segmentation predictions.

Building ResUNet

def build_resunet(input_shape):
    """ RESUNET Architecture """

    inputs = Input(input_shape)

    """ Endoder 1 """
    x = Conv2D(64, 3, padding="same", strides=1)(inputs)
    x = batchnorm_relu(x)
    x = Conv2D(64, 3, padding="same", strides=1)(x)
    s = Conv2D(64, 1, padding="same")(inputs)
    s1 = x + s

    """ Encoder 2, 3 """
    s2 = residual_block(s1, 128, strides=2)
    s3 = residual_block(s2, 256, strides=2)

    """ Bridge """
    b = residual_block(s3, 512, strides=2)

    """ Decoder 1, 2, 3 """
    x = decoder_block(b, s3, 256)
    x = decoder_block(x, s2, 128)
    x = decoder_block(x, s1, 64)

    """ Classifier """
    outputs = Conv2D(1, 1, padding="same", activation="sigmoid")(x)

    """ Model """
    model = Model(inputs, outputs, name="RESUNET")

    return model

This function assembles the ResUNet architecture by defining encoder, bridge, and decoder components, ultimately generating the model.

Model Compilation

if __name__ == "__main__":
    shape = (224, 224, 3)
    model = build_resunet(shape)
    model.summary()

In the main block, we specify the input shape, build the ResUNet model, and display its summary.

Conclusion

ResUNet stands as a formidable architecture in semantic segmentation, blending the strengths of U-Net and ResNet. By implementing it in TensorFlow, we’ve elucidated its intricacies, offering insights into each code block. Armed with this understanding, practitioners can leverage ResUNet for diverse image segmentation tasks, empowering advancements in computer vision and medical imaging.

Nikhil Tomar

I am an independent researcher in the field of Artificial Intelligence. I love to write about the technology I am working on.

You may also like...

Leave a Reply

Your email address will not be published. Required fields are marked *