Skip to content

Commit

Permalink
[orientation] Enable usage of custom trained orientation models (#1708)
Browse files Browse the repository at this point in the history
  • Loading branch information
felixdittrich92 authored Aug 29, 2024
1 parent 4434213 commit 9045dcf
Show file tree
Hide file tree
Showing 9 changed files with 207 additions and 14 deletions.
73 changes: 72 additions & 1 deletion docs/source/using_doctr/custom_models_training.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Train your own model
====================

If the pretrained models don't meet your specific needs, you have the option to train your own model using the doctr library.
If the pretrained models don't meet your specific needs, you have the option to train your own model using the docTR library.
For details on the training process and the necessary data and data format, refer to the following links:

- `detection <https://github.com/mindee/doctr/tree/main/references/detection#readme>`_
Expand Down Expand Up @@ -203,3 +203,74 @@ Load a model with customized Preprocessor:
)
predictor = OCRPredictor(det_predictor, reco_predictor)
Custom orientation classification models
----------------------------------------

If you work with rotated documents and make use of the orientation classification feature by passing one of the following arguments:

* `assume_straight_pages=False`
* `detect_orientation=True`
* `straigten_pages=True`

You can train your own orientation classification model using the docTR library. For details on the training process and the necessary data and data format, refer to the following link:

- `orientation <https://github.com/mindee/doctr/blob/main/references/classification/README.md#usage-orientation-classification>`_

**NOTE**: Currently we support only `mobilenet_v3_small` models for crop and page orientation classification.

Loading your custom trained orientation classification model
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. tabs::

.. tab:: TensorFlow

.. code:: python3
from doctr.io import DocumentFile
from doctr.models import ocr_predictor, mobilenet_v3_small_page_orientation, mobilenet_v3_small_crop_orientation
from doctr.models.classification.zoo import crop_orientation_predictor, page_orientation_predictor
custom_page_orientation_model = mobilenet_v3_small_page_orientation(pretrained=False)
custom_page_orientation_model.load_weights("<path_to_checkpoint>/weights")
custom_crop_orientation_model = mobilenet_v3_small_crop_orientation(pretrained=False)
custom_crop_orientation_model.load_weights("<path_to_checkpoint>/weights")
predictor = ocr_predictor(
pretrained=True,
assume_straight_pages=False,
straighten_pages=True,
detect_orientation=True,
)
# Overwrite the default orientation models
predictor.crop_orientation_predictor = crop_orientation_predictor(custom_crop_orientation_model)
predictor.page_orientation_predictor = page_orientation_predictor(custom_page_orientation_model)
.. tab:: PyTorch

.. code:: python3
import torch
from doctr.io import DocumentFile
from doctr.models import ocr_predictor, mobilenet_v3_small_page_orientation, mobilenet_v3_small_crop_orientation
from doctr.models.classification.zoo import crop_orientation_predictor, page_orientation_predictor
custom_page_orientation_model = mobilenet_v3_small_page_orientation(pretrained=False)
page_params = torch.load('<path_to_pt>', map_location="cpu")
custom_page_orientation_model.load_state_dict(page_params)
custom_crop_orientation_model = mobilenet_v3_small_crop_orientation(pretrained=False)
crop_params = torch.load('<path_to_pt>', map_location="cpu")
custom_crop_orientation_model.load_state_dict(crop_params)
predictor = ocr_predictor(
pretrained=True,
assume_straight_pages=False,
straighten_pages=True,
detect_orientation=True,
)
# Overwrite the default orientation models
predictor.crop_orientation_predictor = crop_orientation_predictor(custom_crop_orientation_model)
predictor.page_orientation_predictor = page_orientation_predictor(custom_page_orientation_model)
4 changes: 3 additions & 1 deletion doctr/datasets/vocabs.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@
VOCABS["hebrew"] = VOCABS["english"] + "אבגדהוזחטיכלמנסעפצקרשת" + "₪"
VOCABS["hindi"] = VOCABS["hindi_letters"] + VOCABS["hindi_digits"] + VOCABS["hindi_punctuation"]
VOCABS["bangla"] = VOCABS["bangla_letters"] + VOCABS["bangla_digits"]
VOCABS["ukrainian"] = VOCABS["generic_cyrillic_letters"] + VOCABS["digits"] + VOCABS["punctuation"] + VOCABS["currency"] + "ґіїєҐІЇЄ₴"
VOCABS["ukrainian"] = (
VOCABS["generic_cyrillic_letters"] + VOCABS["digits"] + VOCABS["punctuation"] + VOCABS["currency"] + "ґіїєҐІЇЄ₴"
)
VOCABS["multilingual"] = "".join(
dict.fromkeys(
VOCABS["french"]
Expand Down
2 changes: 2 additions & 0 deletions doctr/models/classification/mobilenet/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@
from typing import Any, Dict, List, Optional

from torchvision.models import mobilenetv3
from torchvision.models.mobilenetv3 import MobileNetV3

from doctr.datasets import VOCABS

from ...utils import load_pretrained_params

__all__ = [
"MobileNetV3",
"mobilenet_v3_small",
"mobilenet_v3_small_r",
"mobilenet_v3_large",
Expand Down
26 changes: 16 additions & 10 deletions doctr/models/classification/zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,21 @@
ORIENTATION_ARCHS: List[str] = ["mobilenet_v3_small_crop_orientation", "mobilenet_v3_small_page_orientation"]


def _orientation_predictor(arch: str, pretrained: bool, **kwargs: Any) -> OrientationPredictor:
if arch not in ORIENTATION_ARCHS:
raise ValueError(f"unknown architecture '{arch}'")
def _orientation_predictor(arch: Any, pretrained: bool, model_type: str, **kwargs: Any) -> OrientationPredictor:
if isinstance(arch, str):
if arch not in ORIENTATION_ARCHS:
raise ValueError(f"unknown architecture '{arch}'")

# Load directly classifier from backbone
_model = classification.__dict__[arch](pretrained=pretrained)
else:
if not isinstance(arch, classification.MobileNetV3):
raise ValueError(f"unknown architecture: {type(arch)}")
_model = arch

# Load directly classifier from backbone
_model = classification.__dict__[arch](pretrained=pretrained)
kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"])
kwargs["std"] = kwargs.get("std", _model.cfg["std"])
kwargs["batch_size"] = kwargs.get("batch_size", 128 if "crop" in arch else 4)
kwargs["batch_size"] = kwargs.get("batch_size", 128 if model_type == "crop" else 4)
input_shape = _model.cfg["input_shape"][:-1] if is_tf_available() else _model.cfg["input_shape"][1:]
predictor = OrientationPredictor(
PreProcessor(input_shape, preserve_aspect_ratio=True, symmetric_pad=True, **kwargs), _model
Expand All @@ -51,7 +57,7 @@ def _orientation_predictor(arch: str, pretrained: bool, **kwargs: Any) -> Orient


def crop_orientation_predictor(
arch: str = "mobilenet_v3_small_crop_orientation", pretrained: bool = False, **kwargs: Any
arch: Any = "mobilenet_v3_small_crop_orientation", pretrained: bool = False, **kwargs: Any
) -> OrientationPredictor:
"""Crop orientation classification architecture.
Expand All @@ -71,11 +77,11 @@ def crop_orientation_predictor(
-------
OrientationPredictor
"""
return _orientation_predictor(arch, pretrained, **kwargs)
return _orientation_predictor(arch, pretrained, model_type="crop", **kwargs)


def page_orientation_predictor(
arch: str = "mobilenet_v3_small_page_orientation", pretrained: bool = False, **kwargs: Any
arch: Any = "mobilenet_v3_small_page_orientation", pretrained: bool = False, **kwargs: Any
) -> OrientationPredictor:
"""Page orientation classification architecture.
Expand All @@ -95,4 +101,4 @@ def page_orientation_predictor(
-------
OrientationPredictor
"""
return _orientation_predictor(arch, pretrained, **kwargs)
return _orientation_predictor(arch, pretrained, model_type="page", **kwargs)
4 changes: 2 additions & 2 deletions doctr/models/factory/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@


AVAILABLE_ARCHS = {
"classification": models.classification.zoo.ARCHS,
"classification": models.classification.zoo.ARCHS + models.classification.zoo.ORIENTATION_ARCHS,
"detection": models.detection.zoo.ARCHS,
"recognition": models.recognition.zoo.ARCHS,
}
Expand Down Expand Up @@ -174,7 +174,7 @@ def push_to_hf_hub(model: Any, model_name: str, task: str, **kwargs) -> None: #

local_cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "hub", model_name)
repo_url = HfApi().create_repo(model_name, token=get_token(), exist_ok=False)
repo = Repository(local_dir=local_cache_dir, clone_from=repo_url, use_auth_token=True)
repo = Repository(local_dir=local_cache_dir, clone_from=repo_url)

with repo.commit(commit_message):
_save_model_and_config_for_hf_hub(model, repo.local_dir, arch=arch, task=task)
Expand Down
18 changes: 18 additions & 0 deletions tests/pytorch/test_models_classification_pt.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,15 @@ def test_crop_orientation_model(mock_text_box):
assert classifier([text_box_0, text_box_270, text_box_180, text_box_90])[1] == [0, -90, 180, 90]
assert all(isinstance(pred, float) for pred in classifier([text_box_0, text_box_270, text_box_180, text_box_90])[2])

# Test custom model loading
classifier = classification.crop_orientation_predictor(
classification.mobilenet_v3_small_crop_orientation(pretrained=True)
)
assert isinstance(classifier, OrientationPredictor)

with pytest.raises(ValueError):
_ = classification.crop_orientation_predictor(classification.textnet_tiny(pretrained=True))


def test_page_orientation_model(mock_payslip):
text_box_0 = cv2.imread(mock_payslip)
Expand All @@ -147,6 +156,15 @@ def test_page_orientation_model(mock_payslip):
assert classifier([text_box_0, text_box_270, text_box_180, text_box_90])[1] == [0, -90, 180, 90]
assert all(isinstance(pred, float) for pred in classifier([text_box_0, text_box_270, text_box_180, text_box_90])[2])

# Test custom model loading
classifier = classification.page_orientation_predictor(
classification.mobilenet_v3_small_page_orientation(pretrained=True)
)
assert isinstance(classifier, OrientationPredictor)

with pytest.raises(ValueError):
_ = classification.page_orientation_predictor(classification.textnet_tiny(pretrained=True))


@pytest.mark.parametrize(
"arch_name, input_shape, output_size",
Expand Down
38 changes: 38 additions & 0 deletions tests/pytorch/test_models_zoo_pt.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from doctr.io import Document, DocumentFile
from doctr.io.elements import KIEDocument
from doctr.models import detection, recognition
from doctr.models.classification import mobilenet_v3_small_crop_orientation, mobilenet_v3_small_page_orientation
from doctr.models.classification.zoo import crop_orientation_predictor, page_orientation_predictor
from doctr.models.detection.predictor import DetectionPredictor
from doctr.models.detection.zoo import detection_predictor
from doctr.models.kie_predictor import KIEPredictor
Expand Down Expand Up @@ -85,6 +87,24 @@ def test_ocrpredictor(mock_pdf, mock_vocab, assume_straight_pages, straighten_pa
orientation = 0
assert out.pages[0].orientation["value"] == orientation

# Test with custom orientation models
custom_crop_orientation_model = mobilenet_v3_small_crop_orientation(pretrained=True)
custom_page_orientation_model = mobilenet_v3_small_page_orientation(pretrained=True)

if assume_straight_pages:
if predictor.detect_orientation or predictor.straighten_pages:
# Overwrite the default orientation models
predictor.crop_orientation_predictor = crop_orientation_predictor(custom_crop_orientation_model)
predictor.page_orientation_predictor = page_orientation_predictor(custom_page_orientation_model)
else:
# Overwrite the default orientation models
predictor.crop_orientation_predictor = crop_orientation_predictor(custom_crop_orientation_model)
predictor.page_orientation_predictor = page_orientation_predictor(custom_page_orientation_model)

out = predictor(doc)
orientation = 0
assert out.pages[0].orientation["value"] == orientation


def test_trained_ocr_predictor(mock_payslip):
doc = DocumentFile.from_images(mock_payslip)
Expand Down Expand Up @@ -209,6 +229,24 @@ def test_kiepredictor(mock_pdf, mock_vocab, assume_straight_pages, straighten_pa
orientation = 0
assert out.pages[0].orientation["value"] == orientation

# Test with custom orientation models
custom_crop_orientation_model = mobilenet_v3_small_crop_orientation(pretrained=True)
custom_page_orientation_model = mobilenet_v3_small_page_orientation(pretrained=True)

if assume_straight_pages:
if predictor.detect_orientation or predictor.straighten_pages:
# Overwrite the default orientation models
predictor.crop_orientation_predictor = crop_orientation_predictor(custom_crop_orientation_model)
predictor.page_orientation_predictor = page_orientation_predictor(custom_page_orientation_model)
else:
# Overwrite the default orientation models
predictor.crop_orientation_predictor = crop_orientation_predictor(custom_crop_orientation_model)
predictor.page_orientation_predictor = page_orientation_predictor(custom_page_orientation_model)

out = predictor(doc)
orientation = 0
assert out.pages[0].orientation["value"] == orientation


def test_trained_kie_predictor(mock_payslip):
doc = DocumentFile.from_images(mock_payslip)
Expand Down
18 changes: 18 additions & 0 deletions tests/tensorflow/test_models_classification_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,15 @@ def test_crop_orientation_model(mock_text_box):
assert classifier([text_box_0, text_box_270, text_box_180, text_box_90])[1] == [0, -90, 180, 90]
assert all(isinstance(pred, float) for pred in classifier([text_box_0, text_box_270, text_box_180, text_box_90])[2])

# Test custom model loading
classifier = classification.crop_orientation_predictor(
classification.mobilenet_v3_small_crop_orientation(pretrained=True)
)
assert isinstance(classifier, OrientationPredictor)

with pytest.raises(ValueError):
_ = classification.crop_orientation_predictor(classification.textnet_tiny(pretrained=True))


def test_page_orientation_model(mock_payslip):
text_box_0 = cv2.imread(mock_payslip)
Expand All @@ -126,6 +135,15 @@ def test_page_orientation_model(mock_payslip):
assert classifier([text_box_0, text_box_270, text_box_180, text_box_90])[1] == [0, -90, 180, 90]
assert all(isinstance(pred, float) for pred in classifier([text_box_0, text_box_270, text_box_180, text_box_90])[2])

# Test custom model loading
classifier = classification.page_orientation_predictor(
classification.mobilenet_v3_small_page_orientation(pretrained=True)
)
assert isinstance(classifier, OrientationPredictor)

with pytest.raises(ValueError):
_ = classification.page_orientation_predictor(classification.textnet_tiny(pretrained=True))


# temporarily fix to avoid killing the CI (tf2onnx v1.14 memory leak issue)
# ref.: https://github.com/mindee/doctr/pull/1201
Expand Down
38 changes: 38 additions & 0 deletions tests/tensorflow/test_models_zoo_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from doctr.io import Document, DocumentFile
from doctr.io.elements import KIEDocument
from doctr.models import detection, recognition
from doctr.models.classification import mobilenet_v3_small_crop_orientation, mobilenet_v3_small_page_orientation
from doctr.models.classification.zoo import crop_orientation_predictor, page_orientation_predictor
from doctr.models.detection.predictor import DetectionPredictor
from doctr.models.detection.zoo import detection_predictor
from doctr.models.kie_predictor import KIEPredictor
Expand Down Expand Up @@ -84,6 +86,24 @@ def test_ocrpredictor(mock_pdf, mock_vocab, assume_straight_pages, straighten_pa
language = "unknown"
assert out.pages[0].language["value"] == language

# Test with custom orientation models
custom_crop_orientation_model = mobilenet_v3_small_crop_orientation(pretrained=True)
custom_page_orientation_model = mobilenet_v3_small_page_orientation(pretrained=True)

if assume_straight_pages:
if predictor.detect_orientation or predictor.straighten_pages:
# Overwrite the default orientation models
predictor.crop_orientation_predictor = crop_orientation_predictor(custom_crop_orientation_model)
predictor.page_orientation_predictor = page_orientation_predictor(custom_page_orientation_model)
else:
# Overwrite the default orientation models
predictor.crop_orientation_predictor = crop_orientation_predictor(custom_crop_orientation_model)
predictor.page_orientation_predictor = page_orientation_predictor(custom_page_orientation_model)

out = predictor(doc)
orientation = 0
assert out.pages[0].orientation["value"] == orientation


def test_trained_ocr_predictor(mock_payslip):
doc = DocumentFile.from_images(mock_payslip)
Expand Down Expand Up @@ -207,6 +227,24 @@ def test_kiepredictor(mock_pdf, mock_vocab, assume_straight_pages, straighten_pa
language = "unknown"
assert out.pages[0].language["value"] == language

# Test with custom orientation models
custom_crop_orientation_model = mobilenet_v3_small_crop_orientation(pretrained=True)
custom_page_orientation_model = mobilenet_v3_small_page_orientation(pretrained=True)

if assume_straight_pages:
if predictor.detect_orientation or predictor.straighten_pages:
# Overwrite the default orientation models
predictor.crop_orientation_predictor = crop_orientation_predictor(custom_crop_orientation_model)
predictor.page_orientation_predictor = page_orientation_predictor(custom_page_orientation_model)
else:
# Overwrite the default orientation models
predictor.crop_orientation_predictor = crop_orientation_predictor(custom_crop_orientation_model)
predictor.page_orientation_predictor = page_orientation_predictor(custom_page_orientation_model)

out = predictor(doc)
orientation = 0
assert out.pages[0].orientation["value"] == orientation


def test_trained_kie_predictor(mock_payslip):
doc = DocumentFile.from_images(mock_payslip)
Expand Down

0 comments on commit 9045dcf

Please sign in to comment.