diff --git a/egsis/metrics.py b/egsis/metrics.py index d8d94a5..6e8c605 100644 --- a/egsis/metrics.py +++ b/egsis/metrics.py @@ -25,7 +25,7 @@ def iou(y_true: np.ndarray, y_pred: np.ndarray) -> float: def f1(y_true: np.ndarray, y_pred: np.ndarray) -> float: _check_if_they_are_binarized(y_true, y_pred) intersection = y_true & y_pred - total_size = y_true.size + y_pred.size + total_size = y_true.sum() + y_pred.sum() return (2 * intersection.sum()) / total_size diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 3151853..7cf57e3 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -32,7 +32,7 @@ def test_iou(x, y): def test_f1(x, y): - assert metrics.f1(x, y) == pytest.approx(0.11111, 0.001) + assert metrics.f1(x, y) == pytest.approx(0.33333, 0.001) def test_iou_raise_error_on_non_binarized_labels(x, y):