diff --git a/doctr/models/detection/differentiable_binarization/pytorch.py b/doctr/models/detection/differentiable_binarization/pytorch.py index df47f34d08..e7236556da 100644 --- a/doctr/models/detection/differentiable_binarization/pytorch.py +++ b/doctr/models/detection/differentiable_binarization/pytorch.py @@ -16,7 +16,7 @@ from doctr.file_utils import CLASS_NAME from ...classification import mobilenet_v3_large -from ...utils import load_pretrained_params +from ...utils import _bf16_to_numpy_dtype, load_pretrained_params from .base import DBPostProcessor, _DBNet __all__ = ["DBNet", "db_resnet50", "db_resnet34", "db_mobilenet_v3_large", "db_resnet50_rotation"] @@ -203,7 +203,7 @@ def forward( return out if return_model_output or target is None or return_preds: - prob_map = torch.sigmoid(logits) + prob_map = _bf16_to_numpy_dtype(torch.sigmoid(logits)) if return_model_output: out["out_map"] = prob_map diff --git a/doctr/models/detection/linknet/pytorch.py b/doctr/models/detection/linknet/pytorch.py index f40c4018f6..79965b640a 100644 --- a/doctr/models/detection/linknet/pytorch.py +++ b/doctr/models/detection/linknet/pytorch.py @@ -14,7 +14,7 @@ from doctr.file_utils import CLASS_NAME from doctr.models.classification import resnet18, resnet34, resnet50 -from ...utils import load_pretrained_params +from ...utils import _bf16_to_numpy_dtype, load_pretrained_params from .base import LinkNetPostProcessor, _LinkNet __all__ = ["LinkNet", "linknet_resnet18", "linknet_resnet34", "linknet_resnet50"] @@ -175,7 +175,7 @@ def forward( return out if return_model_output or target is None or return_preds: - prob_map = torch.sigmoid(logits) + prob_map = _bf16_to_numpy_dtype(torch.sigmoid(logits)) if return_model_output: out["out_map"] = prob_map diff --git a/doctr/models/recognition/master/pytorch.py b/doctr/models/recognition/master/pytorch.py index 02ab8a1e47..eb385932c6 100644 --- a/doctr/models/recognition/master/pytorch.py +++ b/doctr/models/recognition/master/pytorch.py @@ -15,7 +15,7 @@ from doctr.models.classification import magc_resnet31 from doctr.models.modules.transformer import Decoder, PositionalEncoding -from ...utils.pytorch import load_pretrained_params +from ...utils.pytorch import _bf16_to_numpy_dtype, load_pretrained_params from .base import _MASTER, _MASTERPostProcessor __all__ = ["MASTER", "master"] @@ -195,6 +195,8 @@ def forward( else: logits = self.decode(encoded) + logits = _bf16_to_numpy_dtype(logits) + if self.exportable: out["logits"] = logits return out diff --git a/doctr/models/recognition/parseq/pytorch.py b/doctr/models/recognition/parseq/pytorch.py index 96679a0e23..25efc23252 100644 --- a/doctr/models/recognition/parseq/pytorch.py +++ b/doctr/models/recognition/parseq/pytorch.py @@ -18,7 +18,7 @@ from doctr.models.modules.transformer import MultiHeadAttention, PositionwiseFeedForward from ...classification import vit_s -from ...utils.pytorch import load_pretrained_params +from ...utils.pytorch import _bf16_to_numpy_dtype, load_pretrained_params from .base import _PARSeq, _PARSeqPostProcessor __all__ = ["PARSeq", "parseq"] @@ -362,6 +362,8 @@ def forward( else: logits = self.decode_autoregressive(features) + logits = _bf16_to_numpy_dtype(logits) + out: Dict[str, Any] = {} if self.exportable: out["logits"] = logits diff --git a/doctr/models/recognition/sar/pytorch.py b/doctr/models/recognition/sar/pytorch.py index ee7946f02e..beccea0221 100644 --- a/doctr/models/recognition/sar/pytorch.py +++ b/doctr/models/recognition/sar/pytorch.py @@ -14,7 +14,7 @@ from doctr.datasets import VOCABS from ...classification import resnet31 -from ...utils.pytorch import load_pretrained_params +from ...utils.pytorch import _bf16_to_numpy_dtype, load_pretrained_params from ..core import RecognitionModel, RecognitionPostProcessor __all__ = ["SAR", "sar_resnet31"] @@ -249,7 +249,7 @@ def forward( if self.training and target is None: raise ValueError("Need to provide labels during training for teacher forcing") - decoded_features = self.decoder(features, encoded, gt=None if target is None else gt) + decoded_features = _bf16_to_numpy_dtype(self.decoder(features, encoded, gt=None if target is None else gt)) out: Dict[str, Any] = {} if self.exportable: diff --git a/doctr/models/recognition/vitstr/pytorch.py b/doctr/models/recognition/vitstr/pytorch.py index c805ec0d17..0ff05f826a 100644 --- a/doctr/models/recognition/vitstr/pytorch.py +++ b/doctr/models/recognition/vitstr/pytorch.py @@ -14,7 +14,7 @@ from doctr.datasets import VOCABS from ...classification import vit_b, vit_s -from ...utils.pytorch import load_pretrained_params +from ...utils.pytorch import _bf16_to_numpy_dtype, load_pretrained_params from .base import _ViTSTR, _ViTSTRPostProcessor __all__ = ["ViTSTR", "vitstr_small", "vitstr_base"] @@ -95,7 +95,7 @@ def forward( B, N, E = features.size() features = features.reshape(B * N, E) logits = self.head(features).view(B, N, len(self.vocab) + 1) # (batch_size, max_length, vocab + 1) - decoded_features = logits[:, 1:] # remove cls_token + decoded_features = _bf16_to_numpy_dtype(logits[:, 1:]) # remove cls_token out: Dict[str, Any] = {} if self.exportable: diff --git a/doctr/models/utils/pytorch.py b/doctr/models/utils/pytorch.py index b24030ca17..4e15fa628a 100644 --- a/doctr/models/utils/pytorch.py +++ b/doctr/models/utils/pytorch.py @@ -11,7 +11,14 @@ from doctr.utils.data import download_from_url -__all__ = ["load_pretrained_params", "conv_sequence_pt", "set_device_and_dtype", "export_model_to_onnx", "_copy_tensor"] +__all__ = [ + "load_pretrained_params", + "conv_sequence_pt", + "set_device_and_dtype", + "export_model_to_onnx", + "_copy_tensor", + "_bf16_to_numpy_dtype", +] def _copy_tensor(x: torch.Tensor) -> torch.Tensor: @@ -150,3 +157,8 @@ def export_model_to_onnx(model: nn.Module, model_name: str, dummy_input: torch.T ) logging.info(f"Model exported to {model_name}.onnx") return f"{model_name}.onnx" + + +def _bf16_to_numpy_dtype(x): + # bfloat16 is not supported in .numpy(): torch/csrc/utils/tensor_numpy.cpp:aten_to_numpy_dtype + return x.float() if x.dtype == torch.bfloat16 else x diff --git a/tests/pytorch/test_models_utils_pt.py b/tests/pytorch/test_models_utils_pt.py index 12b644a1b7..4ab0ddc747 100644 --- a/tests/pytorch/test_models_utils_pt.py +++ b/tests/pytorch/test_models_utils_pt.py @@ -4,7 +4,13 @@ import torch from torch import nn -from doctr.models.utils import _copy_tensor, conv_sequence_pt, load_pretrained_params, set_device_and_dtype +from doctr.models.utils import ( + _bf16_to_numpy_dtype, + _copy_tensor, + conv_sequence_pt, + load_pretrained_params, + set_device_and_dtype, +) def test_copy_tensor(): @@ -52,3 +58,9 @@ def test_set_device_and_dtype(): model, batches = set_device_and_dtype(model, batches, device="cpu", dtype=torch.float16) assert model[0].weight.dtype == torch.float16 assert batches[0].dtype == torch.float16 + + +def test_bf16_to_numpy_dtype(): + x = torch.randn([2, 2], dtype=torch.bfloat16) + converted_x = _bf16_to_numpy_dtype(x) + assert x.dtype == torch.bfloat16 and converted_x.dtype == torch.float32 and torch.equal(converted_x, x.float())