diff --git a/doctr/datasets/datasets/base.py b/doctr/datasets/datasets/base.py index a1b1d6299c..8d27c8fa4c 100644 --- a/doctr/datasets/datasets/base.py +++ b/doctr/datasets/datasets/base.py @@ -55,7 +55,12 @@ def __getitem__(self, index: int) -> Tuple[Any, Any]: img = self.img_transforms(img) if self.sample_transforms is not None: - if isinstance(target, dict) and all([isinstance(item, np.ndarray) for item in target.values()]): + # Conditions to assess it is detection model with multiple classes and avoid confusion with other tasks. + if ( + isinstance(target, dict) + and all([isinstance(item, np.ndarray) for item in target.values()]) + and set(target.keys()) != {"boxes", "labels"} # avoid confusion with obj detection target + ): img_transformed = copy_tensor(img) for class_name, bboxes in target.items(): img_transformed, target[class_name] = self.sample_transforms(img, bboxes)