Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/assume straight text #1723

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion doctr/models/kie_predictor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,21 @@ class _KIEPredictor(_OCRPredictor):
def __init__(
self,
assume_straight_pages: bool = True,
assume_straight_text: bool = False,
straighten_pages: bool = False,
preserve_aspect_ratio: bool = True,
symmetric_pad: bool = True,
detect_orientation: bool = False,
**kwargs: Any,
) -> None:
super().__init__(
assume_straight_pages, straighten_pages, preserve_aspect_ratio, symmetric_pad, detect_orientation, **kwargs
assume_straight_pages,
assume_straight_text,
straighten_pages,
preserve_aspect_ratio,
symmetric_pad,
detect_orientation,
**kwargs,
)

self.doc_builder: KIEDocumentBuilder = KIEDocumentBuilder(**kwargs)
7 changes: 6 additions & 1 deletion doctr/models/kie_predictor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ class KIEPredictor(nn.Module, _KIEPredictor):
reco_predictor: recognition module
assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages
without rotated textual elements.
assume_straight_text: if True, speeds up the inference by assuming you only pass straight text
without rotated textual elements.
straighten_pages: if True, estimates the page general orientation based on the median line orientation.
Then, rotates page before passing it to the deep learning modules. The final predictions will be remapped
accordingly. Doing so will improve performances for documents with page-uniform rotations.
Expand All @@ -44,6 +46,7 @@ def __init__(
det_predictor: DetectionPredictor,
reco_predictor: RecognitionPredictor,
assume_straight_pages: bool = True,
assume_straight_text: bool = False,
straighten_pages: bool = False,
preserve_aspect_ratio: bool = True,
symmetric_pad: bool = True,
Expand All @@ -57,6 +60,7 @@ def __init__(
_KIEPredictor.__init__(
self,
assume_straight_pages,
assume_straight_text,
straighten_pages,
preserve_aspect_ratio,
symmetric_pad,
Expand Down Expand Up @@ -129,10 +133,11 @@ def forward(
dict_loc_preds[class_name],
channels_last=channels_last,
assume_straight_pages=self.assume_straight_pages,
assume_straight_text=self.assume_straight_text,
)
# Rectify crop orientation
crop_orientations: Any = {}
if not self.assume_straight_pages:
if not self.assume_straight_pages and not self.assume_straight_text:
for class_name in dict_loc_preds.keys():
crops[class_name], dict_loc_preds[class_name], word_orientations = self._rectify_crops(
crops[class_name], dict_loc_preds[class_name]
Expand Down
12 changes: 10 additions & 2 deletions doctr/models/kie_predictor/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ class KIEPredictor(NestedObject, _KIEPredictor):
reco_predictor: recognition module
assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages
without rotated textual elements.
assume_straight_text: if True, speeds up the inference by assuming you only pass straight text
without rotated textual elements.
straighten_pages: if True, estimates the page general orientation based on the median line orientation.
Then, rotates page before passing it to the deep learning modules. The final predictions will be remapped
accordingly. Doing so will improve performances for documents with page-uniform rotations.
Expand All @@ -46,6 +48,7 @@ def __init__(
det_predictor: DetectionPredictor,
reco_predictor: RecognitionPredictor,
assume_straight_pages: bool = True,
assume_straight_text: bool = False,
straighten_pages: bool = False,
preserve_aspect_ratio: bool = True,
symmetric_pad: bool = True,
Expand All @@ -58,6 +61,7 @@ def __init__(
_KIEPredictor.__init__(
self,
assume_straight_pages,
assume_straight_text,
straighten_pages,
preserve_aspect_ratio,
symmetric_pad,
Expand Down Expand Up @@ -122,12 +126,16 @@ def __call__(
crops = {}
for class_name in dict_loc_preds.keys():
crops[class_name], dict_loc_preds[class_name] = self._prepare_crops(
pages, dict_loc_preds[class_name], channels_last=True, assume_straight_pages=self.assume_straight_pages
pages,
dict_loc_preds[class_name],
channels_last=True,
assume_straight_pages=self.assume_straight_pages,
assume_straight_text=self.assume_straight_text,
)

# Rectify crop orientation
crop_orientations: Any = {}
if not self.assume_straight_pages:
if not self.assume_straight_pages and not self.assume_straight_text:
for class_name in dict_loc_preds.keys():
crops[class_name], dict_loc_preds[class_name], word_orientations = self._rectify_crops(
crops[class_name], dict_loc_preds[class_name]
Expand Down
26 changes: 22 additions & 4 deletions doctr/models/predictor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import numpy as np

from doctr.models.builder import DocumentBuilder
from doctr.utils.geometry import extract_crops, extract_rcrops, rotate_image
from doctr.utils.geometry import extract_crops, extract_dewarped_crops, extract_rcrops, rotate_image

from .._utils import estimate_orientation, rectify_crops, rectify_loc_preds
from ..classification import crop_orientation_predictor, page_orientation_predictor
Expand All @@ -24,6 +24,8 @@ class _OCRPredictor:
----
assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages
without rotated textual elements.
assume_straight_text: if True, speeds up the inference by assuming you only pass straight text
without rotated textual elements.
straighten_pages: if True, estimates the page general orientation based on the median line orientation.
Then, rotates page before passing it to the deep learning modules. The final predictions will be remapped
accordingly. Doing so will improve performances for documents with page-uniform rotations.
Expand All @@ -40,15 +42,21 @@ class _OCRPredictor:
def __init__(
self,
assume_straight_pages: bool = True,
assume_straight_text: bool = False,
straighten_pages: bool = False,
preserve_aspect_ratio: bool = True,
symmetric_pad: bool = True,
detect_orientation: bool = False,
**kwargs: Any,
) -> None:
self.assume_straight_pages = assume_straight_pages
self.assume_straight_text = assume_straight_text
self.straighten_pages = straighten_pages
self.crop_orientation_predictor = None if assume_straight_pages else crop_orientation_predictor(pretrained=True)
self.crop_orientation_predictor = (
None
if assume_straight_pages or (not assume_straight_pages and assume_straight_text)
else crop_orientation_predictor(pretrained=True)
)
self.page_orientation_predictor = (
page_orientation_predictor(pretrained=True)
if detect_orientation or straighten_pages or not assume_straight_pages
Expand Down Expand Up @@ -112,8 +120,15 @@ def _generate_crops(
loc_preds: List[np.ndarray],
channels_last: bool,
assume_straight_pages: bool = False,
assume_straight_text: bool = False,
) -> List[List[np.ndarray]]:
extraction_fn = extract_crops if assume_straight_pages else extract_rcrops
if assume_straight_pages:
extraction_fn = extract_crops
else:
if assume_straight_text:
extraction_fn = extract_dewarped_crops
else:
extraction_fn = extract_rcrops

crops = [
extraction_fn(page, _boxes[:, :4], channels_last=channels_last) # type: ignore[operator]
Expand All @@ -127,8 +142,11 @@ def _prepare_crops(
loc_preds: List[np.ndarray],
channels_last: bool,
assume_straight_pages: bool = False,
assume_straight_text: bool = False,
) -> Tuple[List[List[np.ndarray]], List[np.ndarray]]:
crops = _OCRPredictor._generate_crops(pages, loc_preds, channels_last, assume_straight_pages)
crops = _OCRPredictor._generate_crops(
pages, loc_preds, channels_last, assume_straight_pages, assume_straight_text
)

# Avoid sending zero-sized crops
is_kept = [[all(s > 0 for s in crop.shape) for crop in page_crops] for page_crops in crops]
Expand Down
7 changes: 6 additions & 1 deletion doctr/models/predictor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ class OCRPredictor(nn.Module, _OCRPredictor):
reco_predictor: recognition module
assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages
without rotated textual elements.
assume_straight_text: if True, speeds up the inference by assuming you only pass straight text
without rotated textual elements.
straighten_pages: if True, estimates the page general orientation based on the median line orientation.
Then, rotates page before passing it to the deep learning modules. The final predictions will be remapped
accordingly. Doing so will improve performances for documents with page-uniform rotations.
Expand All @@ -44,6 +46,7 @@ def __init__(
det_predictor: DetectionPredictor,
reco_predictor: RecognitionPredictor,
assume_straight_pages: bool = True,
assume_straight_text: bool = False,
straighten_pages: bool = False,
preserve_aspect_ratio: bool = True,
symmetric_pad: bool = True,
Expand All @@ -57,6 +60,7 @@ def __init__(
_OCRPredictor.__init__(
self,
assume_straight_pages,
assume_straight_text,
straighten_pages,
preserve_aspect_ratio,
symmetric_pad,
Expand Down Expand Up @@ -123,10 +127,11 @@ def forward(
loc_preds,
channels_last=channels_last,
assume_straight_pages=self.assume_straight_pages,
assume_straight_text=self.assume_straight_text,
)
# Rectify crop orientation and get crop orientation predictions
crop_orientations: Any = []
if not self.assume_straight_pages:
if not self.assume_straight_pages and not self.assume_straight_text:
crops, loc_preds, _crop_orientations = self._rectify_crops(crops, loc_preds)
crop_orientations = [
{"value": orientation[0], "confidence": orientation[1]} for orientation in _crop_orientations
Expand Down
2 changes: 2 additions & 0 deletions doctr/models/predictor/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class OCRPredictor(NestedObject, _OCRPredictor):
reco_predictor: recognition module
assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages
without rotated textual elements.
assume_straight_text: if True, speeds up the inference by assuming you only pass straight text
straighten_pages: if True, estimates the page general orientation based on the median line orientation.
Then, rotates page before passing it to the deep learning modules. The final predictions will be remapped
accordingly. Doing so will improve performances for documents with page-uniform rotations.
Expand All @@ -46,6 +47,7 @@ def __init__(
det_predictor: DetectionPredictor,
reco_predictor: RecognitionPredictor,
assume_straight_pages: bool = True,
assume_straight_text: bool = False,
straighten_pages: bool = False,
preserve_aspect_ratio: bool = True,
symmetric_pad: bool = True,
Expand Down
86 changes: 85 additions & 1 deletion doctr/utils/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,8 @@ def extract_rcrops(
_boxes[:, :, 0] *= width
_boxes[:, :, 1] *= height

src_img = img if channels_last else img.transpose(1, 2, 0)

src_pts = _boxes[:, :3].astype(np.float32)
# Preserve size
d1 = np.linalg.norm(src_pts[:, 0] - src_pts[:, 1], axis=-1)
Expand All @@ -469,11 +471,93 @@ def extract_rcrops(
# Use a warp transformation to extract the crop
crops = [
cv2.warpAffine(
img if channels_last else img.transpose(1, 2, 0),
src_img,
# Transformation matrix
cv2.getAffineTransform(src_pts[idx], dst_pts[idx]),
(int(d1[idx]), int(d2[idx])),
)
for idx in range(_boxes.shape[0])
]
return crops # type: ignore[return-value]


def extract_dewarped_crops(
img: np.ndarray, polys: np.ndarray, dtype=np.float32, channels_last: bool = True
) -> List[np.ndarray]:
"""Created cropped images from list of skewed/warped bounding boxes,
but containing straight text

Args:
----
img: input image
polys: bounding boxes of shape (N, 4, 2)
dtype: target data type of bounding boxes
channels_last: whether the channel dimensions is the last one instead of the last one

Returns:
-------
list of cropped images
"""
if polys.shape[0] == 0:
return []
if polys.shape[1:] != (4, 2):
raise AssertionError("polys are expected to be quadrilateral, of shape (N, 4, 2)")

# Project relative coordinates
_boxes = polys.copy()
height, width = img.shape[:2] if channels_last else img.shape[-2:]
if not np.issubdtype(_boxes.dtype, np.integer):
_boxes[:, :, 0] *= width
_boxes[:, :, 1] *= height

src_img = img if channels_last else img.transpose(1, 2, 0)

crops = []

for box in _boxes:
# Sort the points according to the x-axis
box_points = box[np.argsort(box[:, 0])]

# Divide the points into left and right
left_points = box_points[:2]
right_points = box_points[2:]

# Sort the left points according to the y-axis
left_points = left_points[np.argsort(left_points[:, 1])]
# Sort the right points according to the y-axis
right_points = right_points[np.argsort(right_points[:, 1])]
box_points = np.concatenate([left_points, right_points])

# Get the width and height of the rectangle that will contain the warped quadrilateral
# Designate the width and height based on maximum side of the quadrilateral
width_upper = np.linalg.norm(box_points[0] - box_points[2])
width_lower = np.linalg.norm(box_points[1] - box_points[3])
height_left = np.linalg.norm(box_points[0] - box_points[1])
height_right = np.linalg.norm(box_points[2] - box_points[3])

# Get the maximum width and height
rect_width = int(max(width_upper, width_lower))
rect_height = int(max(height_left, height_right))

dst_pts = np.array(
[
[0, 0], # top-left
# bottom-left
[0, rect_height - 1],
# top-right
[rect_width - 1, 0],
# bottom-right
[rect_width - 1, rect_height - 1],
],
dtype=dtype,
)

# Get the perspective transform matrix using the box points
affine_mat = cv2.getPerspectiveTransform(box_points.astype(np.float32), dst_pts)

# Perform the perspective warp to get the rectified crop
crop = cv2.warpPerspective(src_img, affine_mat, (rect_width, rect_height))

# Add the crop to the list of crops
crops.append(crop)
return crops # type: ignore[return-value]
34 changes: 34 additions & 0 deletions tests/common/test_utils_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,3 +266,37 @@ def test_extract_rcrops(mock_pdf):

# No box
assert geometry.extract_rcrops(doc_img, np.zeros((0, 4, 2))) == []


def test_extract_dewarped_crops(mock_pdf):
doc_img = DocumentFile.from_pdf(mock_pdf)[0]
num_crops = 2
rel_boxes = np.array(
[
[
[idx / num_crops, idx / num_crops],
[idx / num_crops + 0.1, idx / num_crops],
[idx / num_crops + 0.1, idx / num_crops + 0.1],
[idx / num_crops, idx / num_crops],
]
for idx in range(num_crops)
],
dtype=np.float32,
)
abs_boxes = deepcopy(rel_boxes)
abs_boxes[:, :, 0] *= doc_img.shape[1]
abs_boxes[:, :, 1] *= doc_img.shape[0]
abs_boxes = abs_boxes.astype(np.int64)

with pytest.raises(AssertionError):
geometry.extract_dewarped_crops(doc_img, np.zeros((1, 8)))
for boxes in (rel_boxes, abs_boxes):
croped_imgs = geometry.extract_dewarped_crops(doc_img, boxes)
# Number of crops
assert len(croped_imgs) == num_crops
# Data type and shape
assert all(isinstance(crop, np.ndarray) for crop in croped_imgs)
assert all(crop.ndim == 3 for crop in croped_imgs)

# No box
assert geometry.extract_dewarped_crops(doc_img, np.zeros((0, 4, 2))) == []
Loading
Loading