Skip to content

Commit

Permalink
update random apply to work also with targets
Browse files Browse the repository at this point in the history
  • Loading branch information
felixdittrich92 committed Sep 27, 2023
1 parent 56395ba commit 3443def
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 8 deletions.
8 changes: 4 additions & 4 deletions doctr/transforms/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import math
import random
from typing import Any, Callable, Dict, List, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import numpy as np

Expand Down Expand Up @@ -206,10 +206,10 @@ def __init__(self, transform: Callable[[Any], Any], p: float = 0.5) -> None:
def extra_repr(self) -> str:
return f"transform={self.transform}, p={self.p}"

def __call__(self, img: Any) -> Any:
def __call__(self, img: Any, target: Optional[np.ndarray] = None) -> Union[Any, Tuple[Any, np.ndarray]]:
if random.random() < self.p:
return self.transform(img)
return img
return self.transform(img) if target is None else self.transform(img, target)

Check warning on line 211 in doctr/transforms/modules/base.py

View check run for this annotation

Codecov / codecov/patch

doctr/transforms/modules/base.py#L211

Added line #L211 was not covered by tests
return img if target is None else (img, target)


class RandomRotate(NestedObject):
Expand Down
4 changes: 2 additions & 2 deletions references/detection/train_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def main(args):
+ (
[
T.Resize(args.input_size, preserve_aspect_ratio=True), # This does not pad
T.RandomRotate(90, expand=True),
T.RandomApply(T.RandomRotate(90, expand=True), 0.5),
T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, symmetric_pad=True),
]
if args.rotation and not args.eval_straight
Expand Down Expand Up @@ -286,7 +286,7 @@ def main(args):
+ (
[
T.Resize(args.input_size, preserve_aspect_ratio=True),
T.RandomRotate(90, expand=True),
T.RandomApply(T.RandomRotate(90, expand=True), 0.5),
T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, symmetric_pad=True),
]
if args.rotation
Expand Down
4 changes: 2 additions & 2 deletions references/detection/train_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def main(args):
+ (
[
T.Resize(args.input_size, preserve_aspect_ratio=True), # This does not pad
T.RandomRotate(90, expand=True),
T.RandomApply(T.RandomRotate(90, expand=True), 0.5),
T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, symmetric_pad=True),
]
if args.rotation and not args.eval_straight
Expand Down Expand Up @@ -240,7 +240,7 @@ def main(args):
+ (
[
T.Resize(args.input_size, preserve_aspect_ratio=True), # This does not pad
T.RandomRotate(90, expand=True),
T.RandomApply(T.RandomRotate(90, expand=True), 0.5),
T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, symmetric_pad=True),
]
if args.rotation
Expand Down
3 changes: 3 additions & 0 deletions tests/common/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ def test_randomapply():
transfo = T.RandomApply(lambda x: 1 - x)
out = transfo(1)
assert out == 0 or out == 1
transfo = T.RandomApply(lambda x, y: (1 - x, 2 * y))
out = transfo(1, np.array([2]))
assert out == (0, 4) or out == (1, 2) and isinstance(out[1], np.ndarray)
assert repr(transfo).endswith(", p=0.5)")


Expand Down

0 comments on commit 3443def

Please sign in to comment.