Skip to content

Commit

Permalink
Modify Jaccard, Dice and Tversky losses
Browse files Browse the repository at this point in the history
  • Loading branch information
zifuwanggg committed Sep 15, 2024
1 parent 41a6fe5 commit 58a0a8f
Showing 1 changed file with 21 additions and 22 deletions.
43 changes: 21 additions & 22 deletions segmentation_models_pytorch/losses/_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,15 +157,7 @@ def soft_jaccard_score(
dims=None,
) -> torch.Tensor:
assert output.size() == target.size()
if dims is not None:
intersection = torch.sum(output * target, dim=dims)
cardinality = torch.sum(output + target, dim=dims)
else:
intersection = torch.sum(output * target)
cardinality = torch.sum(output + target)

union = cardinality - intersection
jaccard_score = (intersection + smooth) / (union + smooth).clamp_min(eps)
jaccard_score = soft_tversky_score(output, target, 1.0, 1.0, smooth, eps, dims)
return jaccard_score


Expand All @@ -177,13 +169,7 @@ def soft_dice_score(
dims=None,
) -> torch.Tensor:
assert output.size() == target.size()
if dims is not None:
intersection = torch.sum(output * target, dim=dims)
cardinality = torch.sum(output + target, dim=dims)
else:
intersection = torch.sum(output * target)
cardinality = torch.sum(output + target)
dice_score = (2.0 * intersection + smooth) / (cardinality + smooth).clamp_min(eps)
dice_score = soft_tversky_score(output, target, 0.5, 0.5, smooth, eps, dims)
return dice_score


Expand All @@ -196,15 +182,28 @@ def soft_tversky_score(
eps: float = 1e-7,
dims=None,
) -> torch.Tensor:
"""Tversky loss
References:
https://arxiv.org/pdf/2302.05666
https://arxiv.org/pdf/2303.16296
"""
assert output.size() == target.size()
if dims is not None:
intersection = torch.sum(output * target, dim=dims) # TP
fp = torch.sum(output * (1.0 - target), dim=dims)
fn = torch.sum((1 - output) * target, dim=dims)
difference = torch.norm(output - target, p=1, dim=dims)
output_sum = torch.sum(output, dim=dims)
target_sum = torch.sum(target, dim=dims)
intersection = (output_sum + target_sum - difference) / 2 # TP
fp = output_sum - intersection
fn = target_sum - intersection
else:
intersection = torch.sum(output * target) # TP
fp = torch.sum(output * (1.0 - target))
fn = torch.sum((1 - output) * target)
difference = torch.norm(output - target, p=1)
output_sum = torch.sum(output)
target_sum = torch.sum(target)
intersection = (output_sum + target_sum - difference) / 2 # TP
fp = output_sum - intersection
fn = target_sum - intersection

tversky_score = (intersection + smooth) / (
intersection + alpha * fp + beta * fn + smooth
Expand Down

0 comments on commit 58a0a8f

Please sign in to comment.