From 1d1ab413c8651e16b843ad641bf922434cfcad95 Mon Sep 17 00:00:00 2001 From: "Wu, Chunyuan" Date: Tue, 10 Oct 2023 09:09:25 +0800 Subject: [PATCH 1/7] convert tensor to float before calling .numpy() --- .../detection/differentiable_binarization/pytorch.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/doctr/models/detection/differentiable_binarization/pytorch.py b/doctr/models/detection/differentiable_binarization/pytorch.py index df47f34d08..708e597838 100644 --- a/doctr/models/detection/differentiable_binarization/pytorch.py +++ b/doctr/models/detection/differentiable_binarization/pytorch.py @@ -208,11 +208,16 @@ def forward( if return_model_output: out["out_map"] = prob_map + def need_conversion_to_float(dtype): + # pytorch: torch/csrc/utils/tensor_numpy.cpp:aten_to_numpy_dtype + return dtype in [torch.bfloat16] + + numpy_dtype_converter = lambda x: x.float() if need_conversion_to_float(x.dtype) else x if target is None or return_preds: # Post-process boxes (keep only text predictions) out["preds"] = [ dict(zip(self.class_names, preds)) - for preds in self.postprocessor(prob_map.detach().cpu().permute((0, 2, 3, 1)).numpy()) + for preds in self.postprocessor(numpy_dtype_converter(prob_map.detach().cpu().permute((0, 2, 3, 1))).numpy()) ] if target is not None: From effeb061326713bfdf90d393c6e9ca865cc28e9b Mon Sep 17 00:00:00 2001 From: "Wu, Chunyuan" Date: Tue, 10 Oct 2023 14:08:41 +0800 Subject: [PATCH 2/7] simplify dtype check --- .../detection/differentiable_binarization/pytorch.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/doctr/models/detection/differentiable_binarization/pytorch.py b/doctr/models/detection/differentiable_binarization/pytorch.py index 708e597838..0e423928fb 100644 --- a/doctr/models/detection/differentiable_binarization/pytorch.py +++ b/doctr/models/detection/differentiable_binarization/pytorch.py @@ -208,11 +208,8 @@ def forward( if return_model_output: out["out_map"] = prob_map - def need_conversion_to_float(dtype): - # pytorch: torch/csrc/utils/tensor_numpy.cpp:aten_to_numpy_dtype - return dtype in [torch.bfloat16] - - numpy_dtype_converter = lambda x: x.float() if need_conversion_to_float(x.dtype) else x + # bfloat16 is not supported in .numpy(): torch/csrc/utils/tensor_numpy.cpp:aten_to_numpy_dtype + numpy_dtype_converter = lambda x: x.float() if x.dtype in [torch.bfloat16] else x if target is None or return_preds: # Post-process boxes (keep only text predictions) out["preds"] = [ From 25e60bbcaf59b111e1e0376feaec3be91680d3af Mon Sep 17 00:00:00 2001 From: "Wu, Chunyuan" Date: Wed, 11 Oct 2023 20:38:20 +0800 Subject: [PATCH 3/7] make numpy_dtype_converter a common util func --- .../differentiable_binarization/pytorch.py | 8 ++++---- doctr/models/utils/pytorch.py | 14 +++++++++++++- tests/pytorch/test_models_utils_pt.py | 14 +++++++++++++- 3 files changed, 30 insertions(+), 6 deletions(-) diff --git a/doctr/models/detection/differentiable_binarization/pytorch.py b/doctr/models/detection/differentiable_binarization/pytorch.py index 0e423928fb..596790fba8 100644 --- a/doctr/models/detection/differentiable_binarization/pytorch.py +++ b/doctr/models/detection/differentiable_binarization/pytorch.py @@ -16,7 +16,7 @@ from doctr.file_utils import CLASS_NAME from ...classification import mobilenet_v3_large -from ...utils import load_pretrained_params +from ...utils import load_pretrained_params, numpy_dtype_converter from .base import DBPostProcessor, _DBNet __all__ = ["DBNet", "db_resnet50", "db_resnet34", "db_mobilenet_v3_large", "db_resnet50_rotation"] @@ -208,13 +208,13 @@ def forward( if return_model_output: out["out_map"] = prob_map - # bfloat16 is not supported in .numpy(): torch/csrc/utils/tensor_numpy.cpp:aten_to_numpy_dtype - numpy_dtype_converter = lambda x: x.float() if x.dtype in [torch.bfloat16] else x if target is None or return_preds: # Post-process boxes (keep only text predictions) out["preds"] = [ dict(zip(self.class_names, preds)) - for preds in self.postprocessor(numpy_dtype_converter(prob_map.detach().cpu().permute((0, 2, 3, 1))).numpy()) + for preds in self.postprocessor( + numpy_dtype_converter(prob_map.detach().cpu().permute((0, 2, 3, 1))).numpy() + ) ] if target is not None: diff --git a/doctr/models/utils/pytorch.py b/doctr/models/utils/pytorch.py index b24030ca17..a71f1b8e14 100644 --- a/doctr/models/utils/pytorch.py +++ b/doctr/models/utils/pytorch.py @@ -11,7 +11,14 @@ from doctr.utils.data import download_from_url -__all__ = ["load_pretrained_params", "conv_sequence_pt", "set_device_and_dtype", "export_model_to_onnx", "_copy_tensor"] +__all__ = [ + "load_pretrained_params", + "conv_sequence_pt", + "set_device_and_dtype", + "export_model_to_onnx", + "_copy_tensor", + "numpy_dtype_converter", +] def _copy_tensor(x: torch.Tensor) -> torch.Tensor: @@ -150,3 +157,8 @@ def export_model_to_onnx(model: nn.Module, model_name: str, dummy_input: torch.T ) logging.info(f"Model exported to {model_name}.onnx") return f"{model_name}.onnx" + + +def numpy_dtype_converter(input): + # bfloat16 is not supported in .numpy(): torch/csrc/utils/tensor_numpy.cpp:aten_to_numpy_dtype + return input.float() if input.dtype in [torch.bfloat16] else input diff --git a/tests/pytorch/test_models_utils_pt.py b/tests/pytorch/test_models_utils_pt.py index 12b644a1b7..94f6c68cf0 100644 --- a/tests/pytorch/test_models_utils_pt.py +++ b/tests/pytorch/test_models_utils_pt.py @@ -4,7 +4,13 @@ import torch from torch import nn -from doctr.models.utils import _copy_tensor, conv_sequence_pt, load_pretrained_params, set_device_and_dtype +from doctr.models.utils import ( + _copy_tensor, + conv_sequence_pt, + load_pretrained_params, + numpy_dtype_converter, + set_device_and_dtype, +) def test_copy_tensor(): @@ -52,3 +58,9 @@ def test_set_device_and_dtype(): model, batches = set_device_and_dtype(model, batches, device="cpu", dtype=torch.float16) assert model[0].weight.dtype == torch.float16 assert batches[0].dtype == torch.float16 + + +def test_numpy_dtype_converter(): + input = torch.randn([2, 2], dtype=torch.bfloat16) + converted_input = numpy_dtype_converter(input) + assert converted_input.dtype == torch.float32 From 7b49c20780a7f27b6a34e19c5cfba48c2b3b46b4 Mon Sep 17 00:00:00 2001 From: "Wu, Chunyuan" Date: Wed, 11 Oct 2023 20:42:10 +0800 Subject: [PATCH 4/7] add fix for detection/linknet --- doctr/models/detection/linknet/pytorch.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/doctr/models/detection/linknet/pytorch.py b/doctr/models/detection/linknet/pytorch.py index f40c4018f6..c31a02ecd3 100644 --- a/doctr/models/detection/linknet/pytorch.py +++ b/doctr/models/detection/linknet/pytorch.py @@ -14,7 +14,7 @@ from doctr.file_utils import CLASS_NAME from doctr.models.classification import resnet18, resnet34, resnet50 -from ...utils import load_pretrained_params +from ...utils import load_pretrained_params, numpy_dtype_converter from .base import LinkNetPostProcessor, _LinkNet __all__ = ["LinkNet", "linknet_resnet18", "linknet_resnet34", "linknet_resnet50"] @@ -183,7 +183,9 @@ def forward( # Post-process boxes out["preds"] = [ dict(zip(self.class_names, preds)) - for preds in self.postprocessor(prob_map.detach().cpu().permute((0, 2, 3, 1)).numpy()) + for preds in self.postprocessor( + numpy_dtype_converter(prob_map.detach().cpu().permute((0, 2, 3, 1))).numpy() + ) ] if target is not None: From 1545c17d06aa6905c65d2194d4f37fd8d198a3ee Mon Sep 17 00:00:00 2001 From: "Wu, Chunyuan" Date: Thu, 12 Oct 2023 13:24:02 +0800 Subject: [PATCH 5/7] refine the code following #1344 --- .../detection/differentiable_binarization/pytorch.py | 8 +++----- doctr/models/detection/linknet/pytorch.py | 8 +++----- doctr/models/utils/pytorch.py | 6 +++--- tests/pytorch/test_models_utils_pt.py | 10 +++++----- 4 files changed, 14 insertions(+), 18 deletions(-) diff --git a/doctr/models/detection/differentiable_binarization/pytorch.py b/doctr/models/detection/differentiable_binarization/pytorch.py index 596790fba8..a115734bf8 100644 --- a/doctr/models/detection/differentiable_binarization/pytorch.py +++ b/doctr/models/detection/differentiable_binarization/pytorch.py @@ -16,7 +16,7 @@ from doctr.file_utils import CLASS_NAME from ...classification import mobilenet_v3_large -from ...utils import load_pretrained_params, numpy_dtype_converter +from ...utils import _bf16_to_numpy_dtype, load_pretrained_params from .base import DBPostProcessor, _DBNet __all__ = ["DBNet", "db_resnet50", "db_resnet34", "db_mobilenet_v3_large", "db_resnet50_rotation"] @@ -203,7 +203,7 @@ def forward( return out if return_model_output or target is None or return_preds: - prob_map = torch.sigmoid(logits) + prob_map = _bf16_to_numpy_dtype(torch.sigmoid(logits)) if return_model_output: out["out_map"] = prob_map @@ -212,9 +212,7 @@ def forward( # Post-process boxes (keep only text predictions) out["preds"] = [ dict(zip(self.class_names, preds)) - for preds in self.postprocessor( - numpy_dtype_converter(prob_map.detach().cpu().permute((0, 2, 3, 1))).numpy() - ) + for preds in self.postprocessor((prob_map.detach().cpu().permute((0, 2, 3, 1))).numpy()) ] if target is not None: diff --git a/doctr/models/detection/linknet/pytorch.py b/doctr/models/detection/linknet/pytorch.py index c31a02ecd3..c3ad822c6f 100644 --- a/doctr/models/detection/linknet/pytorch.py +++ b/doctr/models/detection/linknet/pytorch.py @@ -14,7 +14,7 @@ from doctr.file_utils import CLASS_NAME from doctr.models.classification import resnet18, resnet34, resnet50 -from ...utils import load_pretrained_params, numpy_dtype_converter +from ...utils import _bf16_to_numpy_dtype, load_pretrained_params from .base import LinkNetPostProcessor, _LinkNet __all__ = ["LinkNet", "linknet_resnet18", "linknet_resnet34", "linknet_resnet50"] @@ -175,7 +175,7 @@ def forward( return out if return_model_output or target is None or return_preds: - prob_map = torch.sigmoid(logits) + prob_map = _bf16_to_numpy_dtype(torch.sigmoid(logits)) if return_model_output: out["out_map"] = prob_map @@ -183,9 +183,7 @@ def forward( # Post-process boxes out["preds"] = [ dict(zip(self.class_names, preds)) - for preds in self.postprocessor( - numpy_dtype_converter(prob_map.detach().cpu().permute((0, 2, 3, 1))).numpy() - ) + for preds in self.postprocessor((prob_map.detach().cpu().permute((0, 2, 3, 1))).numpy()) ] if target is not None: diff --git a/doctr/models/utils/pytorch.py b/doctr/models/utils/pytorch.py index a71f1b8e14..4e15fa628a 100644 --- a/doctr/models/utils/pytorch.py +++ b/doctr/models/utils/pytorch.py @@ -17,7 +17,7 @@ "set_device_and_dtype", "export_model_to_onnx", "_copy_tensor", - "numpy_dtype_converter", + "_bf16_to_numpy_dtype", ] @@ -159,6 +159,6 @@ def export_model_to_onnx(model: nn.Module, model_name: str, dummy_input: torch.T return f"{model_name}.onnx" -def numpy_dtype_converter(input): +def _bf16_to_numpy_dtype(x): # bfloat16 is not supported in .numpy(): torch/csrc/utils/tensor_numpy.cpp:aten_to_numpy_dtype - return input.float() if input.dtype in [torch.bfloat16] else input + return x.float() if x.dtype == torch.bfloat16 else x diff --git a/tests/pytorch/test_models_utils_pt.py b/tests/pytorch/test_models_utils_pt.py index 94f6c68cf0..4ab0ddc747 100644 --- a/tests/pytorch/test_models_utils_pt.py +++ b/tests/pytorch/test_models_utils_pt.py @@ -5,10 +5,10 @@ from torch import nn from doctr.models.utils import ( + _bf16_to_numpy_dtype, _copy_tensor, conv_sequence_pt, load_pretrained_params, - numpy_dtype_converter, set_device_and_dtype, ) @@ -60,7 +60,7 @@ def test_set_device_and_dtype(): assert batches[0].dtype == torch.float16 -def test_numpy_dtype_converter(): - input = torch.randn([2, 2], dtype=torch.bfloat16) - converted_input = numpy_dtype_converter(input) - assert converted_input.dtype == torch.float32 +def test_bf16_to_numpy_dtype(): + x = torch.randn([2, 2], dtype=torch.bfloat16) + converted_x = _bf16_to_numpy_dtype(x) + assert x.dtype == torch.bfloat16 and converted_x.dtype == torch.float32 and torch.equal(converted_x, x.float()) From 1ef6c33e32dceed943dda90763ac32b41e520948 Mon Sep 17 00:00:00 2001 From: "Wu, Chunyuan" Date: Thu, 12 Oct 2023 13:26:31 +0800 Subject: [PATCH 6/7] remove redundant parentheses --- doctr/models/detection/differentiable_binarization/pytorch.py | 2 +- doctr/models/detection/linknet/pytorch.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/doctr/models/detection/differentiable_binarization/pytorch.py b/doctr/models/detection/differentiable_binarization/pytorch.py index a115734bf8..e7236556da 100644 --- a/doctr/models/detection/differentiable_binarization/pytorch.py +++ b/doctr/models/detection/differentiable_binarization/pytorch.py @@ -212,7 +212,7 @@ def forward( # Post-process boxes (keep only text predictions) out["preds"] = [ dict(zip(self.class_names, preds)) - for preds in self.postprocessor((prob_map.detach().cpu().permute((0, 2, 3, 1))).numpy()) + for preds in self.postprocessor(prob_map.detach().cpu().permute((0, 2, 3, 1)).numpy()) ] if target is not None: diff --git a/doctr/models/detection/linknet/pytorch.py b/doctr/models/detection/linknet/pytorch.py index c3ad822c6f..79965b640a 100644 --- a/doctr/models/detection/linknet/pytorch.py +++ b/doctr/models/detection/linknet/pytorch.py @@ -183,7 +183,7 @@ def forward( # Post-process boxes out["preds"] = [ dict(zip(self.class_names, preds)) - for preds in self.postprocessor((prob_map.detach().cpu().permute((0, 2, 3, 1))).numpy()) + for preds in self.postprocessor(prob_map.detach().cpu().permute((0, 2, 3, 1)).numpy()) ] if target is not None: From 11d7ea01cc37c78dfbd68b90b1aef263aae6223d Mon Sep 17 00:00:00 2001 From: "Wu, Chunyuan" Date: Thu, 12 Oct 2023 13:34:58 +0800 Subject: [PATCH 7/7] add fix for recognition models --- doctr/models/recognition/master/pytorch.py | 4 +++- doctr/models/recognition/parseq/pytorch.py | 4 +++- doctr/models/recognition/sar/pytorch.py | 4 ++-- doctr/models/recognition/vitstr/pytorch.py | 4 ++-- 4 files changed, 10 insertions(+), 6 deletions(-) diff --git a/doctr/models/recognition/master/pytorch.py b/doctr/models/recognition/master/pytorch.py index 02ab8a1e47..eb385932c6 100644 --- a/doctr/models/recognition/master/pytorch.py +++ b/doctr/models/recognition/master/pytorch.py @@ -15,7 +15,7 @@ from doctr.models.classification import magc_resnet31 from doctr.models.modules.transformer import Decoder, PositionalEncoding -from ...utils.pytorch import load_pretrained_params +from ...utils.pytorch import _bf16_to_numpy_dtype, load_pretrained_params from .base import _MASTER, _MASTERPostProcessor __all__ = ["MASTER", "master"] @@ -195,6 +195,8 @@ def forward( else: logits = self.decode(encoded) + logits = _bf16_to_numpy_dtype(logits) + if self.exportable: out["logits"] = logits return out diff --git a/doctr/models/recognition/parseq/pytorch.py b/doctr/models/recognition/parseq/pytorch.py index 96679a0e23..25efc23252 100644 --- a/doctr/models/recognition/parseq/pytorch.py +++ b/doctr/models/recognition/parseq/pytorch.py @@ -18,7 +18,7 @@ from doctr.models.modules.transformer import MultiHeadAttention, PositionwiseFeedForward from ...classification import vit_s -from ...utils.pytorch import load_pretrained_params +from ...utils.pytorch import _bf16_to_numpy_dtype, load_pretrained_params from .base import _PARSeq, _PARSeqPostProcessor __all__ = ["PARSeq", "parseq"] @@ -362,6 +362,8 @@ def forward( else: logits = self.decode_autoregressive(features) + logits = _bf16_to_numpy_dtype(logits) + out: Dict[str, Any] = {} if self.exportable: out["logits"] = logits diff --git a/doctr/models/recognition/sar/pytorch.py b/doctr/models/recognition/sar/pytorch.py index ee7946f02e..beccea0221 100644 --- a/doctr/models/recognition/sar/pytorch.py +++ b/doctr/models/recognition/sar/pytorch.py @@ -14,7 +14,7 @@ from doctr.datasets import VOCABS from ...classification import resnet31 -from ...utils.pytorch import load_pretrained_params +from ...utils.pytorch import _bf16_to_numpy_dtype, load_pretrained_params from ..core import RecognitionModel, RecognitionPostProcessor __all__ = ["SAR", "sar_resnet31"] @@ -249,7 +249,7 @@ def forward( if self.training and target is None: raise ValueError("Need to provide labels during training for teacher forcing") - decoded_features = self.decoder(features, encoded, gt=None if target is None else gt) + decoded_features = _bf16_to_numpy_dtype(self.decoder(features, encoded, gt=None if target is None else gt)) out: Dict[str, Any] = {} if self.exportable: diff --git a/doctr/models/recognition/vitstr/pytorch.py b/doctr/models/recognition/vitstr/pytorch.py index c805ec0d17..0ff05f826a 100644 --- a/doctr/models/recognition/vitstr/pytorch.py +++ b/doctr/models/recognition/vitstr/pytorch.py @@ -14,7 +14,7 @@ from doctr.datasets import VOCABS from ...classification import vit_b, vit_s -from ...utils.pytorch import load_pretrained_params +from ...utils.pytorch import _bf16_to_numpy_dtype, load_pretrained_params from .base import _ViTSTR, _ViTSTRPostProcessor __all__ = ["ViTSTR", "vitstr_small", "vitstr_base"] @@ -95,7 +95,7 @@ def forward( B, N, E = features.size() features = features.reshape(B * N, E) logits = self.head(features).view(B, N, len(self.vocab) + 1) # (batch_size, max_length, vocab + 1) - decoded_features = logits[:, 1:] # remove cls_token + decoded_features = _bf16_to_numpy_dtype(logits[:, 1:]) # remove cls_token out: Dict[str, Any] = {} if self.exportable: