From 683904c5672f2dc93d3c7b31978170b56270b1df Mon Sep 17 00:00:00 2001 From: felix Date: Fri, 22 Mar 2024 12:10:40 +0100 Subject: [PATCH] bench + reparam --- docs/source/using_doctr/using_models.rst | 2 +- doctr/models/detection/fast/base.py | 2 +- doctr/models/detection/fast/pytorch.py | 2 +- doctr/models/detection/fast/tensorflow.py | 2 +- doctr/models/detection/zoo.py | 4 ++++ 5 files changed, 8 insertions(+), 4 deletions(-) diff --git a/docs/source/using_doctr/using_models.rst b/docs/source/using_doctr/using_models.rst index 850c7725d6..5086b19769 100644 --- a/docs/source/using_doctr/using_models.rst +++ b/docs/source/using_doctr/using_models.rst @@ -67,7 +67,7 @@ For a comprehensive comparison, we have compiled a detailed benchmark on publicl +----------------+---------------------------------+-----------------+---------------+------------+---------------+------------+---------------+--------------------+ | PyTorch | linknet_resnet50 | (1024, 1024, 3) | 28.8 M | 81.78 | 82.47 | 87.29 | 85.54 | 1.0 | +----------------+---------------------------------+-----------------+---------------+------------+---------------+------------+---------------+--------------------+ -| PyTorch | fast_tiny | (1024, 1024, 3) | 13.5 M (8.5M) | | | | | 0.7 (0.4) | +| PyTorch | fast_tiny | (1024, 1024, 3) | 13.5 M (8.5M) | 84.90 | 85.04 | 93.73 | 76.26 | 0.7 (0.4) | +----------------+---------------------------------+-----------------+---------------+------------+---------------+------------+---------------+--------------------+ | PyTorch | fast_small | (1024, 1024, 3) | 14.7 M (9.7M) | | | | | 0.7 (0.5) | +----------------+---------------------------------+-----------------+---------------+------------+---------------+------------+---------------+--------------------+ diff --git a/doctr/models/detection/fast/base.py b/doctr/models/detection/fast/base.py index 10ce7d3029..868c3eadec 100644 --- a/doctr/models/detection/fast/base.py +++ b/doctr/models/detection/fast/base.py @@ -31,7 +31,7 @@ class FASTPostProcessor(DetectionPostProcessor): def __init__( self, - bin_thresh: float = 0.3, + bin_thresh: float = 0.1, box_thresh: float = 0.1, assume_straight_pages: bool = True, ) -> None: diff --git a/doctr/models/detection/fast/pytorch.py b/doctr/models/detection/fast/pytorch.py index 1108e74a89..c07977474d 100644 --- a/doctr/models/detection/fast/pytorch.py +++ b/doctr/models/detection/fast/pytorch.py @@ -119,7 +119,7 @@ class FAST(_FAST, nn.Module): def __init__( self, feat_extractor: IntermediateLayerGetter, - bin_thresh: float = 0.3, + bin_thresh: float = 0.1, box_thresh: float = 0.1, dropout_prob: float = 0.1, pooling_size: int = 4, # different from paper performs better on close text-rich images diff --git a/doctr/models/detection/fast/tensorflow.py b/doctr/models/detection/fast/tensorflow.py index 2186a984fc..f0934e99c2 100644 --- a/doctr/models/detection/fast/tensorflow.py +++ b/doctr/models/detection/fast/tensorflow.py @@ -122,7 +122,7 @@ class FAST(_FAST, keras.Model, NestedObject): def __init__( self, feature_extractor: IntermediateLayerGetter, - bin_thresh: float = 0.3, + bin_thresh: float = 0.1, box_thresh: float = 0.1, dropout_prob: float = 0.1, pooling_size: int = 4, # different from paper performs better on close text-rich images diff --git a/doctr/models/detection/zoo.py b/doctr/models/detection/zoo.py index 2097be2f0c..45cbc1adc5 100644 --- a/doctr/models/detection/zoo.py +++ b/doctr/models/detection/zoo.py @@ -8,6 +8,7 @@ from doctr.file_utils import is_tf_available, is_torch_available from .. import detection +from ..detection.fast import reparameterize from ..preprocessor import PreProcessor from .predictor import DetectionPredictor @@ -51,6 +52,9 @@ def _predictor(arch: Any, pretrained: bool, assume_straight_pages: bool = True, pretrained_backbone=kwargs.get("pretrained_backbone", True), assume_straight_pages=assume_straight_pages, ) + # Reparameterize FAST models by default to lower inference latency and memory usage + if isinstance(_model, detection.FAST): + _model = reparameterize(_model) else: if not isinstance(arch, (detection.DBNet, detection.LinkNet, detection.FAST)): raise ValueError(f"unknown architecture: {type(arch)}")