Skip to content

Commit

Permalink
[misc] rename helper function for bf16 to float32 casting (#1347)
Browse files Browse the repository at this point in the history
  • Loading branch information
felixdittrich92 authored Oct 12, 2023
1 parent 74a0b1f commit e83c3ab
Show file tree
Hide file tree
Showing 17 changed files with 36 additions and 36 deletions.
4 changes: 2 additions & 2 deletions doctr/models/detection/differentiable_binarization/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from doctr.file_utils import CLASS_NAME

from ...classification import mobilenet_v3_large
from ...utils import _bf16_to_numpy_dtype, load_pretrained_params
from ...utils import _bf16_to_float32, load_pretrained_params
from .base import DBPostProcessor, _DBNet

__all__ = ["DBNet", "db_resnet50", "db_resnet34", "db_mobilenet_v3_large", "db_resnet50_rotation"]
Expand Down Expand Up @@ -203,7 +203,7 @@ def forward(
return out

if return_model_output or target is None or return_preds:
prob_map = _bf16_to_numpy_dtype(torch.sigmoid(logits))
prob_map = _bf16_to_float32(torch.sigmoid(logits))

if return_model_output:
out["out_map"] = prob_map
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from tensorflow.keras.applications import ResNet50

from doctr.file_utils import CLASS_NAME
from doctr.models.utils import IntermediateLayerGetter, _bf16_to_numpy_dtype, conv_sequence, load_pretrained_params
from doctr.models.utils import IntermediateLayerGetter, _bf16_to_float32, conv_sequence, load_pretrained_params
from doctr.utils.repr import NestedObject

from ...classification import mobilenet_v3_large
Expand Down Expand Up @@ -241,7 +241,7 @@ def call(
return out

if return_model_output or target is None or return_preds:
prob_map = _bf16_to_numpy_dtype(tf.math.sigmoid(logits))
prob_map = _bf16_to_float32(tf.math.sigmoid(logits))

if return_model_output:
out["out_map"] = prob_map
Expand Down
4 changes: 2 additions & 2 deletions doctr/models/detection/linknet/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from doctr.file_utils import CLASS_NAME
from doctr.models.classification import resnet18, resnet34, resnet50

from ...utils import _bf16_to_numpy_dtype, load_pretrained_params
from ...utils import _bf16_to_float32, load_pretrained_params
from .base import LinkNetPostProcessor, _LinkNet

__all__ = ["LinkNet", "linknet_resnet18", "linknet_resnet34", "linknet_resnet50"]
Expand Down Expand Up @@ -175,7 +175,7 @@ def forward(
return out

if return_model_output or target is None or return_preds:
prob_map = _bf16_to_numpy_dtype(torch.sigmoid(logits))
prob_map = _bf16_to_float32(torch.sigmoid(logits))
if return_model_output:
out["out_map"] = prob_map

Expand Down
4 changes: 2 additions & 2 deletions doctr/models/detection/linknet/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from doctr.file_utils import CLASS_NAME
from doctr.models.classification import resnet18, resnet34, resnet50
from doctr.models.utils import IntermediateLayerGetter, _bf16_to_numpy_dtype, conv_sequence, load_pretrained_params
from doctr.models.utils import IntermediateLayerGetter, _bf16_to_float32, conv_sequence, load_pretrained_params
from doctr.utils.repr import NestedObject

from .base import LinkNetPostProcessor, _LinkNet
Expand Down Expand Up @@ -229,7 +229,7 @@ def call(
return out

if return_model_output or target is None or return_preds:
prob_map = _bf16_to_numpy_dtype(tf.math.sigmoid(logits))
prob_map = _bf16_to_float32(tf.math.sigmoid(logits))

if return_model_output:
out["out_map"] = prob_map
Expand Down
4 changes: 2 additions & 2 deletions doctr/models/recognition/crnn/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from doctr.datasets import VOCABS

from ...classification import mobilenet_v3_large_r, mobilenet_v3_small_r, vgg16_bn_r
from ...utils.tensorflow import _bf16_to_numpy_dtype, load_pretrained_params
from ...utils.tensorflow import _bf16_to_float32, load_pretrained_params
from ..core import RecognitionModel, RecognitionPostProcessor

__all__ = ["CRNN", "crnn_vgg16_bn", "crnn_mobilenet_v3_small", "crnn_mobilenet_v3_large"]
Expand Down Expand Up @@ -199,7 +199,7 @@ def call(
w, h, c = transposed_feat.get_shape().as_list()[1:]
# B x W x H x C --> B x W x H * C
features_seq = tf.reshape(transposed_feat, shape=(-1, w, h * c))
logits = _bf16_to_numpy_dtype(self.decoder(features_seq, **kwargs))
logits = _bf16_to_float32(self.decoder(features_seq, **kwargs))

out: Dict[str, tf.Tensor] = {}
if self.exportable:
Expand Down
4 changes: 2 additions & 2 deletions doctr/models/recognition/master/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from doctr.models.classification import magc_resnet31
from doctr.models.modules.transformer import Decoder, PositionalEncoding

from ...utils.pytorch import _bf16_to_numpy_dtype, load_pretrained_params
from ...utils.pytorch import _bf16_to_float32, load_pretrained_params
from .base import _MASTER, _MASTERPostProcessor

__all__ = ["MASTER", "master"]
Expand Down Expand Up @@ -195,7 +195,7 @@ def forward(
else:
logits = self.decode(encoded)

logits = _bf16_to_numpy_dtype(logits)
logits = _bf16_to_float32(logits)

if self.exportable:
out["logits"] = logits
Expand Down
4 changes: 2 additions & 2 deletions doctr/models/recognition/master/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from doctr.models.classification import magc_resnet31
from doctr.models.modules.transformer import Decoder, PositionalEncoding

from ...utils.tensorflow import _bf16_to_numpy_dtype, load_pretrained_params
from ...utils.tensorflow import _bf16_to_float32, load_pretrained_params
from .base import _MASTER, _MASTERPostProcessor

__all__ = ["MASTER", "master"]
Expand Down Expand Up @@ -183,7 +183,7 @@ def call(
else:
logits = self.decode(encoded, **kwargs)

logits = _bf16_to_numpy_dtype(logits)
logits = _bf16_to_float32(logits)

if self.exportable:
out["logits"] = logits
Expand Down
4 changes: 2 additions & 2 deletions doctr/models/recognition/parseq/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from doctr.models.modules.transformer import MultiHeadAttention, PositionwiseFeedForward

from ...classification import vit_s
from ...utils.pytorch import _bf16_to_numpy_dtype, load_pretrained_params
from ...utils.pytorch import _bf16_to_float32, load_pretrained_params
from .base import _PARSeq, _PARSeqPostProcessor

__all__ = ["PARSeq", "parseq"]
Expand Down Expand Up @@ -362,7 +362,7 @@ def forward(
else:
logits = self.decode_autoregressive(features)

logits = _bf16_to_numpy_dtype(logits)
logits = _bf16_to_float32(logits)

out: Dict[str, Any] = {}
if self.exportable:
Expand Down
4 changes: 2 additions & 2 deletions doctr/models/recognition/parseq/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from doctr.models.modules.transformer import MultiHeadAttention, PositionwiseFeedForward

from ...classification import vit_s
from ...utils.tensorflow import _bf16_to_numpy_dtype, load_pretrained_params
from ...utils.tensorflow import _bf16_to_float32, load_pretrained_params
from .base import _PARSeq, _PARSeqPostProcessor

__all__ = ["PARSeq", "parseq"]
Expand Down Expand Up @@ -390,7 +390,7 @@ def call(
else:
logits = self.decode_autoregressive(features, **kwargs)

logits = _bf16_to_numpy_dtype(logits)
logits = _bf16_to_float32(logits)

out: Dict[str, tf.Tensor] = {}
if self.exportable:
Expand Down
4 changes: 2 additions & 2 deletions doctr/models/recognition/sar/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from doctr.datasets import VOCABS

from ...classification import resnet31
from ...utils.pytorch import _bf16_to_numpy_dtype, load_pretrained_params
from ...utils.pytorch import _bf16_to_float32, load_pretrained_params
from ..core import RecognitionModel, RecognitionPostProcessor

__all__ = ["SAR", "sar_resnet31"]
Expand Down Expand Up @@ -249,7 +249,7 @@ def forward(
if self.training and target is None:
raise ValueError("Need to provide labels during training for teacher forcing")

decoded_features = _bf16_to_numpy_dtype(self.decoder(features, encoded, gt=None if target is None else gt))
decoded_features = _bf16_to_float32(self.decoder(features, encoded, gt=None if target is None else gt))

out: Dict[str, Any] = {}
if self.exportable:
Expand Down
4 changes: 2 additions & 2 deletions doctr/models/recognition/sar/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from doctr.utils.repr import NestedObject

from ...classification import resnet31
from ...utils.tensorflow import _bf16_to_numpy_dtype, load_pretrained_params
from ...utils.tensorflow import _bf16_to_float32, load_pretrained_params
from ..core import RecognitionModel, RecognitionPostProcessor

__all__ = ["SAR", "sar_resnet31"]
Expand Down Expand Up @@ -316,7 +316,7 @@ def call(
if kwargs.get("training", False) and target is None:
raise ValueError("Need to provide labels during training for teacher forcing")

decoded_features = _bf16_to_numpy_dtype(
decoded_features = _bf16_to_float32(
self.decoder(features, encoded, gt=None if target is None else gt, **kwargs)
)

Expand Down
4 changes: 2 additions & 2 deletions doctr/models/recognition/vitstr/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from doctr.datasets import VOCABS

from ...classification import vit_b, vit_s
from ...utils.pytorch import _bf16_to_numpy_dtype, load_pretrained_params
from ...utils.pytorch import _bf16_to_float32, load_pretrained_params
from .base import _ViTSTR, _ViTSTRPostProcessor

__all__ = ["ViTSTR", "vitstr_small", "vitstr_base"]
Expand Down Expand Up @@ -95,7 +95,7 @@ def forward(
B, N, E = features.size()
features = features.reshape(B * N, E)
logits = self.head(features).view(B, N, len(self.vocab) + 1) # (batch_size, max_length, vocab + 1)
decoded_features = _bf16_to_numpy_dtype(logits[:, 1:]) # remove cls_token
decoded_features = _bf16_to_float32(logits[:, 1:]) # remove cls_token

out: Dict[str, Any] = {}
if self.exportable:
Expand Down
4 changes: 2 additions & 2 deletions doctr/models/recognition/vitstr/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from doctr.datasets import VOCABS

from ...classification import vit_b, vit_s
from ...utils.tensorflow import _bf16_to_numpy_dtype, load_pretrained_params
from ...utils.tensorflow import _bf16_to_float32, load_pretrained_params
from .base import _ViTSTR, _ViTSTRPostProcessor

__all__ = ["ViTSTR", "vitstr_small", "vitstr_base"]
Expand Down Expand Up @@ -131,7 +131,7 @@ def call(
logits = tf.reshape(
self.head(features, **kwargs), (B, N, len(self.vocab) + 1)
) # (batch_size, max_length, vocab + 1)
decoded_features = _bf16_to_numpy_dtype(logits[:, 1:]) # remove cls_token
decoded_features = _bf16_to_float32(logits[:, 1:]) # remove cls_token

out: Dict[str, tf.Tensor] = {}
if self.exportable:
Expand Down
4 changes: 2 additions & 2 deletions doctr/models/utils/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@
"set_device_and_dtype",
"export_model_to_onnx",
"_copy_tensor",
"_bf16_to_numpy_dtype",
"_bf16_to_float32",
]


def _copy_tensor(x: torch.Tensor) -> torch.Tensor:
return x.clone().detach()


def _bf16_to_numpy_dtype(x: torch.Tensor) -> torch.Tensor:
def _bf16_to_float32(x: torch.Tensor) -> torch.Tensor:
# bfloat16 is not supported in .numpy(): torch/csrc/utils/tensor_numpy.cpp:aten_to_numpy_dtype
return x.float() if x.dtype == torch.bfloat16 else x

Expand Down
4 changes: 2 additions & 2 deletions doctr/models/utils/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@
"IntermediateLayerGetter",
"export_model_to_onnx",
"_copy_tensor",
"_bf16_to_numpy_dtype",
"_bf16_to_float32",
]


def _copy_tensor(x: tf.Tensor) -> tf.Tensor:
return tf.identity(x)


def _bf16_to_numpy_dtype(x: tf.Tensor) -> tf.Tensor:
def _bf16_to_float32(x: tf.Tensor) -> tf.Tensor:
# Convert bfloat16 to float32 for numpy compatibility
return tf.cast(x, tf.float32) if x.dtype == tf.bfloat16 else x

Expand Down
6 changes: 3 additions & 3 deletions tests/pytorch/test_models_utils_pt.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torch import nn

from doctr.models.utils import (
_bf16_to_numpy_dtype,
_bf16_to_float32,
_copy_tensor,
conv_sequence_pt,
load_pretrained_params,
Expand All @@ -19,9 +19,9 @@ def test_copy_tensor():
assert m.device == x.device and m.dtype == x.dtype and m.shape == x.shape and torch.allclose(m, x)


def test_bf16_to_numpy_dtype():
def test_bf16_to_float32():
x = torch.randn([2, 2], dtype=torch.bfloat16)
converted_x = _bf16_to_numpy_dtype(x)
converted_x = _bf16_to_float32(x)
assert x.dtype == torch.bfloat16 and converted_x.dtype == torch.float32 and torch.equal(converted_x, x.float())


Expand Down
6 changes: 3 additions & 3 deletions tests/tensorflow/test_models_utils_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from doctr.models.utils import (
IntermediateLayerGetter,
_bf16_to_numpy_dtype,
_bf16_to_float32,
_copy_tensor,
conv_sequence,
load_pretrained_params,
Expand All @@ -20,9 +20,9 @@ def test_copy_tensor():
assert m.device == x.device and m.dtype == x.dtype and m.shape == x.shape and tf.reduce_all(tf.equal(m, x))


def test_bf16_to_numpy_dtype():
def test_bf16_to_float32():
x = tf.random.uniform(shape=[8], minval=0, maxval=1, dtype=tf.bfloat16)
m = _bf16_to_numpy_dtype(x)
m = _bf16_to_float32(x)
assert x.dtype == tf.bfloat16 and m.dtype == tf.float32 and tf.reduce_all(tf.equal(m, tf.cast(x, tf.float32)))


Expand Down

0 comments on commit e83c3ab

Please sign in to comment.