Skip to content

Commit

Permalink
[FIX] clip overflowing probs (#1335)
Browse files Browse the repository at this point in the history
  • Loading branch information
felixdittrich92 authored Sep 29, 2023
1 parent f865bf8 commit 8e54989
Show file tree
Hide file tree
Showing 8 changed files with 8 additions and 8 deletions.
2 changes: 1 addition & 1 deletion doctr/models/recognition/master/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def __call__(
for encoded_seq in out_idxs.cpu().numpy()
]

return list(zip(word_values, probs.numpy().tolist()))
return list(zip(word_values, probs.numpy().clip(0, 1).tolist()))


def _master(
Expand Down
2 changes: 1 addition & 1 deletion doctr/models/recognition/master/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def __call__(
decoded_strings_pred = tf.sparse.to_dense(decoded_strings_pred.to_sparse(), default_value="not valid")[:, 0]
word_values = [word.decode() for word in decoded_strings_pred.numpy().tolist()]

return list(zip(word_values, probs.numpy().tolist()))
return list(zip(word_values, probs.numpy().clip(0, 1).tolist()))


def _master(arch: str, pretrained: bool, backbone_fn, pretrained_backbone: bool = True, **kwargs: Any) -> MASTER:
Expand Down
2 changes: 1 addition & 1 deletion doctr/models/recognition/parseq/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ def __call__(
for encoded_seq in out_idxs.cpu().numpy()
]
# compute probabilties for each word up to the EOS token
probs = [preds_prob[i, : len(word)].mean().item() for i, word in enumerate(word_values)]
probs = [preds_prob[i, : len(word)].clip(0, 1).mean().item() for i, word in enumerate(word_values)]

return list(zip(word_values, probs))

Expand Down
2 changes: 1 addition & 1 deletion doctr/models/recognition/parseq/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,7 @@ def __call__(
word_values = [word.decode() for word in decoded_strings_pred.numpy().tolist()]

# compute probabilties for each word up to the EOS token
probs = [preds_prob[i, : len(word)].numpy().mean().item() for i, word in enumerate(word_values)]
probs = [preds_prob[i, : len(word)].numpy().clip(0, 1).mean().item() for i, word in enumerate(word_values)]

return list(zip(word_values, probs))

Expand Down
2 changes: 1 addition & 1 deletion doctr/models/recognition/sar/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ def __call__(
for encoded_seq in out_idxs.detach().cpu().numpy()
]

return list(zip(word_values, probs.numpy().tolist()))
return list(zip(word_values, probs.numpy().clip(0, 1).tolist()))


def _sar(
Expand Down
2 changes: 1 addition & 1 deletion doctr/models/recognition/sar/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ def __call__(
decoded_strings_pred = tf.sparse.to_dense(decoded_strings_pred.to_sparse(), default_value="not valid")[:, 0]
word_values = [word.decode() for word in decoded_strings_pred.numpy().tolist()]

return list(zip(word_values, probs.numpy().tolist()))
return list(zip(word_values, probs.numpy().clip(0, 1).tolist()))


def _sar(
Expand Down
2 changes: 1 addition & 1 deletion doctr/models/recognition/vitstr/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def __call__(
for encoded_seq in out_idxs.cpu().numpy()
]
# compute probabilties for each word up to the EOS token
probs = [preds_prob[i, : len(word)].mean().item() for i, word in enumerate(word_values)]
probs = [preds_prob[i, : len(word)].clip(0, 1).mean().item() for i, word in enumerate(word_values)]

return list(zip(word_values, probs))

Expand Down
2 changes: 1 addition & 1 deletion doctr/models/recognition/vitstr/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def __call__(
word_values = [word.decode() for word in decoded_strings_pred.numpy().tolist()]

# compute probabilties for each word up to the EOS token
probs = [preds_prob[i, : len(word)].numpy().mean().item() for i, word in enumerate(word_values)]
probs = [preds_prob[i, : len(word)].numpy().clip(0, 1).mean().item() for i, word in enumerate(word_values)]

return list(zip(word_values, probs))

Expand Down

0 comments on commit 8e54989

Please sign in to comment.