Skip to content

Commit

Permalink
replace rotation
Browse files Browse the repository at this point in the history
  • Loading branch information
felixdittrich92 committed Jul 17, 2023
1 parent d0b67e5 commit b258f74
Show file tree
Hide file tree
Showing 5 changed files with 116 additions and 22 deletions.
13 changes: 12 additions & 1 deletion doctr/models/classification/vit/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import tensorflow as tf
from tensorflow.keras import Sequential, layers
from tensorflow_addons.layers import GELU
from tensorflow.keras.activations import gelu

from doctr.datasets import VOCABS
from doctr.models.modules.transformer import EncoderBlock
Expand Down Expand Up @@ -38,6 +38,17 @@
}


class GELU(layers.Layer):
"""Gaussian Error Linear Unit activation function"""

def __init__(self, approximate: bool = False, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.approximate = approximate

def call(self, x: tf.Tensor) -> tf.Tensor:
return gelu(x, approximate=self.approximate)


class ClassifierHead(layers.Layer, NestedObject):
"""Classifier head for Vision Transformer
Expand Down
105 changes: 95 additions & 10 deletions doctr/transforms/functional/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@

import math
from copy import deepcopy
from typing import Tuple
from typing import Iterable, Optional, Tuple, Union

import numpy as np
import tensorflow as tf
import tensorflow_addons as tfa

from doctr.utils.geometry import compute_expanded_shape, rotate_abs_geoms

Expand Down Expand Up @@ -54,8 +53,28 @@ def rotated_img_tensor(img: tf.Tensor, angle: float, expand: bool = False) -> tf
h_crop, w_crop = int(round(max(exp_img.shape[0] - exp_h, 0))), int(round(min(exp_img.shape[1] - exp_w, 0)))
else:
exp_img = img
# Rotate the padded image
rotated_img = tfa.image.rotate(exp_img, angle * math.pi / 180) # Interpolation NEAREST by default

# Compute the rotation matrix
height, width = tf.cast(tf.shape(exp_img)[0], tf.float32), tf.cast(tf.shape(exp_img)[1], tf.float32)
cos_angle, sin_angle = tf.math.cos(angle * math.pi / 180.0), tf.math.sin(angle * math.pi / 180.0)
x_offset = ((width - 1) - (cos_angle * (width - 1) - sin_angle * (height - 1))) / 2.0
y_offset = ((height - 1) - (sin_angle * (width - 1) + cos_angle * (height - 1))) / 2.0

rotation_matrix = tf.convert_to_tensor(
[cos_angle, -sin_angle, x_offset, sin_angle, cos_angle, y_offset, 0.0, 0.0],
dtype=tf.float32,
)
# Rotate the image
rotated_img = tf.squeeze(
tf.raw_ops.ImageProjectiveTransformV3(
images=exp_img[None], # Add a batch dimension for compatibility with the function
transforms=rotation_matrix[None], # Add a batch dimension for compatibility with the function
output_shape=tf.shape(exp_img)[:2],
interpolation="NEAREST",
fill_mode="CONSTANT",
fill_value=tf.constant(0.0, dtype=tf.float32),
)
)
# Crop the rest
if h_crop > 0 or w_crop > 0:
h_slice = slice(h_crop // 2, -h_crop // 2) if h_crop > 0 else slice(rotated_img.shape[0])
Expand Down Expand Up @@ -133,6 +152,70 @@ def crop_detection(
return cropped_img, boxes


def _gaussian_filter(
img: tf.Tensor,
kernel_size: Union[int, Iterable[int]],
sigma: float,
mode: Optional[str] = None,
pad_value: Optional[int] = 0,
):
"""Apply Gaussian filter to image.
Adapted from: https://github.com/tensorflow/addons/blob/master/tensorflow_addons/image/filters.py
Args:
input: image to filter of shape (N, H, W, C)
kernel_size: kernel size of the filter
sigma: standard deviation of the Gaussian filter
mode: padding mode, one of "CONSTANT", "REFLECT", "SYMMETRIC"
pad_value: value to pad the image with
Returns:
A tensor of shape (N, H, W, C)
"""
ksize = tf.convert_to_tensor(tf.broadcast_to(kernel_size, [2]), dtype=tf.int32)
sigma = tf.convert_to_tensor(tf.broadcast_to(sigma, [2]), dtype=img.dtype)
assert mode in ("CONSTANT", "REFLECT", "SYMMETRIC"), "mode should be one of 'CONSTANT', 'REFLECT', 'SYMMETRIC'"
mode = "CONSTANT" if mode is None else str.upper(mode)
constant_values = (
tf.zeros([], dtype=img.dtype) if pad_value is None else tf.convert_to_tensor(pad_value, dtype=img.dtype)
)

def kernel1d(ksize: tf.Tensor, sigma: tf.Tensor, dtype: tf.DType):
x = tf.range(ksize, dtype=dtype)
x = x - tf.cast(tf.math.floordiv(ksize, 2), dtype=dtype)
x = x + tf.where(tf.math.equal(tf.math.mod(ksize, 2), 0), tf.cast(0.5, dtype), 0)
g = tf.math.exp(-(tf.math.pow(x, 2) / (2 * tf.math.pow(sigma, 2))))
g = g / tf.reduce_sum(g)
return g

def kernel2d(ksize: tf.Tensor, sigma: tf.Tensor, dtype: tf.DType):
kernel_x = kernel1d(ksize[0], sigma[0], dtype)
kernel_y = kernel1d(ksize[1], sigma[1], dtype)
return tf.matmul(
tf.expand_dims(kernel_x, axis=-1),
tf.transpose(tf.expand_dims(kernel_y, axis=-1)),
)

g = kernel2d(ksize, sigma, img.dtype)
# Pad the image
height, width = ksize[0], ksize[1]
paddings = [
[0, 0],
[(height - 1) // 2, height - 1 - (height - 1) // 2],
[(width - 1) // 2, width - 1 - (width - 1) // 2],
[0, 0],
]
img = tf.pad(img, paddings, mode=mode, constant_values=constant_values)

channel = tf.shape(img)[-1]
shape = tf.concat([ksize, tf.constant([1, 1], ksize.dtype)], axis=0)
g = tf.reshape(g, shape)
shape = tf.concat([ksize, [channel], tf.constant([1], ksize.dtype)], axis=0)
g = tf.broadcast_to(g, shape)
return tf.nn.depthwise_conv2d(img, g, [1, 1, 1, 1], padding="VALID", data_format="NHWC")


def random_shadow(img: tf.Tensor, opacity_range: Tuple[float, float], **kwargs) -> tf.Tensor:
"""Apply a random shadow to a given image
Expand All @@ -141,7 +224,7 @@ def random_shadow(img: tf.Tensor, opacity_range: Tuple[float, float], **kwargs)
opacity_range: the minimum and maximum desired opacity of the shadow
Returns:
shaded image
shadowed image
"""

shadow_mask = create_shadow_mask(img.shape[:2], **kwargs)
Expand All @@ -151,10 +234,12 @@ def random_shadow(img: tf.Tensor, opacity_range: Tuple[float, float], **kwargs)

# Add some blur to make it believable
k = 7 + int(2 * 4 * np.random.rand(1))
shadow_tensor = tfa.image.gaussian_filter2d(
shadow_tensor,
filter_shape=k,
sigma=np.random.uniform(0.5, 5.0),
sigma = np.random.uniform(0.5, 5.0)
shadow_tensor = _gaussian_filter(
shadow_tensor[tf.newaxis, ...],
kernel_size=k,
sigma=sigma,
mode="REFLECT",
)

return opacity * shadow_tensor * img + (1 - opacity) * img
return tf.squeeze(opacity * shadow_tensor * img + (1 - opacity) * img, axis=0)
16 changes: 9 additions & 7 deletions doctr/transforms/modules/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@

import numpy as np
import tensorflow as tf
import tensorflow_addons as tfa

from doctr.utils.repr import NestedObject

from ..functional.tensorflow import random_shadow
from ..functional.tensorflow import _gaussian_filter, random_shadow

__all__ = [
"Compose",
Expand Down Expand Up @@ -385,11 +384,14 @@ def extra_repr(self) -> str:

@tf.function
def __call__(self, img: tf.Tensor) -> tf.Tensor:
sigma = random.uniform(self.std[0], self.std[1])
return tfa.image.gaussian_filter2d(
img,
filter_shape=self.kernel_shape,
sigma=sigma,
return tf.squeeze(
_gaussian_filter(
img[None],
kernel_size=self.kernel_shape,
sigma=random.uniform(self.std[0], self.std[1]),

Check warning on line 391 in doctr/transforms/modules/tensorflow.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

doctr/transforms/modules/tensorflow.py#L391

Standard pseudo-random generators are not suitable for security/cryptographic purposes.
mode="REFLECT",
),
axis=0,
)


Expand Down
3 changes: 0 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ dependencies = [
[project.optional-dependencies]
tf = [
"tensorflow>=2.11.0,<3.0.0", # cf. https://github.com/mindee/doctr/pull/1182
"tensorflow-addons>=0.17.1",
"tf2onnx>=1.14.0,<2.0.0",
]
torch = [
Expand Down Expand Up @@ -92,7 +91,6 @@ docs = [
dev = [
# Tensorflow
"tensorflow>=2.11.0,<3.0.0", # cf. https://github.com/mindee/doctr/pull/1182
"tensorflow-addons>=0.17.1",
"tf2onnx>=1.14.0,<2.0.0",
# PyTorch
"torch>=1.12.0,<3.0.0",
Expand Down Expand Up @@ -155,7 +153,6 @@ module = [
"cv2.*",
"h5py.*",
"matplotlib.*",
"tensorflow_addons.*",
"pyclipper.*",
"shapely.*",
"tf2onnx.*",
Expand Down
1 change: 0 additions & 1 deletion tests/tensorflow/test_transforms_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,6 @@ def test_rotate_sample():
expected_img = tf.ones((100, 200, 3), dtype=tf.float32)
expected_polys = np.array([[0, 1], [0, 0], [1, 0], [1, 1]], dtype=np.float32)[None, ...]
rotated_img, rotated_geoms = rotate_sample(img, boxes, 90, True)
# import ipdb; ipdb.set_trace()
assert tf.math.reduce_all(rotated_img == expected_img) and np.all(rotated_geoms == expected_polys)
rotated_img, rotated_geoms = rotate_sample(img, polys, 90, True)
assert tf.math.reduce_all(rotated_img == expected_img) and np.all(rotated_geoms == expected_polys)
Expand Down

0 comments on commit b258f74

Please sign in to comment.