ColonSegNet Implementation In TensorFlow

In this article, we will embark on a deep dive into the implementation of ColonSegNet in TensorFlow. It is a lightweight and real-time colon segmentation architecture and has garnered attention for its efficiency in medical image analysis. In our previous post, we introduced the architecture and its components. Now, let’s explore the technical intricacies of implementing ColonSegNet in TensorFlow.


Previous Post: ColonSegNet: A Lightweight Real-Time Colon Segmentation Architecture


Install TensorFlow

pip install tensorflow

Now, let’s go through the code step by step.

Importing Libraries and Functions

Import necessary libraries and suppress TensorFlow warnings.

import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, UpSampling2D, Dense
from tensorflow.keras.layers import GlobalAveragePooling2D, Conv2DTranspose, Concatenate, Input
from tensorflow.keras.layers import MaxPool2D
from tensorflow.keras.models import Model

The code starts by importing essential libraries. The os.environ line is used to suppress unnecessary TensorFlow warnings for cleaner output.

Squeeze and Excitation (SE) Block

The se_layer function defines the Squeeze and Excitation Block block, a key component in enhancing the representational power of ColonSegNet. Let’s break down its implementation:

The block diagram of Squeeze and Excitation Network.
The block diagram of Squeeze and Excitation Network.
def se_layer(x, num_filters, reduction=16):
    x_init = x

    x = GlobalAveragePooling2D()(x)
    x = Dense(num_filters//reduction, use_bias=False, activation="relu")(x)
    x = Dense(num_filters, use_bias=False, activation="sigmoid")(x)
    x = x * x_init
    return x
  • Global Average Pooling: Reduces the spatial dimensions of the input tensor to a single value per channel, capturing the global context.
  • Squeeze: A fully connected layer with a ReLU activation compresses the channel-wise information, preparing it for excitation.
  • Excitation: Another fully connected layer with a sigmoid activation produces a channel-wise scaling factor, indicating the importance of each channel.
  • Scaling: The input tensor is then scaled element-wise by the excitation output, allowing the network to focus on relevant channels.

Read More:


Residual Block

The residual_block function represents a residual block, a crucial element in overcoming the vanishing gradient problem and facilitating information flow. Let’s explore its details:

The block diagram of residual block which uses squeeze and excitation attention mechanism to improve its performance
The block diagram of residual block which uses squeeze and excitation attention mechanism to improve its performance
def residual_block(x, num_filters):
    x_init = x

    x = Conv2D(num_filters, 3, padding="same")(x)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)
    
    x = Conv2D(num_filters, 3, padding="same")(x)
    x = BatchNormalization()(x)

    s = Conv2D(num_filters, 1, padding="same")(x_init)
    s = BatchNormalization()(s)
    s = se_layer(s, num_filters)

    x = Activation("relu")(x + s)
    return x
  • The block starts with a convolutional layer, followed by batch normalization and ReLU activation.
  • Another convolutional layer is applied, and batch normalization is performed again.
  • A shortcut connection is created using a 1×1 convolution followed by batch normalization and the SE layer.
  • The final output is obtained by adding the shortcut connection to the output of the second convolutional layer, followed by ReLU activation.

Strided Convolution Block

The strided_conv_block function defines a block with strided convolution, aiding in downsampling spatial dimensions. This is crucial for capturing hierarchical features.

def strided_conv_block(x, num_filters):
    x = Conv2D(num_filters, 3, strides=2, padding="same")(x)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)
    return x

Encoder Block

The encoder_block function combines the residual and strided convolution blocks to create an encoder block. It captures and refines features while downsampling the input.

def encoder_block(x, num_filters):
    x1 = residual_block(x, num_filters)
    x2 = strided_conv_block(x1, num_filters)
    x3 = residual_block(x2, num_filters)
    p = MaxPool2D((2, 2))(x3)

    return x1, x3, p

ColonSegNet Architecture

The build_colonsegnet function assembles the complete ColonSegNet architecture. It consists of encoder and decoder blocks, making it capable of segmenting colon regions in medical images.

def build_colonsegnet(input_shape):
    """ Input """
    inputs = Input(input_shape)

    """ Encoder """
    s11, s12, p1 = encoder_block(inputs, 64)
    s21, s22, p2 = encoder_block(p1, 256)

    """ Decoder 1 """
    x = Conv2DTranspose(128, 4, strides=4, padding="same")(s22)
    x = Concatenate()([x, s12])
    x = residual_block(x, 128)
    r1 = x

    x = Conv2DTranspose(128, 4, strides=2, padding="same")(s21)
    x = Concatenate()([x, r1])
    x = residual_block(x, 128)

    """ Decoder 2 """
    x = Conv2DTranspose(64, 4, strides=2, padding="same")(x)
    x = Concatenate()([x, s11])
    x = residual_block(x, 64)
    r2 = x

    x = Conv2DTranspose(32, 4, strides=2, padding="same")(s12)
    x = Concatenate()([x, r2])
    x = residual_block(x, 32)

    """ Output """
    output = Conv2D(1, 1, padding="same")(x)

    """ Model """
    model = Model(inputs, output)

    return model

Running the ColonSegNet

In the main block, we specify the input shape, build the ColonSegNet model using the defined function, and print the model summary.

if __name__ == "__main__":
    input_shape = (512, 512, 3)
    model = build_colonsegnet(input_shape)
    model.summary()

Conclusion

Congratulations! You’ve just explored the implementation details of ColonSegNet in TensorFlow. Feel free to experiment with the code, integrate it into your projects, and contribute to the growing field of medical image segmentation.

Happy coding!

Nikhil Tomar

I am an independent researcher in the field of Artificial Intelligence. I love to write about the technology I am working on.

You may also like...

Leave a Reply

Your email address will not be published. Required fields are marked *