Skip to content

Commit

Permalink
fix prob computation for parseq and vitstr models
Browse files Browse the repository at this point in the history
  • Loading branch information
felixdittrich92 committed Sep 22, 2023
1 parent 8245706 commit 0a226a8
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 20 deletions.
9 changes: 4 additions & 5 deletions doctr/models/recognition/parseq/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,18 +393,17 @@ def __call__(
) -> List[Tuple[str, float]]:
# compute pred with argmax for attention models
out_idxs = logits.argmax(-1)
# N x L
probs = torch.gather(torch.softmax(logits, -1), -1, out_idxs.unsqueeze(-1)).squeeze(-1)
# Take the minimum confidence of the sequence
probs = probs.min(dim=1).values.detach().cpu()
preds_prob = torch.softmax(logits, -1).max(dim=-1)[0]

# Manual decoding
word_values = [
"".join(self._embedding[idx] for idx in encoded_seq).split("<eos>")[0]
for encoded_seq in out_idxs.cpu().numpy()
]
# compute probabilties for each word up to the EOS token
probs = [preds_prob[i, : len(word)].mean().item() for i, word in enumerate(word_values)]

return list(zip(word_values, probs.numpy().tolist()))
return list(zip(word_values, probs))


def _parseq(
Expand Down
10 changes: 5 additions & 5 deletions doctr/models/recognition/parseq/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,10 +421,7 @@ def __call__(
) -> List[Tuple[str, float]]:
# compute pred with argmax for attention models
out_idxs = tf.math.argmax(logits, axis=2)
# N x L
probs = tf.gather(tf.nn.softmax(logits, axis=-1), out_idxs, axis=-1, batch_dims=2)
# Take the minimum confidence of the sequence
probs = tf.math.reduce_min(probs, axis=1)
preds_prob = tf.math.reduce_max(tf.nn.softmax(logits, axis=-1), axis=-1)

# decode raw output of the model with tf_label_to_idx
out_idxs = tf.cast(out_idxs, dtype="int32")
Expand All @@ -434,7 +431,10 @@ def __call__(
decoded_strings_pred = tf.sparse.to_dense(decoded_strings_pred.to_sparse(), default_value="not valid")[:, 0]
word_values = [word.decode() for word in decoded_strings_pred.numpy().tolist()]

return list(zip(word_values, probs.numpy().tolist()))
# compute probabilties for each word up to the EOS token
probs = [preds_prob[i, : len(word)].numpy().mean().item() for i, word in enumerate(word_values)]

return list(zip(word_values, probs))


def _parseq(
Expand Down
9 changes: 4 additions & 5 deletions doctr/models/recognition/vitstr/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,18 +159,17 @@ def __call__(
) -> List[Tuple[str, float]]:
# compute pred with argmax for attention models
out_idxs = logits.argmax(-1)
# N x L
probs = torch.gather(torch.softmax(logits, -1), -1, out_idxs.unsqueeze(-1)).squeeze(-1)
# Take the minimum confidence of the sequence
probs = probs.min(dim=1).values.detach().cpu()
preds_prob = torch.softmax(logits, -1).max(dim=-1)[0]

# Manual decoding
word_values = [
"".join(self._embedding[idx] for idx in encoded_seq).split("<eos>")[0]
for encoded_seq in out_idxs.cpu().numpy()
]
# compute probabilties for each word up to the EOS token
probs = [preds_prob[i, : len(word)].mean().item() for i, word in enumerate(word_values)]

return list(zip(word_values, probs.numpy().tolist()))
return list(zip(word_values, probs))


def _vitstr(
Expand Down
10 changes: 5 additions & 5 deletions doctr/models/recognition/vitstr/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,10 +164,7 @@ def __call__(
) -> List[Tuple[str, float]]:
# compute pred with argmax for attention models
out_idxs = tf.math.argmax(logits, axis=2)
# N x L
probs = tf.gather(tf.nn.softmax(logits, axis=-1), out_idxs, axis=-1, batch_dims=2)
# Take the minimum confidence of the sequence
probs = tf.math.reduce_min(probs, axis=1)
preds_prob = tf.math.reduce_max(tf.nn.softmax(logits, axis=-1), axis=-1)

# decode raw output of the model with tf_label_to_idx
out_idxs = tf.cast(out_idxs, dtype="int32")
Expand All @@ -177,7 +174,10 @@ def __call__(
decoded_strings_pred = tf.sparse.to_dense(decoded_strings_pred.to_sparse(), default_value="not valid")[:, 0]
word_values = [word.decode() for word in decoded_strings_pred.numpy().tolist()]

return list(zip(word_values, probs.numpy().tolist()))
# compute probabilties for each word up to the EOS token
probs = [preds_prob[i, : len(word)].numpy().mean().item() for i, word in enumerate(word_values)]

return list(zip(word_values, probs))


def _vitstr(
Expand Down

0 comments on commit 0a226a8

Please sign in to comment.