Skip to content

Commit

Permalink
Modify Dice loss
Browse files Browse the repository at this point in the history
  • Loading branch information
zifuwanggg committed Oct 11, 2024
1 parent c98aa6b commit 1c77e98
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions training/loss_fns.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import torch
import torch.distributed
import torch.linalg as LA
import torch.nn as nn
import torch.nn.functional as F

Expand All @@ -20,6 +21,9 @@
def dice_loss(inputs, targets, num_objects, loss_on_multimask=False):
"""
Compute the DICE loss, similar to generalized IOU for masks
Reference:
Dice Semimetric Losses: Optimizing the Dice Score with Soft Labels.
Wang, Z. et. al. MICCAI 2023.
Args:
inputs: A float tensor of arbitrary shape.
The predictions for each example.
Expand All @@ -38,11 +42,11 @@ def dice_loss(inputs, targets, num_objects, loss_on_multimask=False):
# flatten spatial dimension while keeping multimask channel dimension
inputs = inputs.flatten(2)
targets = targets.flatten(2)
numerator = 2 * (inputs * targets).sum(-1)
else:
inputs = inputs.flatten(1)
numerator = 2 * (inputs * targets).sum(1)
denominator = inputs.sum(-1) + targets.sum(-1)
difference = LA.vector_norm(inputs - targets, ord=1, dim=-1)
numerator = (denominator - difference) / 2
loss = 1 - (numerator + 1) / (denominator + 1)
if loss_on_multimask:
return loss / num_objects
Expand Down

0 comments on commit 1c77e98

Please sign in to comment.