Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] PT - convert BF16 tensor to float before calling .numpy() #1342

Merged
merged 7 commits into from
Oct 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions doctr/models/detection/differentiable_binarization/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from doctr.file_utils import CLASS_NAME

from ...classification import mobilenet_v3_large
from ...utils import load_pretrained_params
from ...utils import _bf16_to_numpy_dtype, load_pretrained_params
from .base import DBPostProcessor, _DBNet

__all__ = ["DBNet", "db_resnet50", "db_resnet34", "db_mobilenet_v3_large", "db_resnet50_rotation"]
Expand Down Expand Up @@ -203,7 +203,7 @@ def forward(
return out

if return_model_output or target is None or return_preds:
prob_map = torch.sigmoid(logits)
prob_map = _bf16_to_numpy_dtype(torch.sigmoid(logits))

if return_model_output:
out["out_map"] = prob_map
Expand Down
4 changes: 2 additions & 2 deletions doctr/models/detection/linknet/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from doctr.file_utils import CLASS_NAME
from doctr.models.classification import resnet18, resnet34, resnet50

from ...utils import load_pretrained_params
from ...utils import _bf16_to_numpy_dtype, load_pretrained_params
from .base import LinkNetPostProcessor, _LinkNet

__all__ = ["LinkNet", "linknet_resnet18", "linknet_resnet34", "linknet_resnet50"]
Expand Down Expand Up @@ -175,7 +175,7 @@ def forward(
return out

if return_model_output or target is None or return_preds:
prob_map = torch.sigmoid(logits)
prob_map = _bf16_to_numpy_dtype(torch.sigmoid(logits))
if return_model_output:
out["out_map"] = prob_map

Expand Down
4 changes: 3 additions & 1 deletion doctr/models/recognition/master/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from doctr.models.classification import magc_resnet31
from doctr.models.modules.transformer import Decoder, PositionalEncoding

from ...utils.pytorch import load_pretrained_params
from ...utils.pytorch import _bf16_to_numpy_dtype, load_pretrained_params
from .base import _MASTER, _MASTERPostProcessor

__all__ = ["MASTER", "master"]
Expand Down Expand Up @@ -195,6 +195,8 @@ def forward(
else:
logits = self.decode(encoded)

logits = _bf16_to_numpy_dtype(logits)

if self.exportable:
out["logits"] = logits
return out
Expand Down
4 changes: 3 additions & 1 deletion doctr/models/recognition/parseq/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from doctr.models.modules.transformer import MultiHeadAttention, PositionwiseFeedForward

from ...classification import vit_s
from ...utils.pytorch import load_pretrained_params
from ...utils.pytorch import _bf16_to_numpy_dtype, load_pretrained_params
from .base import _PARSeq, _PARSeqPostProcessor

__all__ = ["PARSeq", "parseq"]
Expand Down Expand Up @@ -362,6 +362,8 @@ def forward(
else:
logits = self.decode_autoregressive(features)

logits = _bf16_to_numpy_dtype(logits)

out: Dict[str, Any] = {}
if self.exportable:
out["logits"] = logits
Expand Down
4 changes: 2 additions & 2 deletions doctr/models/recognition/sar/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from doctr.datasets import VOCABS

from ...classification import resnet31
from ...utils.pytorch import load_pretrained_params
from ...utils.pytorch import _bf16_to_numpy_dtype, load_pretrained_params
from ..core import RecognitionModel, RecognitionPostProcessor

__all__ = ["SAR", "sar_resnet31"]
Expand Down Expand Up @@ -249,7 +249,7 @@ def forward(
if self.training and target is None:
raise ValueError("Need to provide labels during training for teacher forcing")

decoded_features = self.decoder(features, encoded, gt=None if target is None else gt)
decoded_features = _bf16_to_numpy_dtype(self.decoder(features, encoded, gt=None if target is None else gt))

out: Dict[str, Any] = {}
if self.exportable:
Expand Down
4 changes: 2 additions & 2 deletions doctr/models/recognition/vitstr/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from doctr.datasets import VOCABS

from ...classification import vit_b, vit_s
from ...utils.pytorch import load_pretrained_params
from ...utils.pytorch import _bf16_to_numpy_dtype, load_pretrained_params
from .base import _ViTSTR, _ViTSTRPostProcessor

__all__ = ["ViTSTR", "vitstr_small", "vitstr_base"]
Expand Down Expand Up @@ -95,7 +95,7 @@ def forward(
B, N, E = features.size()
features = features.reshape(B * N, E)
logits = self.head(features).view(B, N, len(self.vocab) + 1) # (batch_size, max_length, vocab + 1)
decoded_features = logits[:, 1:] # remove cls_token
decoded_features = _bf16_to_numpy_dtype(logits[:, 1:]) # remove cls_token

out: Dict[str, Any] = {}
if self.exportable:
Expand Down
14 changes: 13 additions & 1 deletion doctr/models/utils/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,14 @@

from doctr.utils.data import download_from_url

__all__ = ["load_pretrained_params", "conv_sequence_pt", "set_device_and_dtype", "export_model_to_onnx", "_copy_tensor"]
__all__ = [
"load_pretrained_params",
"conv_sequence_pt",
"set_device_and_dtype",
"export_model_to_onnx",
"_copy_tensor",
"_bf16_to_numpy_dtype",
]


def _copy_tensor(x: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -150,3 +157,8 @@ def export_model_to_onnx(model: nn.Module, model_name: str, dummy_input: torch.T
)
logging.info(f"Model exported to {model_name}.onnx")
return f"{model_name}.onnx"


def _bf16_to_numpy_dtype(x):
# bfloat16 is not supported in .numpy(): torch/csrc/utils/tensor_numpy.cpp:aten_to_numpy_dtype
return x.float() if x.dtype == torch.bfloat16 else x
14 changes: 13 additions & 1 deletion tests/pytorch/test_models_utils_pt.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,13 @@
import torch
from torch import nn

from doctr.models.utils import _copy_tensor, conv_sequence_pt, load_pretrained_params, set_device_and_dtype
from doctr.models.utils import (
_bf16_to_numpy_dtype,
_copy_tensor,
conv_sequence_pt,
load_pretrained_params,
set_device_and_dtype,
)


def test_copy_tensor():
Expand Down Expand Up @@ -52,3 +58,9 @@ def test_set_device_and_dtype():
model, batches = set_device_and_dtype(model, batches, device="cpu", dtype=torch.float16)
assert model[0].weight.dtype == torch.float16
assert batches[0].dtype == torch.float16


def test_bf16_to_numpy_dtype():
x = torch.randn([2, 2], dtype=torch.bfloat16)
converted_x = _bf16_to_numpy_dtype(x)
assert x.dtype == torch.bfloat16 and converted_x.dtype == torch.float32 and torch.equal(converted_x, x.float())
Loading