diff --git a/doctr/documents/elements.py b/doctr/documents/elements.py index 6553384ad5..5baa6b5df2 100644 --- a/doctr/documents/elements.py +++ b/doctr/documents/elements.py @@ -113,7 +113,7 @@ def __init__( if geometry is None: # Check whether this is a rotated or straight box box_resolution_fn = resolve_enclosing_rbbox if len(words[0].geometry) == 5 else resolve_enclosing_bbox - geometry = box_resolution_fn([w.geometry for w in words]) # type: ignore[operator] + geometry = box_resolution_fn([w.geometry for w in words]) # type: ignore[operator, misc] super().__init__(words=words) self.geometry = geometry @@ -149,7 +149,7 @@ def __init__( line_boxes = [word.geometry for line in lines for word in line.words] artefact_boxes = [artefact.geometry for artefact in artefacts] box_resolution_fn = resolve_enclosing_rbbox if len(lines[0].geometry) == 5 else resolve_enclosing_bbox - geometry = box_resolution_fn(line_boxes + artefact_boxes) # type: ignore[operator] + geometry = box_resolution_fn(line_boxes + artefact_boxes) # type: ignore[operator, arg-type] super().__init__(lines=lines, artefacts=artefacts) self.geometry = geometry diff --git a/doctr/utils/geometry.py b/doctr/utils/geometry.py index 84f5e1ad55..e20f4851e8 100644 --- a/doctr/utils/geometry.py +++ b/doctr/utils/geometry.py @@ -3,7 +3,7 @@ # This program is licensed under the Apache License version 2. # See LICENSE or go to for full license details. -from typing import List +from typing import List, Union import numpy as np import cv2 from .common_types import BoundingBox, Polygon4P, RotatedBbox @@ -36,9 +36,22 @@ def polygon_to_rbbox(polygon: Polygon4P) -> RotatedBbox: return fit_rbbox(cnt) -def resolve_enclosing_bbox(bboxes: List[BoundingBox]) -> BoundingBox: - x, y = zip(*[point for box in bboxes for point in box]) - return (min(x), min(y)), (max(x), max(y)) +def resolve_enclosing_bbox(bboxes: Union[List[BoundingBox], np.ndarray]) -> Union[BoundingBox, np.ndarray]: + """Compute enclosing bbox either from: + + - an array of boxes: (*, 5), where boxes have this shape: + (xmin, ymin, xmax, ymax, score) + + - a list of BoundingBox + + Return a (1, 5) array (enclosing boxarray), or a BoundingBox + """ + if isinstance(bboxes, np.ndarray): + xmin, ymin, xmax, ymax, score = np.split(bboxes, 5, axis=1) + return np.array([xmin.min(), ymin.min(), xmax.max(), ymax.max(), score.mean()]) + else: + x, y = zip(*[point for box in bboxes for point in box]) + return (min(x), min(y)), (max(x), max(y)) def resolve_enclosing_rbbox(rbboxes: List[RotatedBbox]) -> RotatedBbox: diff --git a/doctr/utils/metrics.py b/doctr/utils/metrics.py index 3dbc5ed215..0cd136020f 100644 --- a/doctr/utils/metrics.py +++ b/doctr/utils/metrics.py @@ -10,7 +10,8 @@ from scipy.optimize import linear_sum_assignment from doctr.utils.geometry import rbbox_to_polygon -__all__ = ['TextMatch', 'box_iou', 'mask_iou', 'rbox_to_mask', 'LocalizationConfusion', 'OCRMetric'] +__all__ = ['TextMatch', 'box_iou', 'box_ioa', 'mask_iou', 'rbox_to_mask', + 'nms', 'LocalizationConfusion', 'OCRMetric'] def string_match(word1: str, word2: str) -> Tuple[bool, bool, bool, bool]: @@ -143,6 +144,35 @@ def box_iou(boxes_1: np.ndarray, boxes_2: np.ndarray) -> np.ndarray: return iou_mat +def box_ioa(boxes_1: np.ndarray, boxes_2: np.ndarray) -> np.ndarray: + """Compute the IoA (intersection over area) between two sets of bounding boxes: + ioa(i, j) = inter(i, j) / area(i) + + Args: + boxes_1: bounding boxes of shape (N, 4) in format (xmin, ymin, xmax, ymax) + boxes_2: bounding boxes of shape (M, 4) in format (xmin, ymin, xmax, ymax) + Returns: + the IoA matrix of shape (N, M) + """ + + ioa_mat = np.zeros((boxes_1.shape[0], boxes_2.shape[0]), dtype=np.float32) + + if boxes_1.shape[0] > 0 and boxes_2.shape[0] > 0: + l1, t1, r1, b1 = np.split(boxes_1, 4, axis=1) + l2, t2, r2, b2 = np.split(boxes_2, 4, axis=1) + + left = np.maximum(l1, l2.T) + top = np.maximum(t1, t2.T) + right = np.minimum(r1, r2.T) + bot = np.minimum(b1, b2.T) + + intersection = np.clip(right - left, 0, np.Inf) * np.clip(bot - top, 0, np.Inf) + area = (r1 - l1) * (b1 - t1) + ioa_mat = intersection / area + + return ioa_mat + + def mask_iou(masks_1: np.ndarray, masks_2: np.ndarray) -> np.ndarray: """Compute the IoU between two sets of boolean masks @@ -200,6 +230,44 @@ def rbox_to_mask(boxes: np.ndarray, shape: Tuple[int, int]) -> np.ndarray: return masks.astype(bool) +def nms(boxes: np.ndarray, thresh: float = .5) -> List[int]: + """Perform non-max suppression, borrowed from `_. + + Args: + boxes: np array of straight boxes: (*, 5), (xmin, ymin, xmax, ymax, score) + thresh: iou threshold to perform box suppression. + + Returns: + A list of box indexes to keep + """ + x1 = boxes[:, 0] + y1 = boxes[:, 1] + x2 = boxes[:, 2] + y2 = boxes[:, 3] + scores = boxes[:, 4] + + areas = (x2 - x1) * (y2 - y1) + order = scores.argsort()[::-1] + + keep = [] + while order.size > 0: + i = order[0] + keep.append(i) + xx1 = np.maximum(x1[i], x1[order[1:]]) + yy1 = np.maximum(y1[i], y1[order[1:]]) + xx2 = np.minimum(x2[i], x2[order[1:]]) + yy2 = np.minimum(y2[i], y2[order[1:]]) + + w = np.maximum(0.0, xx2 - xx1) + h = np.maximum(0.0, yy2 - yy1) + inter = w * h + ovr = inter / (areas[i] + areas[order[1:]] - inter) + + inds = np.where(ovr <= thresh)[0] + order = order[inds + 1] + return keep + + class LocalizationConfusion: """Implements common confusion metrics and mean IoU for localization evaluation. diff --git a/test/test_utils_geometry.py b/test/test_utils_geometry.py index 8f5a05108a..a2768ea54f 100644 --- a/test/test_utils_geometry.py +++ b/test/test_utils_geometry.py @@ -12,6 +12,8 @@ def test_polygon_to_bbox(): def test_resolve_enclosing_bbox(): assert geometry.resolve_enclosing_bbox([((0, 0.5), (1, 0)), ((0.5, 0), (1, 0.25))]) == ((0, 0), (1, 0.5)) + pred = geometry.resolve_enclosing_bbox(np.array([[0.1, 0.1, 0.2, 0.2, 0.9], [0.15, 0.15, 0.2, 0.2, 0.8]])) + assert pred.all() == np.array([0.1, 0.1, 0.2, 0.2, 0.85]).all() def test_rbbox_to_polygon(): diff --git a/test/test_utils_metrics.py b/test/test_utils_metrics.py index da1ac22066..9d39142158 100644 --- a/test/test_utils_metrics.py +++ b/test/test_utils_metrics.py @@ -179,3 +179,24 @@ def test_ocr_metric( metric.reset() assert metric.num_gts == metric.num_preds == metric.tot_iou == 0 assert metric.raw_matches == metric.caseless_matches == metric.unidecode_matches == metric.unicase_matches == 0 + + +def test_nms(): + boxes = [ + [0.1, 0.1, 0.2, 0.2, 0.95], + [0.15, 0.15, 0.19, 0.2, 0.90], # to suppress + [0.5, 0.5, 0.6, 0.55, 0.90], + [0.55, 0.5, 0.7, 0.55, 0.85], # to suppress + ] + to_keep = metrics.nms(np.asarray(boxes), thresh=0.2) + assert to_keep == [0, 2] + + +def test_box_ioa(): + boxes = [ + [0.1, 0.1, 0.2, 0.2], + [0.15, 0.15, 0.2, 0.2], + ] + mat = metrics.box_ioa(np.array(boxes), np.array(boxes)) + assert mat[1, 0] == mat[0, 0] == mat[1, 1] == 1. + assert abs(mat[0, 1] - .25) <= 1e-7