Skip to content

Commit

Permalink
Modify the Tversky loss
Browse files Browse the repository at this point in the history
  • Loading branch information
zifuwanggg committed Sep 18, 2024
1 parent 58a0a8f commit c4cbd1e
Showing 1 changed file with 9 additions and 14 deletions.
23 changes: 9 additions & 14 deletions segmentation_models_pytorch/losses/_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Optional

import torch
import torch.linalg as LA
import torch.nn.functional as F

__all__ = [
Expand Down Expand Up @@ -190,20 +191,14 @@ def soft_tversky_score(
"""
assert output.size() == target.size()
if dims is not None:
difference = torch.norm(output - target, p=1, dim=dims)
output_sum = torch.sum(output, dim=dims)
target_sum = torch.sum(target, dim=dims)
intersection = (output_sum + target_sum - difference) / 2 # TP
fp = output_sum - intersection
fn = target_sum - intersection
else:
difference = torch.norm(output - target, p=1)
output_sum = torch.sum(output)
target_sum = torch.sum(target)
intersection = (output_sum + target_sum - difference) / 2 # TP
fp = output_sum - intersection
fn = target_sum - intersection

output_sum = torch.sum(output, dim=dims)
target_sum = torch.sum(target, dim=dims)
difference = LA.vector_norm(output - target, ord=1, dim=dims)

intersection = (output_sum + target_sum - difference) / 2 # TP
fp = output_sum - intersection
fn = target_sum - intersection

tversky_score = (intersection + smooth) / (
intersection + alpha * fp + beta * fn + smooth
Expand Down

0 comments on commit c4cbd1e

Please sign in to comment.