Skip to content

Commit

Permalink
[Fix] Prob comp in vitstr and parseq for empty words (#1345)
Browse files Browse the repository at this point in the history
  • Loading branch information
felixT2K authored Oct 12, 2023
1 parent 56c8356 commit 7374e89
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 4 deletions.
4 changes: 3 additions & 1 deletion doctr/models/recognition/parseq/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,9 @@ def __call__(
for encoded_seq in out_idxs.cpu().numpy()
]
# compute probabilties for each word up to the EOS token
probs = [preds_prob[i, : len(word)].clip(0, 1).mean().item() for i, word in enumerate(word_values)]
probs = [
preds_prob[i, : len(word)].clip(0, 1).mean().item() if word else 0.0 for i, word in enumerate(word_values)
]

return list(zip(word_values, probs))

Expand Down
5 changes: 4 additions & 1 deletion doctr/models/recognition/parseq/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,10 @@ def __call__(
word_values = [word.decode() for word in decoded_strings_pred.numpy().tolist()]

# compute probabilties for each word up to the EOS token
probs = [preds_prob[i, : len(word)].numpy().clip(0, 1).mean().item() for i, word in enumerate(word_values)]
probs = [
preds_prob[i, : len(word)].numpy().clip(0, 1).mean().item() if word else 0.0
for i, word in enumerate(word_values)
]

return list(zip(word_values, probs))

Expand Down
4 changes: 3 additions & 1 deletion doctr/models/recognition/vitstr/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,9 @@ def __call__(
for encoded_seq in out_idxs.cpu().numpy()
]
# compute probabilties for each word up to the EOS token
probs = [preds_prob[i, : len(word)].clip(0, 1).mean().item() for i, word in enumerate(word_values)]
probs = [
preds_prob[i, : len(word)].clip(0, 1).mean().item() if word else 0.0 for i, word in enumerate(word_values)
]

return list(zip(word_values, probs))

Expand Down
5 changes: 4 additions & 1 deletion doctr/models/recognition/vitstr/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,10 @@ def __call__(
word_values = [word.decode() for word in decoded_strings_pred.numpy().tolist()]

# compute probabilties for each word up to the EOS token
probs = [preds_prob[i, : len(word)].numpy().clip(0, 1).mean().item() for i, word in enumerate(word_values)]
probs = [
preds_prob[i, : len(word)].numpy().clip(0, 1).mean().item() if word else 0.0
for i, word in enumerate(word_values)
]

return list(zip(word_values, probs))

Expand Down

0 comments on commit 7374e89

Please sign in to comment.