DEV Community

Darth Espressius
Darth Espressius

Posted on

3 Common Loss Functions for Image Segmentation

Image segmentation has a wide range of applications, from music spectrum separation and self-driving-cars to biomedical imaging and brain-tumor segmentation. The aim of image segmentation is to visually separate (segment) parts of an image (or image-sequence) into separate objects. For example in the image below from the OCR: Transformer Segmentation paper, the car at the center of the image was "detected" on a pixel-wise basis. Whilst object detection would simply return the coordinates of say, a bounding box around the car, segmentation aims to return an image mask (1 for "is car", 0 for "is not car") for a given image.

Car Segmented from OCR Paper

Deep learning has affected (in my opinion) the area of computer vision more so than any other field. There have been multiple innovations in various fields using a variety of techniques over the past five tears. Image segmentation can be thought of a classification task on the pixel level, and the choice of loss function for the task of segmentation is key in determining both the speed at which a Machine-Learning model converges, as well to some extent, the accuracy of the model.

A loss function gives feedback to the model during the process of supervised training (learning from already-labelled data), how well it is converging upon the optimal model parameters. It is used to guide a model in its search for the "ideal" approximation which maps the input data to the output data (images to masks in the case of image segmentation).

This review paper from Shruti Jadon (IEEE Member) bucketed loss functions into four main groupings: Distribution-based, region-based, boundary-based and compounded loss. In this blog post, I will focus on three of the more commonly-used loss functions for semantic image segmentation: Binary Cross-Entropy Loss, Dice Loss and the Shape-Aware Loss.

Binary Cross-Entropy

Cross-entropy is used to measure the difference between two probability distributions. It is used as a similarity metric to tell how close one distribution of random events are to another, and is used for both classification (in the more general sense) as well as segmentation.

The binary cross-entropy (BCE) loss therefore attempts to measure the differences of information content between the actual and predicted image masks. It is more generally based on the Bernoulli distribution, and works best with equal data-distribution amongst classes. In other terms, image masks with very heavy class imbalance may (such as in finding very small, rare tumors from X-ray images) may not be adequately evaluated by BCE.

This is due to the fact that the BCE treats both positive (1) and negative (0) samples in the image mask equally. Since there may be an unequal distribution of pixels that represent a given object (say, a car from the first image above) and the rest of the image, the BCE loss may not effectively represent the performance of the deep-learning model.

Binary Cross Entropy is defined as:

L(y,y^)=ylog(y^)(1y)log(1y^) L(y,\hat{y}) = -y\log(\hat{y}) - (1-y)log(1-\hat{y})

Quick primer on mathematical notation: if yy is our target image-segmentation mask, and y^\hat{y} is our predicted mask from our deep-learning model, the loss measures the difference between what we want ( yy ) and what the model gave us ( y^\hat{y} )

This has been implemented in TensorFlow's keras.losses package and as such, can be readily used as-is in your image segmentation models.

An adaptation of vanilla BCE has been weighted BCE, which weights positive pixels by some coefficient. It is heavily used in medical imaging (and other areas with highly skewed datasets). It is defined as follows:

L(y,y^)=βylog(y^)(1y)log(1y^) L(y,\hat{y}) = -\beta y\log(\hat{y}) - (1-y)log(1-\hat{y})

The β\beta parameter can be tuned, for example: to reduce the number of false-negative pixels, β>1\beta > 1 , in order to reduce the number of false positives, set β<1\beta < 1

Dice Coefficient

This is a widely-used loss to calculate the similarity between images and is similar to the Intersection-over-Union heuristic. The Dice Coefficient has as such, been adapted to a loss function as the Dice Loss:

DL(y,y^)=12yy^+1y+y^+1 DL(y, \hat{y}) = 1 - \frac{2y\hat{y}+1}{y+\hat{y}+1}

A common criticism is the nature of its resulting search space, which is non-convex, several modifications have been made to make the Dice Loss more tractable for solving using methods such as L-BFGS and Stochastic Gradient Descent. The Dice Loss can be implemented in TensorFlow by subclassing tf.keras.losses as following:

class DiceLoss(tf.keras.losses.Loss):
    def __init__(self, smooth=1e-6, gama=2):
        super(DiceLoss, self).__init__() = 'NDL'
        self.smooth = smooth
        self.gama = gama

    def call(self, y_true, y_pred):
        y_true, y_pred = tf.cast(
            y_true, dtype=tf.float32), tf.cast(y_pred, tf.float32)
        nominator = 2 * \
            tf.reduce_sum(tf.multiply(y_pred, y_true)) + self.smooth
        denominator = tf.reduce_sum(
            y_pred ** self.gama) + tf.reduce_sum(y_true ** self.gama) + self.smooth
        result = 1 - tf.divide(nominator, denominator)
        return result
Enter fullscreen mode Exit fullscreen mode

Shape-Aware Loss

The U-Net paper forced their fully-connected convolutional network to learn small separation borders by using a pre-computed weight map for each ground truth pixel. This was aimed at compensating for the different frequency of pixels from certain classes in the training data set, and is computed using morphological operations. This weight map was computed as:

w(x)=wc(x)+w0e(d1(x)+d2(x))22σ2 w(\bold{x}) = w_c(\bold{x}) + w_0 e^{-\frac{ (d_1(\bold{x}) + d_2(\bold{x}))^2}{2\sigma^2}}

The d1d_1 and d2d_2 functions give distances to the nearest and second nearest cells. wcw_c is manually tuned to weight classes of instances of objects within an image depending on class distribution.

This weight term is then used in the typical cross-entropy loss, which results in the following loss function:

L(y,y^)=w(x)×[ylog(y^)+(1y)log(1p^)] L(y, \hat{y}) = -w(\bold{x})\times \left[ y\log(\hat{y}) + (1-y)\log(1-\hat{p})\right]

Top comments (0)