Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/assume straight text #1723

Closed

Conversation

milosacimovic
Copy link
Contributor

Modifying the ocr_predictor API to support assume_straight_text as an argument.
When used with assume_straight_pages=False this reduces the reliance on an unreliable crop orientation model when the text is almost straight and additionally reduces speed of execution. It should alleviate the issue mentioned in #1455 if the use-case is one where the text is straight i.e. no rotations of 90, 180 and 270 degrees.

The main contribution to the pipeline is the logic around a new geometry function which extracts the crops while dewarping the images based on the corners of the text detection, which returns polygons (when assume_straight_pages=False and assume_straight_text=True).

… reduces the relience on unreliable crop orientation models and reduces speed of execution
Copy link

codecov bot commented Sep 13, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 96.45%. Comparing base (9045dcf) to head (f1128b7).

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1723      +/-   ##
==========================================
+ Coverage   96.40%   96.45%   +0.05%     
==========================================
  Files         164      164              
  Lines        7782     7818      +36     
==========================================
+ Hits         7502     7541      +39     
+ Misses        280      277       -3     
Flag Coverage Δ
unittests 96.45% <100.00%> (+0.05%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Contributor

@felixdittrich92 felixdittrich92 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @milosacimovic 👋
Thanks a lot for the quick PR 👍

A problem i see is the increasing complexity with ocr_predictor / kie_predictor .
From my experience adding another assume_ kwarg would users more and more confuse.
Additional it makes it only possible to disable the crop_orientation_predictor actually.

So two suggestions from my view:

Option 1:
Advantages:

  • We can avoid that the orientation models needs to be initialized
  • Clear about what it does

Disadvantages:

  • Needs also modifications in demo and api
  • 2 additional ocr_predictor / kie_predictor args (could also be passed as kwargs maybe !?)

In this case disable_page_orientation has only an effect in combination with assume_straight_pages=False and or detect_orientation=True and or straighten_pages=True where it then can handle only small rotations in the range between ~ -45 and 45 degrees
And disable_crop_orientationwould have only an effect with assume_straight_pages=False so that it everytime results in a "prediction" of 0 and 1.0 as probability (or None)

model = ocr_predictor(
    pretrained=True,
    assume_straight_pages=False,
    straighten_pages=True,
    detect_orientation=True,
    disable_page_orientation=True, # maybe as kwarg ? Then can handle only small rotations
    disable_crop_orientation=True, # maybe as kwarg ? Then returns always 0 and as prob 1.0 or None (prefered None ?)
)

Option 2:
Advantages:

  • Encapsulated from ocr_predictor and cleaner handling by specific predictor

Disadvantages:

  • Needs also modifications in demo and api
  • Possible no way that the orientation models are loaded once into RAM before removal
predictor = ocr_predictor(
    pretrained=True,
    assume_straight_pages=False,
    straighten_pages=True,
    detect_orientation=True,
)

# Overwrite the orientation models - disable
predictor.crop_orientation_predictor = crop_orientation_predictor(disabled=True)
predictor.page_orientation_predictor = page_orientation_predictor(disabled=True)

pseudo code:

def _orientation_predictor(arch: Any, pretrained: bool, model_type: str, **kwargs: Any) -> OrientationPredictor:

if kwargs.get("disabled", False):
     return OrientationPredictor(None, None)

class OrientationPredictor(nn.Module):

class OrientationPredictor(nn.Module):

    def __init__(
        self,
        pre_processor: Optional[PreProcessor] = None,
        model: Optional[nn.Module] = None,
    ) -> None:
        super().__init__()
        self.pre_processor = pre_processor
        self.model = model.eval() if model else model

    @torch.inference_mode()
    def forward(
        self,
        inputs: List[Union[np.ndarray, torch.Tensor]],
    ) -> List[Union[List[int], List[float]]]:
        # Dimension check
        if any(input.ndim != 3 for input in inputs):
            raise ValueError("incorrect input shape: all inputs are expected to be multi-channel 2D images.")

        if model is None:
            in_length = len(inputs)
            return [[0] * in_length, [0] * in_length, [1.0] in_length]

        processed_batches = self.pre_processor(inputs)
        _params = next(self.model.parameters())
        self.model, processed_batches = set_device_and_dtype(
            self.model, processed_batches, _params.device, _params.dtype
        )
        predicted_batches = [self.model(batch) for batch in processed_batches]
        # confidence
        probs = [
            torch.max(torch.softmax(batch, dim=1), dim=1).values.cpu().detach().numpy() for batch in predicted_batches
        ]
        # Postprocess predictions
        predicted_batches = [out_batch.argmax(dim=1).cpu().detach().numpy() for out_batch in predicted_batches]

        class_idxs = [int(pred) for batch in predicted_batches for pred in batch]
        classes = [int(self.model.cfg["classes"][idx]) for idx in class_idxs]
        confs = [round(float(p), 2) for prob in probs for p in prob]

        return [class_idxs, classes, confs]

Only quick and dirty to "visualize" the idea i have in mind 😅

If we could realize option 1 and use kwargs to pass these values i think i would orefer the first idea.

Additional in every case we need some entry in the documentation for the new logic.

@milosacimovic @odulcy-mindee wdyt ? 🤗

@felixdittrich92
Copy link
Contributor

felixdittrich92 commented Sep 13, 2024

So i totally agree with the feature but we should take care of both the crop and page orientation predictors and we should take care not to miss stuff:

😃

Converting to draft in the meanwhile 👍

@felixdittrich92 felixdittrich92 self-assigned this Sep 13, 2024
@felixdittrich92 felixdittrich92 added topic: documentation Improvements or additions to documentation module: models Related to doctr.models module: utils Related to doctr.utils ext: tests Related to tests folder ext: demo Related to demo folder ext: api Related to api folder framework: pytorch Related to PyTorch backend framework: tensorflow Related to TensorFlow backend topic: character classification Related to the task of character classification type: new feature New feature ext: docs Related to docs folder labels Sep 13, 2024
@felixdittrich92 felixdittrich92 added this to the 0.10.0 milestone Sep 13, 2024
@felixdittrich92 felixdittrich92 marked this pull request as draft September 13, 2024 06:47
@milosacimovic
Copy link
Contributor Author

Hi @felixdittrich92 ,
Thank you so much for considering my PR so quickly and for the immensely valuable feedback on the API changes.
I will look into the first option. However, what I would still like to get from you is your thoughts on the subtle differences between

def extract_dewarped_crops(
    img: np.ndarray, polys: np.ndarray, dtype=np.float32, channels_last: bool = True
) -> List[np.ndarray]:
    """Created cropped images from list of skewed/warped bounding boxes,
    but containing straight text

and currently used

def extract_rcrops(
    img: np.ndarray, polys: np.ndarray, dtype=np.float32, channels_last: bool = True
) -> List[np.ndarray]:
    """Created cropped images from list of rotated bounding boxes

From my experience the extract_rcrops has issues when extracting crops from slightly rotated documents (-45, 45) where it rotates the crops even though it should keep them straight.

This was actually my main complaint about the current implementation.

@felixdittrich92
Copy link
Contributor

Hi @felixdittrich92 , Thank you so much for considering my PR so quickly and for the immensely valuable feedback on the API changes. I will look into the first option. However, what I would still like to get from you is your thoughts on the subtle differences between

def extract_dewarped_crops(
    img: np.ndarray, polys: np.ndarray, dtype=np.float32, channels_last: bool = True
) -> List[np.ndarray]:
    """Created cropped images from list of skewed/warped bounding boxes,
    but containing straight text

and currently used

def extract_rcrops(
    img: np.ndarray, polys: np.ndarray, dtype=np.float32, channels_last: bool = True
) -> List[np.ndarray]:
    """Created cropped images from list of rotated bounding boxes

From my experience the extract_rcrops has issues when extracting crops from slightly rotated documents (-45, 45) where it rotates the crops even though it should keep them straight.

This was actually my main complaint about the current implementation.

I will take a look into asap 👍 But all the stuff points to the same issue so we can combine both in your PR 👍

@felixdittrich92
Copy link
Contributor

@milosacimovic Tested your function and yeah it works better for smaller rotated pages (between -45 and 45).
It's also a bit slower but not as much.
I think that's something to combine:

With disable_page_orientation=True (where we expect only small rotated pages) + your function
Otherwise: the current function

Wdyt ?

@felixdittrich92
Copy link
Contributor

felixdittrich92 commented Sep 19, 2024

Hi @milosacimovic 👋,

I quickly prototyped this feature.
Wdyt about the changes: main...felixdittrich92:doctr:disable-orient-prototype ?

Feel free to test and update your PR with my changes if everything works as expected then only the docs part (maybe some optimizations from a users view 😅) and maybe additional tests + mypy/format fixes would be open 🤗

@felixdittrich92
Copy link
Contributor

#1735

@felixdittrich92 felixdittrich92 linked an issue Sep 27, 2024 that may be closed by this pull request
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ext: api Related to api folder ext: demo Related to demo folder ext: docs Related to docs folder ext: tests Related to tests folder framework: pytorch Related to PyTorch backend framework: tensorflow Related to TensorFlow backend module: models Related to doctr.models module: utils Related to doctr.utils topic: character classification Related to the task of character classification topic: documentation Improvements or additions to documentation type: new feature New feature
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants