UNET Implementation in TensorFlow using Keras API

In this post, you will learn how to implement UNET architecture in TensorFlow using Keras API. The post helps you to learn about UNET, and how to use it for your research.

UNET is one of the most popular semantic segmentation architecture. Olaf Ronneberger et al. developed this network for Biomedical Image Segmentation in 2015. 

Original UNET Architecture
The block diagram of the original UNET architecture

To know more, read the article: What is UNET?

Import

In this first part of the post, you need to import all classes required for the implementation of the UNET architecture.

from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, Conv2DTranspose, Concatenate, Input
from tensorflow.keras.models import Model

Here, we have imported all the layers classes from the TensorFlow that are needed. These layer class include Conv2D, BatchNormalization, ReLU and many more.

To connect the input and output layers of the UNET architecture, we imported the Model class.

Convolution Block

The entire UNET architecture consists of repeated use of two 3 x 3 convolution, each followed by a ReLU activation.

Here, we write the code for building the convolution block.

def conv_block(input, num_filters):
    x = Conv2D(num_filters, 3, padding="same")(input)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)

    x = Conv2D(num_filters, 3, padding="same")(x)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)

    return x

The convolution block or conv_block function takes two parameters:

  1. input: The input represents the feature maps from the previous block.
  2. num_filters: The num_filters refers to the number of output feature channels for the convolutional layers present in the conv_block function.

In the original UNET architecture, the two 3×3 convolutions are followed by a ReLU activation function. Here, we have introduced batch normalization in between the convolutional and the ReLU layer.

Batch normalization helps to make the deep neural network faster and more stable by normalizing the input layer.

Encoder Block

def encoder_block(input, num_filters):
    x = conv_block(input, num_filters)
    p = MaxPool2D((2, 2))(x)
    return x, p

The encoder_block function takes two parameters:

  1. input: The input represents the feature maps from the previous block.
  2. num_filters: The num_filters refers to the number of output feature channels.

The encoder_block consists of a conv_block and a 2×2 max-pooling layers. The encoder_block returns two feature maps:

  1. x: It represents the output of the conv_block. It acts as the skip connection for the corresponding decoder block.
  2. p: It represents the reduced feature maps passed to the next block as the input.

The encoder_block takes the input of size (128 x 128 x 32) with num_filters = 64 then it returns the output x = (128 x 128 x 64) and p = (64 x 64 x 64)

Decoder Block

def decoder_block(input, skip_features, num_filters):
    x = Conv2DTranspose(num_filters, (2, 2), strides=2, padding="same")(input)
    x = Concatenate()([x, skip_features])
    x = conv_block(x, num_filters)
    return x

The decoder_block takes three parameters:

  1. input: The input represents the feature maps from the previous block.
  2. skip_features: The skip_features represents the feature maps from the encoder block that are fetched through the skip connection.
  3. num_filters: The num_filters represents the number of output feature channels.

The decoder_block function begins a 2×2 transpose convolution which doubles the spatial dimensions (height and width) of the incoming feature maps.

If the input size is (16 x 16 x 32) and num_filters is 64 then the output of transpose convolution is (32 x 32 x 64).

Next, we concatenate the subsampled feature maps with the skip connection feature maps. These skip connections bring the feature maps from earlier layers helping the network to generate better semantic feature maps.

After the concatenation, a conv_block is used.

Till now, we have studied the encoder and the decoder block of the UNET architecture. Now, we start working on building complete UNET architecture.

UNET Architecture

def build_unet(input_shape):
    inputs = Input(input_shape)

    s1, p1 = encoder_block(inputs, 64)
    s2, p2 = encoder_block(p1, 128)
    s3, p3 = encoder_block(p2, 256)
    s4, p4 = encoder_block(p3, 512)

    b1 = conv_block(p4, 1024)

    d1 = decoder_block(b1, s4, 512)
    d2 = decoder_block(d1, s3, 256)
    d3 = decoder_block(d2, s2, 128)
    d4 = decoder_block(d3, s1, 64)

    outputs = Conv2D(1, 1, padding="same", activation="sigmoid")(d4)

    model = Model(inputs, outputs, name="U-Net")
    return model

The build_unet function take one parameter:

  1. input_shape: It is a tuple of height, width and the number of input channels. For example: (512, 512, 3)

The build_unet function returns the Model object, containing all the layers.

inputs = Input(input_shape)

The build_unet function begins with an Input layer with a specified input shape provided as the function parameter.

s1, p1 = encoder_block(inputs, 64)
s2, p2 = encoder_block(p1, 128)
s3, p3 = encoder_block(p2, 256)
s4, p4 = encoder_block(p3, 512)

Next, follows the four encoder blocks, here each encoder block uses the previous layer as the input. Along with the input, it takes the number of output feature channels. The number of filters begins with 64 and subsequently doubles for encoder block.

b1 = conv_block(p4, 1024)

The output of the 4th encoder block acts as the input for the bridge. The bridge is simply a conv_block with 1024 as the output feature channels.

d1 = decoder_block(b1, s4, 512)
d2 = decoder_block(d1, s3, 256)
d3 = decoder_block(d2, s2, 128)
d4 = decoder_block(d3, s1, 64)

Next, begins the decoder, which consists of four decoder block. Each decoder block uses previous feature maps as input and the number of output channels. The input is first upsampled by using a transpose convolution. These upsampled feature maps concatenated with the appropriate skip connection from the encoder block. After that, it is followed by a conv_block.

outputs = Conv2D(1, 1, padding="same", activation="sigmoid")(d4)

The output of the 4th decoder block passes through a 1×1 convolutional layer with sigmoid activation function.

For binary segmentation, we use 1 output feature channel with sigmoid activation. While in multiclass segmentation, we use the number of classes as the output feature channel with softmax activation function.

model = Model(inputs, outputs, name="U-Net")
return model

Finally, we have takes both the input and output of the UNET architecture and gave it to the Model class. Now we have a model object containing all the layers.

Run the UNET

if __name__ == "__main__":
    input_shape = (512, 512, 3)
    model = build_unet(input_shape)
    model.summary()

Now, we have implemented the UNET architecture in TensorFlow using Keras API. We call the build_unet function with an input_shape of (512 x 512 x 3)

Here is the summary of the UNET architecture.

Summary of the UNET architecture in TensorFlow
Summary of the UNET architecture

Read More

Nikhil Tomar

I am an independent researcher in the field of Artificial Intelligence. I love to write about the technology I am working on.

You may also like...

3 Responses

  1. Cel says:

    Hi, thanks for the clear explanation of the UNET architecture. Do you think this architecture can also be used to do dense regression (regression from input image to output image) if we simply replace the final sigmoid activation with a linear activation?

  2. Sonal says:

    Is kernel_size not necessary for Conv2dTranspose?

  3. Gajraj says:

    Can you show the compile step and what kind of loss we should choose ? For e.g. to use `from_logits` or not in SparseCateogricalCrossEntropyLoss ?

Leave a Reply

Your email address will not be published. Required fields are marked *