Skip to content

Commit

Permalink
fix: 🐛 fix bug when training object detection (#1254)
Browse files Browse the repository at this point in the history
* fix: 🐛 fix bug when training object detection

* comment: add comments to make it clearer
  • Loading branch information
aminemindee authored Jul 20, 2023
1 parent 021d4f5 commit 408854a
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion doctr/datasets/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,12 @@ def __getitem__(self, index: int) -> Tuple[Any, Any]:
img = self.img_transforms(img)

if self.sample_transforms is not None:
if isinstance(target, dict) and all([isinstance(item, np.ndarray) for item in target.values()]):
# Conditions to assess it is detection model with multiple classes and avoid confusion with other tasks.
if (
isinstance(target, dict)
and all([isinstance(item, np.ndarray) for item in target.values()])
and set(target.keys()) != {"boxes", "labels"} # avoid confusion with obj detection target
):
img_transformed = copy_tensor(img)
for class_name, bboxes in target.items():
img_transformed, target[class_name] = self.sample_transforms(img, bboxes)
Expand Down

0 comments on commit 408854a

Please sign in to comment.