Skip to content

Commit

Permalink
Prune with max_arcs in IntersectDense (#820)
Browse files Browse the repository at this point in the history
* Add checking for array constructor

* Prune with max arcs

* Minor fix

* Fix typo

* Fix review comments

* Fix typo
  • Loading branch information
pkufool authored Sep 14, 2021
1 parent bbe0ded commit 2c28070
Show file tree
Hide file tree
Showing 6 changed files with 115 additions and 48 deletions.
4 changes: 4 additions & 0 deletions k2/csrc/array.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@ class Array1 {
Dtype dtype = DtypeOf<T>::dtype)
: dim_(dim), dtype_(dtype), byte_offset_(byte_offset), region_(region) {
K2_CHECK(K2_TYPE_IS_ANY(T) || dtype == DtypeOf<T>::dtype);
K2_CHECK_GE(dim_, 0) << "Array dim MUST be greater than or equal to 0, "
<< "given :" << dim;
}

Array1(ContextPtr ctx, int32_t size, T elem,
Expand Down Expand Up @@ -496,6 +498,8 @@ ToType(int64_t, Long)

void Init(ContextPtr context, int32_t size, Dtype dtype) {
K2_CHECK(K2_TYPE_IS_ANY(T) || dtype == DtypeOf<T>::dtype);
K2_CHECK_GE(size, 0) << "Array size MUST be greater than or equal to 0, "
<< "given :" << size;
dtype_ = dtype;
region_ = NewRegion(context, static_cast<size_t>(size) * ElementSize());
dim_ = size;
Expand Down
14 changes: 12 additions & 2 deletions k2/csrc/fsa_algo.h
Original file line number Diff line number Diff line change
Expand Up @@ -221,10 +221,20 @@ void IntersectDensePruned(FsaVec &a_fsas, DenseFsaVec &b_fsas,
`IsMonotonic(*a_to_b_map)` (this requirement
is related to the length-sorting requirement of
b_fsas).
@param[in] output_beam Beam with which we prune the output (analogous
@param[in] output_beam Beam with which we prune the output (analogous
to lattice-beam in Kaldi), e.g. 8. We discard arcs in
the output that are not on a path that's within
`output_beam` of the best path of the composed output.
@param[in] max_states The max number of states with which we prune the
output, mainly to avoid out-of-memory and numerical overflow.
If number of states exceeds max_states, we'll decrease
output_beam to prune out more states, util the number of
states is less than max_states.
@param[in] max_arcs The max number of arcs with which we prune the
output, mainly to avoid out-of-memory and numerical overflow.
If number of arcs exceeds max_arcs, we'll decrease
output_beam to prune out more states, util the number of
arcs is less than max_arcs.
@param[out] out Output vector of composed, pruned FSAs, with same
Dim0() as a_fsas. Elements of it may be empty if the
composed results was empty. All states in the output will be
Expand All @@ -239,7 +249,7 @@ void IntersectDensePruned(FsaVec &a_fsas, DenseFsaVec &b_fsas,
*/
void IntersectDense(FsaVec &a_fsas, DenseFsaVec &b_fsas,
const Array1<int32_t> *a_to_b_map,
float output_beam,
float output_beam, int32_t max_states, int32_t max_arcs,
FsaVec *out, Array1<int32_t> *arc_map_a,
Array1<int32_t> *arc_map_b);

Expand Down
93 changes: 63 additions & 30 deletions k2/csrc/intersect_dense.cu
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,11 @@ class MultiGraphDenseIntersect {
*/
MultiGraphDenseIntersect(FsaVec &a_fsas, DenseFsaVec &b_fsas,
const Array1<int32_t> &a_to_b_map,
float output_beam)
float output_beam, int32_t max_states,
int32_t max_arcs)
: a_fsas_(a_fsas), b_fsas_(b_fsas), a_to_b_map_(a_to_b_map),
output_beam_(output_beam) {
output_beam_(output_beam), max_states_(max_states),
max_arcs_(max_arcs) {
NVTX_RANGE(K2_FUNC);
c_ = GetContext(a_fsas.shape, b_fsas.shape, a_to_b_map);

Expand Down Expand Up @@ -214,13 +216,10 @@ class MultiGraphDenseIntersect {
int32_t product = ((size_t)(T_ + 1) * (size_t)num_states);
Renumbering renumber_states;
int32_t T = T_;
const int32_t *a_fsas_row_ids1_data = a_fsas_.RowIds(1).Data();
const int32_t *a_fsas_row_ids1_data = a_fsas_.RowIds(1).Data(),
*a_fsas_row_splits2_data = a_fsas_.RowSplits(2).Data();
FsaInfo *fsa_info_data = fsa_info_.Data();

// 15 million is max_states... this is to avoid out-of-memory conditions
// Eventually we can make this an option.
int32_t max_states = 15000000;

while (1) {
// This code is in a loop is in case we get too many states and have to
// retry. The limit `max_states` is to reduce the likelihood of
Expand All @@ -230,6 +229,8 @@ class MultiGraphDenseIntersect {
score_cutoffs = GetScoreCutoffs();
score_cutoffs_data = score_cutoffs.Data();
float **state_scores_data = state_scores_.Data();
Array1<int64_t> state_arcs(c_, product);
int64_t *state_arcs_data = state_arcs.Data();

// We'll do exclusive-sum on the following array, after setting its
// elements to 1 if the corresponding state was not pruned away. The
Expand All @@ -255,7 +256,10 @@ class MultiGraphDenseIntersect {

int32_t idx_within_fsa = i - (T + 1) * fsa_info.state_offset,
t = idx_within_fsa / fsa_info.num_states,
state_idx1 = idx_within_fsa % fsa_info.num_states;
state_idx1 = idx_within_fsa % fsa_info.num_states,
state_idx01 = fsa_info.state_offset + state_idx1,
num_arcs = a_fsas_row_splits2_data[state_idx01 + 1] -
a_fsas_row_splits2_data[state_idx01];
// In the state_scores arrays, there are 2 copies of each FSA's
// states, for backward and forward.
int32_t backward_state_idx =
Expand All @@ -273,19 +277,25 @@ class MultiGraphDenseIntersect {
if (forward_score + backward_score > cutoff) keep = 1;
}
keep_state_data[i] = keep;
state_arcs_data[i] = keep * num_arcs;
});
int32_t tot_states = renumber_states.New2Old().Dim();
if (tot_states > max_states) {
float cur_beam = output_beam_,
next_beam = cur_beam * sqrt(max_states * 1.0 / tot_states);
if (next_beam < cur_beam * 0.25)
next_beam = cur_beam * 0.25;
if (next_beam > cur_beam * 0.75)
next_beam = cur_beam * 0.75;
if (tot_states > max_states_) {
float cur_beam = output_beam_;
DecreaseBeam(max_states_, tot_states);
K2_LOG(INFO) << "Num-states " << tot_states << " exceeds limit "
<< max_states << ", decreasing beam from " << cur_beam
<< " to " << next_beam;
output_beam_ = next_beam;
<< max_states_ << ", decreasing beam from " << cur_beam
<< " to " << output_beam_;
continue;
}

int64_t tot_arcs = Sum(state_arcs);
if (tot_arcs > max_arcs_) {
float cur_beam = output_beam_;
DecreaseBeam(max_arcs_, tot_arcs);
K2_LOG(INFO) << "Num-arcs " << tot_arcs << " exceeds limit "
<< max_arcs_ << ", decreasing beam from " << cur_beam
<< " to " << output_beam_;
} else {
break;
}
Expand Down Expand Up @@ -328,7 +338,6 @@ class MultiGraphDenseIntersect {
// the answer.
Array1<int32_t> ans_state_idx01(c_, ans_tot_num_states);
int32_t *ans_state_idx01_data = ans_state_idx01.Data();
const int32_t *a_fsas_row_splits2_data = a_fsas_.RowSplits(2).Data();

// set ans_row_ids2_data, which contains an ans_idx01 that combines
// FSA-index and time-index.
Expand Down Expand Up @@ -562,9 +571,10 @@ class MultiGraphDenseIntersect {
// subsample the output shape, removing arcs that weren't kept
// TODO: make this more efficient, avoid constructing and_row_ids3.
RaggedShape ans_shape = RaggedShape4(
&ans_row_splits1, &ans_row_ids1, -1,
&ans_row_splits2, &ans_row_ids2, -1,
&ans_row_splits3_subsampled, &ans_row_ids3_subsampled, -1);
&ans_row_splits1, &ans_row_ids1, ans_row_ids1.Dim(),
&ans_row_splits2, &ans_row_ids2, ans_row_ids2.Dim(),
&ans_row_splits3_subsampled, &ans_row_ids3_subsampled,
ans_row_ids3_subsampled.Dim());

// .. remove the 't' axis
return Ragged<Arc>(RemoveAxis(ans_shape, 1), arcs);
Expand Down Expand Up @@ -805,6 +815,21 @@ class MultiGraphDenseIntersect {
&step.state_scores);
}

/*
Decrease output beam according to num_states or num_arcs, `limit` would be
the max_states or max_arcs (mainly to avoid out-of-memory conditions),
`total` would be current total states or total arcs.
*/
void DecreaseBeam(int64_t limit, int64_t total) {
float cur_beam = output_beam_,
next_beam = cur_beam * sqrt(limit * 1.0 / total);
if (next_beam < cur_beam * 0.25)
next_beam = cur_beam * 0.25;
if (next_beam > cur_beam * 0.75)
next_beam = cur_beam * 0.75;
output_beam_ = next_beam;
}

/*
Called after DoStep() is done for all time steps, returns the total scores
minus output_beam_. (This is what it does in the absence of roundoff error
Expand All @@ -825,8 +850,10 @@ class MultiGraphDenseIntersect {
float **state_scores_data = state_scores_.Data();

FsaInfo *fsa_info_data = fsa_info_.Data();
Array1<float> score_cutoffs(c_, num_fsas_);
float *score_cutoffs_data = score_cutoffs.Data();
Array1<float> score_cutoffs(c_, num_fsas_),
score_diff(c_, num_fsas_);
float *score_cutoffs_data = score_cutoffs.Data(),
*score_diff_data = score_diff.Data();
float output_beam = output_beam_;
const float minus_inf = -std::numeric_limits<float>::infinity();
K2_EVAL(
Expand Down Expand Up @@ -855,12 +882,13 @@ class MultiGraphDenseIntersect {
tot_score_min =
(tot_score_start < tot_score_end ? tot_score_start
: tot_score_end);
K2_CHECK(tot_score_end == tot_score_start ||
fabs(tot_score_end - tot_score_start) < 1.0)
<< tot_score_end << " vs "
<< tot_score_start; // TODO: remove this
score_cutoffs_data[fsa_idx0] = tot_score_min - output_beam;
score_diff_data[fsa_idx0] = fabs(tot_score_end - tot_score_start);
});
float max_diff = MaxValue(score_diff);
if (max_diff >= 1.0)
K2_LOG(WARNING) << "The difference between forward score and backward"
<< " score exceeds 1.0, the value is : " << max_diff;
return score_cutoffs;
}

Expand Down Expand Up @@ -978,11 +1006,14 @@ class MultiGraphDenseIntersect {
float output_beam_;

int32_t T_; // == b_fsas_.MaxSize(1)

int32_t max_states_; // number of max states to avoid out-of-memory
int32_t max_arcs_; // number of max arcs to avoid out-of-memory
};

void IntersectDense(FsaVec &a_fsas, DenseFsaVec &b_fsas,
const Array1<int32_t> *a_to_b_map,
float output_beam,
float output_beam, int32_t max_states, int32_t max_arcs,
FsaVec *out, Array1<int32_t> *arc_map_a,
Array1<int32_t> *arc_map_b) {
NVTX_RANGE("IntersectDense");
Expand All @@ -1001,7 +1032,9 @@ void IntersectDense(FsaVec &a_fsas, DenseFsaVec &b_fsas,

MultiGraphDenseIntersect intersector(a_fsas, b_fsas,
*a_to_b_map,
output_beam);
output_beam,
max_states,
max_arcs);

intersector.Intersect();
FsaVec ret = intersector.FormatOutput(arc_map_a, arc_map_b);
Expand Down
21 changes: 12 additions & 9 deletions k2/csrc/intersect_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,13 @@ TEST(Intersect, Simple) {

fsa = FsaToFsaVec(fsa);
float output_beam = 1000;
int32_t max_states = 15000000,
max_arcs = 1 << 30;

FsaVec out_fsas;
Array1<int32_t> arc_map_a, arc_map_b;
IntersectDense(fsa, dfsavec, nullptr,
output_beam, &out_fsas, &arc_map_a,
&arc_map_b);
IntersectDense(fsa, dfsavec, nullptr, output_beam, max_states, max_arcs,
&out_fsas, &arc_map_a, &arc_map_b);
K2_LOG(INFO) << "out_fsas = " << out_fsas << ", arc_map_a = " << arc_map_a
<< ", arc_map_b = " << arc_map_b;

Expand Down Expand Up @@ -232,9 +233,10 @@ TEST(Intersect, RandomSingle) {

FsaVec out_fsas;
float output_beam = 1000.0;
IntersectDense(fsa, dfsavec, nullptr,
output_beam, &out_fsas, &arc_map_a,
&arc_map_b);
int32_t max_states = 15000000,
max_arcs = 1 << 30;
IntersectDense(fsa, dfsavec, nullptr, output_beam, max_states, max_arcs,
&out_fsas, &arc_map_a, &arc_map_b);
K2_LOG(INFO) << "out_fsas = " << out_fsas << ", arc_map_b = " << arc_map_b;

FsaVec fsas_b = ConvertDenseToFsaVec(dfsavec);
Expand Down Expand Up @@ -305,9 +307,10 @@ TEST(Intersect, RandomFsaVec) {

FsaVec out_fsas;
float output_beam = 100000.0; // TODO(Dan) ...
IntersectDense(fsavec, dfsavec, nullptr,
output_beam, &out_fsas, &arc_map_a,
&arc_map_b);
int32_t max_states = 15000000,
max_arcs = 1 << 30;
IntersectDense(fsavec, dfsavec, nullptr, output_beam, max_states, max_arcs,
&out_fsas, &arc_map_a, &arc_map_b);
K2_LOG(INFO) << "out_fsas = " << out_fsas
<< ", arc_map_a = " << arc_map_a
<< ", arc_map_b = " << arc_map_b;
Expand Down
10 changes: 6 additions & 4 deletions k2/python/csrc/torch/fsa_algo.cu
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,8 @@ static void PybindIntersectDense(py::module &m) {
m.def(
"intersect_dense",
[](FsaVec &a_fsas, DenseFsaVec &b_fsas,
torch::optional<torch::Tensor> a_to_b_map, float output_beam)
torch::optional<torch::Tensor> a_to_b_map, float output_beam,
int32_t max_states, int32_t max_arcs)
-> std::tuple<FsaVec, torch::Tensor, torch::Tensor> {
DeviceGuard guard(a_fsas.Context());
Array1<int32_t> arc_map_a;
Expand All @@ -260,12 +261,13 @@ static void PybindIntersectDense(py::module &m) {
} else {
a_to_b_map_array = Arange(a_fsa_vec.Context(), 0, a_fsa_vec.Dim0());
}
IntersectDense(a_fsa_vec, b_fsas, &a_to_b_map_array, output_beam, &out,
&arc_map_a, &arc_map_b);
IntersectDense(a_fsa_vec, b_fsas, &a_to_b_map_array, output_beam,
max_states, max_arcs, &out, &arc_map_a, &arc_map_b);
return std::make_tuple(out, ToTorch(arc_map_a), ToTorch(arc_map_b));
},
py::arg("a_fsas"), py::arg("b_fsas"), py::arg("a_to_b_map"),
py::arg("output_beam"));
py::arg("output_beam"), py::arg("max_states") = 15000000,
py::arg("max_arcs") = 1073741824 /* 2^30 */);
}

static void PybindConnect(py::module &m) {
Expand Down
21 changes: 18 additions & 3 deletions k2/python/k2/autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,8 @@ def forward(ctx,
b_fsas: DenseFsaVec,
out_fsa: List[Fsa],
output_beam: float,
max_states: int,
max_arcs: int,
unused_scores_a: torch.Tensor,
unused_scores_b: torch.Tensor,
a_to_b_map: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -560,7 +562,9 @@ def forward(ctx,
a_fsas=a_fsas.arcs,
b_fsas=b_fsas.dense_fsa_vec,
a_to_b_map=a_to_b_map,
output_beam=output_beam)
output_beam=output_beam,
max_states=max_states,
max_arcs=max_arcs)

out_fsa[0] = Fsa(ragged_arc)

Expand Down Expand Up @@ -631,6 +635,8 @@ def backward(ctx, out_fsa_grad: torch.Tensor) \
None, # b_fsas
None, # out_fsa
None, # output_beam
None, # max_states
None, # max_arcs
grad_a, # unused_scores_a
grad_b, # unused_scores_b
None, # a_to_b_map
Expand Down Expand Up @@ -766,6 +772,8 @@ def intersect_dense_pruned(a_fsas: Fsa,
def intersect_dense(a_fsas: Fsa,
b_fsas: DenseFsaVec,
output_beam: float,
max_states: int = 15000000,
max_arcs: int = 1073741824,
a_to_b_map: Optional[torch.Tensor] = None,
seqframe_idx_name: Optional[str] = None,
frame_idx_name: Optional[str] = None) -> Fsa:
Expand All @@ -783,8 +791,14 @@ def intersect_dense(a_fsas: Fsa,
b_fsas:
Input FSAs that correspond to neural network output.
output_beam:
Beam to prune output, similar to lattice-beam in Kaldi. Relative
to best path of output.
Beam to prune output, similar to lattice-beam in Kaldi. Relative
to best path of output.
max_states:
The max number of states to prune the output, mainly to avoid
out-of-memory and numerical overflow, default 15,000,000.
max_arcs:
The max number of arcs to prune the output, mainly to avoid
out-of-memory and numerical overflow, default 1073741824(2^30).
a_to_b_map:
Maps from FSA-index in a to FSA-index in b to use for it.
If None, then we expect the number of FSAs in a_fsas to equal
Expand Down Expand Up @@ -825,6 +839,7 @@ def intersect_dense(a_fsas: Fsa,
# the following return value is discarded since it is already contained
# in `out_fsa[0].scores`
_IntersectDenseFunction.apply(a_fsas, b_fsas, out_fsa, output_beam,
max_states, max_arcs,
a_fsas.scores, b_fsas.scores, a_to_b_map,
seqframe_idx_name, frame_idx_name)
return out_fsa[0]
Expand Down

0 comments on commit 2c28070

Please sign in to comment.