Skip to content

Commit

Permalink
Check that label values are valid in k2.intersect_dense_pruned. (#774)
Browse files Browse the repository at this point in the history
* Check that label values are valid in k2.intersect_dense_pruned.

* Fix CI test failures.
  • Loading branch information
csukuangfj authored Jul 5, 2021
1 parent c414ca6 commit 1cea103
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 1 deletion.
2 changes: 1 addition & 1 deletion k2/csrc/tensor_ops_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ template <typename Real>
void TestDiscountedCumSum() {
for (int32_t i = 0; i < 4; i++) {
int32_t M = RandInt(0, 1000),
T = RandInt(0, 2000); // TODO: increase.
T = RandInt(1, 2000); // TODO: increase.
while (M * T > 10000) { // don't want test to take too long.
M /= 2;
T /= 2;
Expand Down
30 changes: 30 additions & 0 deletions k2/python/k2/autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,6 +719,21 @@ def intersect_dense_pruned(a_fsas: Fsa,
Returns:
The result of the intersection.
'''
# Possible values for _k2.build_type are [Release, Debug]
if _k2.version.build_type == 'Debug':
# This check is to guarantee that all labels are in a valid range.
# If not, unpredictable errors will occur.
#
# One such situation is that someone imports a graph from Kaldi,
# whose labels are transition IDs. When the neural network output
# units are pdf IDs, this additional check will detect the mismatch.
#
assert a_fsas.labels.min() >= -1
# The first column of b_fsas.scores is -inf,
# so we use b_fsas.scores.shape[1] - 1 here
# (-1 is to exclude the column with -inf)
assert a_fsas.labels.max() < b_fsas.scores.shape[1] - 1

out_fsa = [0]

# the following return value is discarded since it is already contained
Expand Down Expand Up @@ -773,6 +788,21 @@ def intersect_dense(a_fsas: Fsa,
The result of the intersection (pruned to `output_beam`; this pruning
is exact, it uses forward and backward scores.
'''
# Possible values for _k2.build_type are [Release, Debug]
if _k2.version.build_type == 'Debug':
# This check is to guarantee that all labels are in a valid range.
# If not, unpredictable errors will occur.
#
# One such situation is that someone imports a graph from Kaldi,
# whose labels are transition IDs. When the neural network output
# units are pdf IDs, this additional check will detect the mismatch.
#
assert a_fsas.labels.min() >= -1
# The first column of b_fsas.scores is -inf,
# so we use b_fsas.scores.shape[1] - 1 here
# (-1 is to exclude the column with -inf)
assert a_fsas.labels.max() < b_fsas.scores.shape[1] - 1

out_fsa = [0]

# the following return value is discarded since it is already contained
Expand Down

0 comments on commit 1cea103

Please sign in to comment.