Implementing Multiclass Dice Loss Function

I am doing multi class segmentation using UNet. My input to the model is HxWxC and my output is,

outputs = layers.Conv2D(n_classes, (1, 1), activation='sigmoid')(decoder0)

Using SparseCategoricalCrossentropy I can train the network fine. Now I would like to also try dice coefficient as the loss function. Implemented as follows,

def dice_loss(y_true, y_pred, smooth=1e-6):
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.math.sigmoid(y_pred)

    numerator = 2 * tf.reduce_sum(y_true * y_pred) + smooth
    denominator = tf.reduce_sum(y_true + y_pred) + smooth

    return 1 - numerator / denominator

However I am actually getting an increasing loss instead of decreasing loss. I have checked multiple sources but all the material I find use diceloss for binary classification and not multiclass. So my question is is is there a problem with the implementation.

Answer

The problem is that your dice loss doesn’t address the number of classes you have but rather assumes binary case, so it might explain the increase in your loss.

You should implement generalized dice loss that accounts for all the classes and return the value for all of them.

Something like the following:

def dice_coef_9cat(y_true, y_pred, smooth=1e-7):
    '''
    Dice coefficient for 10 categories. Ignores background pixel label 0
    Pass to model as metric during compile statement
    '''
    y_true_f = K.flatten(K.one_hot(K.cast(y_true, 'int32'), num_classes=10)[...,1:])
    y_pred_f = K.flatten(y_pred[...,1:])
    intersect = K.sum(y_true_f * y_pred_f, axis=-1)
    denom = K.sum(y_true_f + y_pred_f, axis=-1)
    return K.mean((2. * intersect / (denom + smooth)))

def dice_coef_9cat_loss(y_true, y_pred):
    '''
    Dice loss to minimize. Pass to model as loss during compile statement
    '''
    return 1 - dice_coef_9cat(y_true, y_pred)

This snippet is taken from https://github.com/keras-team/keras/issues/9395#issuecomment-370971561

This is for 9 categories, while you should adjust to the number of categories you have.

Leave a Reply

Your email address will not be published. Required fields are marked *