diff --git a/k2/csrc/CMakeLists.txt b/k2/csrc/CMakeLists.txt index 7c454ffb3..221c2fbc9 100644 --- a/k2/csrc/CMakeLists.txt +++ b/k2/csrc/CMakeLists.txt @@ -46,6 +46,7 @@ add_subdirectory(host) set(context_srcs algorithms.cu array_ops.cu + connect.cu context.cu dtype.cu fsa.cu @@ -136,6 +137,7 @@ set(cuda_test_srcs algorithms_test.cu array_ops_test.cu array_test.cu + connect_test.cu dtype_test.cu fsa_algo_test.cu fsa_test.cu diff --git a/k2/csrc/connect.cu b/k2/csrc/connect.cu new file mode 100644 index 000000000..00cca1dd2 --- /dev/null +++ b/k2/csrc/connect.cu @@ -0,0 +1,516 @@ +/** + * Copyright 2021 Xiaomi Corporation (authors: Wei Kang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include "k2/csrc/array_ops.h" +#include "k2/csrc/context.h" +#include "k2/csrc/device_guard.h" +#include "k2/csrc/fsa_algo.h" +#include "k2/csrc/fsa_utils.h" +#include "k2/csrc/thread_pool.h" + +namespace k2 { + +class Connector { + public: + /** + Connector object. You should call Connect() after + constructing it. Please see Connect() declaration in header for + high-level overview of the algorithm. + + @param [in] fsas A vector of FSAs; must have 3 axes. + */ + explicit Connector(FsaVec &fsas) : c_(fsas.Context()), fsas_(fsas) { + K2_CHECK_EQ(fsas_.NumAxes(), 3); + int32_t num_states = fsas_.shape.TotSize(1); + accessible_ = Array1(c_, num_states, 0); + } + + /* + Computes the next batch of states + @param [in] cur_states Ragged array with 2 axes, with the shape + `[fsa][state]`, containing state-indexes (idx01) into fsas_. + @return Returns the states which, after processing. + */ + std::unique_ptr> GetNextBatch(Ragged &cur_states) { + NVTX_RANGE(K2_FUNC); + // Process arcs leaving all states in `cur_states` + + // First figure out how many arcs leave each state. + // And set accessiblility for each state + Array1 num_arcs_per_state(c_, cur_states.NumElements() + 1); + const int32_t *fsas_row_splits2_data = fsas_.RowSplits(2).Data(), + *states_data = cur_states.values.Data(); + int32_t *num_arcs_per_state_data = num_arcs_per_state.Data(); + char *accessible_data = accessible_.Data(); + K2_EVAL( + c_, cur_states.NumElements(), lambda_set_arcs_and_accessible_per_state, + (int32_t states_idx01)->void { + int32_t idx01 = states_data[states_idx01], + num_arcs = fsas_row_splits2_data[idx01 + 1] - + fsas_row_splits2_data[idx01]; + num_arcs_per_state_data[states_idx01] = num_arcs; + // Set accessibility + accessible_data[idx01] |= 1; + }); + ExclusiveSum(num_arcs_per_state, &num_arcs_per_state); + + // arcs_shape `[fsa][state][arc]` + RaggedShape arcs_shape = ComposeRaggedShapes( + cur_states.shape, RaggedShape2(&num_arcs_per_state, nullptr, -1)); + + // We'll be figuring out the states that these arcs leading to is not + // accessible yet (i.e. for which state_renumbering.Keep[i] == true). + int32_t total_states = fsas_.shape.TotSize(1); + Renumbering state_renumbering(c_, total_states, true); + + const int32_t *arcs_row_ids2_data = arcs_shape.RowIds(2).Data(), + *arcs_row_splits2_data = arcs_shape.RowSplits(2).Data(), + *dest_states_data = dest_states_.values.Data(); + char *keep_state_data = state_renumbering.Keep().Data(); + K2_EVAL( + c_, arcs_shape.NumElements(), lambda_set_state_renumbering, + (int32_t arcs_idx012)->void { + // note: the prefix `arcs_` means it is an idxXXX w.r.t. `arcs_shape`. + // the prefix `fsas_` means the variable is an idxXXX w.r.t. `fsas_`. + int32_t arcs_idx01 = arcs_row_ids2_data[arcs_idx012], + arcs_idx01x = arcs_row_splits2_data[arcs_idx01], + arcs_idx2 = arcs_idx012 - arcs_idx01x, + fsas_idx01 = states_data[arcs_idx01], // a state index + fsas_idx01x = fsas_row_splits2_data[fsas_idx01], + fsas_idx012 = fsas_idx01x + arcs_idx2, + fsas_dest_state_idx01 = dest_states_data[fsas_idx012]; + // 1. If this arc is a self-loop, just ignore this arc as we won't + // processe the dest_state (current state) again. + // 2. If the state this arc pointing to is accessible, skip it. + if (fsas_dest_state_idx01 == fsas_idx01 || + (accessible_data[fsas_dest_state_idx01] & 1)) { + return; + } + keep_state_data[fsas_dest_state_idx01] = 1; + }); + + Array1 new2old_map = state_renumbering.New2Old(); + if (new2old_map.Dim() == 0) { + // There are no new states. This means we terminated. + return nullptr; + } + int32_t num_states = new2old_map.Dim(); + Array1 temp(c_, 2 * num_states); + // `new_states` will contain state-ids which are idx01's into `fsas_`. + Array1 new_states = temp.Arange(0, num_states); + // `ans_row_ids` will map to FSA index + Array1 ans_row_ids = temp.Arange(num_states, 2 * num_states); + + const int32_t *new2old_map_data = new2old_map.Data(), + *fsas_row_ids1_data = fsas_.RowIds(1).Data(); + int32_t *ans_row_ids_data = ans_row_ids.Data(), + *new_states_data = new_states.Data(); + K2_EVAL( + c_, num_states, lambda_set_new_states_and_row_ids, + (int32_t state_idx)->void { + int32_t state_idx01 = new2old_map_data[state_idx], + fsa_idx0 = fsas_row_ids1_data[state_idx01]; + new_states_data[state_idx] = state_idx01; + ans_row_ids_data[state_idx] = fsa_idx0; + }); + + int32_t num_fsas = fsas_.Dim0(); + Array1 ans_row_splits(c_, num_fsas + 1); + RowIdsToRowSplits(ans_row_ids, &ans_row_splits); + + auto ans = std::make_unique>( + RaggedShape2(&ans_row_splits, &ans_row_ids, + new_states.Dim()), new_states); + return ans; + } + + /* + Computes the next batch of states in reverse order + @param [in] cur_states Ragged array with 2 axes, with the shape of + `[fsa][state]`, containing state-indexes (idx01) into fsas_. + @return Returns the states which, after processing. + */ + std::unique_ptr> GetNextBatchBackward( + Ragged &cur_states) { + NVTX_RANGE(K2_FUNC); + // Process arcs entering all states in `cur_states` + + // First figure out how many arcs enter each state. + // And set coaccessibility for each state + Array1 num_arcs_per_state(c_, cur_states.NumElements() + 1); + int32_t *num_arcs_per_state_data = num_arcs_per_state.Data(); + const int32_t *incoming_arcs_row_splits2_data = + incoming_arcs_.RowSplits(2).Data(), + *states_data = cur_states.values.Data(); + char *accessible_data = accessible_.Data(); + K2_EVAL( + c_, cur_states.NumElements(), + lambda_set_arcs_and_coaccessible_per_state, + (int32_t states_idx01)->void { + int32_t idx01 = states_data[states_idx01], + num_arcs = incoming_arcs_row_splits2_data[idx01 + 1] - + incoming_arcs_row_splits2_data[idx01]; + num_arcs_per_state_data[states_idx01] = num_arcs; + // Set coaccessiblility (mark second bit) + accessible_data[idx01] |= (1 << 1); + }); + ExclusiveSum(num_arcs_per_state, &num_arcs_per_state); + + // arcs_shape `[fsa][state][arc]` + RaggedShape arcs_shape = ComposeRaggedShapes( + cur_states.shape, RaggedShape2(&num_arcs_per_state, nullptr, -1)); + + // We'll be figuring out the states that these arcs coming from is not + // coaccessible yet (i.e. for which state_renumbering.Keep[i] == true). + int32_t total_states = fsas_.shape.TotSize(1); + Renumbering state_renumbering(c_, total_states, true); + + const int32_t *arcs_row_ids2_data = arcs_shape.RowIds(2).Data(), + *arcs_row_splits2_data = arcs_shape.RowSplits(2).Data(), + *fsas_row_splits1_data = fsas_.RowSplits(1).Data(), + *fsas_row_ids1_data = fsas_.RowIds(1).Data(), + *incoming_arcs_data = incoming_arcs_.values.Data(); + const Arc *fsas_data = fsas_.values.Data(); + char *keep_state_data = state_renumbering.Keep().Data(); + K2_EVAL( + c_, arcs_shape.NumElements(), lambda_set_arc_renumbering, + (int32_t arcs_idx012)->void { + // note: the prefix `arcs_` means it is an idxXXX w.r.t. `arcs_shape`. + // the prefix `fsas_` means the variable is an idxXXX w.r.t. `fsas_`. + int32_t arcs_idx01 = arcs_row_ids2_data[arcs_idx012], + arcs_idx01x = arcs_row_splits2_data[arcs_idx01], + arcs_idx2 = arcs_idx012 - arcs_idx01x, + fsas_idx01 = states_data[arcs_idx01], // a state index + fsas_idx0 = fsas_row_ids1_data[fsas_idx01], + fsas_idx01x = incoming_arcs_row_splits2_data[fsas_idx01], + fsas_idx012 = fsas_idx01x + arcs_idx2, + fsas_src_state_idx1 = + fsas_data[incoming_arcs_data[fsas_idx012]].src_state, + fsas_src_state_idx0x = fsas_row_splits1_data[fsas_idx0], + fsas_src_state_idx01 = + fsas_src_state_idx0x + fsas_src_state_idx1; + // 1. If this arc is a self-loop, just ignore this arc as we won't + // processe the src_state (current state) again. + // 2. If the src state entering this arc is coaccessible, skip it. + // 3. If more than one arc comes from the same state, we select only + // one arc arbitrarily. + if (fsas_src_state_idx01 == fsas_idx01 || + (accessible_data[fsas_src_state_idx01] & (1 << 1))) { + keep_state_data[fsas_src_state_idx01] = 0; + return; + } + keep_state_data[fsas_src_state_idx01] = 1; + }); + + Array1 new2old_map = state_renumbering.New2Old(); + if (new2old_map.Dim() == 0) { + // There are no new states. This means we terminated. + return nullptr; + } + int32_t num_states = new2old_map.Dim(); + Array1 temp(c_, 2 * num_states); + // `new_states` will contain state-ids which are idx01's into `fsas_`. + Array1 new_states = temp.Arange(0, num_states); + // `ans_row_ids` will map to FSA index + Array1 ans_row_ids = temp.Arange(num_states, 2 * num_states); + + const int32_t *new2old_map_data = new2old_map.Data(); + int32_t *ans_row_ids_data = ans_row_ids.Data(), + *new_states_data = new_states.Data(); + K2_EVAL( + c_, num_states, lambda_set_new_states_and_row_ids, + (int32_t state_idx)->void { + int32_t state_idx01 = new2old_map_data[state_idx], + fsa_idx0 = fsas_row_ids1_data[state_idx01]; + ans_row_ids_data[state_idx] = fsa_idx0; + new_states_data[state_idx] = state_idx01; + }); + + int32_t num_fsas = fsas_.Dim0(); + Array1 ans_row_splits(c_, num_fsas + 1); + RowIdsToRowSplits(ans_row_ids, &ans_row_splits); + + auto ans = std::make_unique>( + RaggedShape2(&ans_row_splits, &ans_row_ids, + num_states), new_states); + return ans; + } + + /* + Returns the start batch of states. This will include all start-states that + existed in the original FSAs. + */ + std::unique_ptr> GetStartBatch() { + NVTX_RANGE(K2_FUNC); + int32_t num_fsas = fsas_.Dim0(); + const int32_t *fsas_row_splits1_data = fsas_.RowSplits(1).Data(); + Array1 has_start_state(c_, num_fsas + 1); + int32_t *has_start_state_data = has_start_state.Data(); + K2_EVAL( + c_, num_fsas, lambda_set_has_start_state, (int32_t i)->void { + int32_t split = fsas_row_splits1_data[i], + next_split = fsas_row_splits1_data[i + 1]; + has_start_state_data[i] = (next_split > split); + }); + ExclusiveSum(has_start_state, &has_start_state); + + int32_t n = has_start_state[num_fsas]; + auto ans = std::make_unique>( + RaggedShape2(&has_start_state, nullptr, n), Array1(c_, n)); + int32_t *ans_data = ans->values.Data(); + const int32_t *ans_row_ids1_data = ans->RowIds(1).Data(); + K2_EVAL( + c_, n, lambda_set_start_state, (int32_t i)->void { + int32_t fsa_idx0 = ans_row_ids1_data[i], + start_state = fsas_row_splits1_data[fsa_idx0]; + // If the following fails, it likely means an input FSA was invalid + // (e.g. had exactly one state, which is not allowed). Either that, + // or a code error. + K2_DCHECK_LT(start_state, fsas_row_splits1_data[fsa_idx0 + 1]); + ans_data[i] = start_state; + }); + return ans; + } + + + /* + Returns the final batch of states. This will include all final-states that + existed in the original FSAs. + */ + std::unique_ptr> GetFinalBatch() { + NVTX_RANGE(K2_FUNC); + int32_t num_fsas = fsas_.Dim0(); + const int32_t *fsas_row_splits1_data = fsas_.RowSplits(1).Data(); + Array1 has_final_state(c_, num_fsas + 1); + int32_t *has_final_state_data = has_final_state.Data(); + K2_EVAL( + c_, num_fsas, lambda_set_has_final_state, (int32_t i)->void { + int32_t split = fsas_row_splits1_data[i], + next_split = fsas_row_splits1_data[i + 1]; + has_final_state_data[i] = (next_split > split); + }); + ExclusiveSum(has_final_state, &has_final_state); + + int32_t n = has_final_state[num_fsas]; + auto ans = std::make_unique>( + RaggedShape2(&has_final_state, nullptr, n), Array1(c_, n)); + int32_t *ans_data = ans->values.Data(); + const int32_t *ans_row_ids1_data = ans->RowIds(1).Data(); + K2_EVAL( + c_, n, lambda_set_final_state, (int32_t i)->void { + int32_t fsa_idx0 = ans_row_ids1_data[i], + final_state = fsas_row_splits1_data[fsa_idx0 + 1] - 1; + // If the following fails, it likely means an input FSA was invalid + // (e.g. had exactly one state, which is not allowed). Either that, + // or a code error. + K2_DCHECK_GT(final_state, fsas_row_splits1_data[fsa_idx0]); + ans_data[i] = final_state; + }); + return ans; + } + + /* + Traverse the fsa from start states to mark the accessible states. + */ + static void ForwardPassStatic(Connector* c) { + // WARNING: this is run in a separate thread, so we have + // to reset its default device. Otherwise, it will throw later + // if the main thread is using a different device. + DeviceGuard guard(c->c_); + auto iter = c->GetStartBatch(); + while (iter != nullptr) + iter = c->GetNextBatch(*iter); + } + + /* + Traverse the fsa in reverse order (from final states) to mark the + coaccessible states. + */ + static void BackwardPassStatic(Connector* c) { + // WARNING: this is run in a separate thread, so we have + // to reset its default device. Otherwise, it will throw later + // if the main thread is using a different device. + DeviceGuard guard(c->c_); + auto riter = c->GetFinalBatch(); + while (riter != nullptr) + riter = c->GetNextBatchBackward(*riter); + } + + /* Does the main work of connecting and returns the resulting FSAs. + @param [out] arc_map if non-NULL, the map from (arcs in output) + to (corresponding arcs in input) is written to here. + @return Returns the connected FsaVec. + */ + FsaVec Connect(Array1 *arc_map) { + NVTX_RANGE(K2_FUNC); + Array1 dest_states_idx01 = GetDestStates(fsas_, true); + dest_states_ = Ragged(fsas_.shape, dest_states_idx01); + incoming_arcs_ = GetIncomingArcs(fsas_, dest_states_idx01); + + ThreadPool* pool = GetThreadPool(); + // Mark accessible states + pool->SubmitTask([this]() { ForwardPassStatic(this); }); + // Mark coaccessible states + pool->SubmitTask([this]() { BackwardPassStatic(this); }); + pool->WaitAllTasksFinished(); + + // Get remaining states and construct row_ids1/row_splits1 + int32_t num_states = fsas_.shape.TotSize(1); + const char *accessible_data = accessible_.Data(); + Renumbering states_renumbering(c_, num_states); + char* states_renumbering_data = states_renumbering.Keep().Data(); + K2_EVAL( + c_, num_states, lambda_set_states_renumbering, + (int32_t state_idx01)->void { + if (accessible_data[state_idx01] == 3) // 3 in hex is 0x0000 0011 + states_renumbering_data[state_idx01] = 1; + else + states_renumbering_data[state_idx01] = 0; + }); + Array1 new2old_map_states = states_renumbering.New2Old(); + Array1 old2new_map_states = states_renumbering.Old2New(); + Array1 new_row_ids1 = fsas_.RowIds(1)[new2old_map_states]; + Array1 new_row_splits1(c_, fsas_.Dim0() + 1); + RowIdsToRowSplits(new_row_ids1, &new_row_splits1); + + // Get remaining arcs + int32_t num_arcs = fsas_.NumElements(); + Renumbering arcs_renumbering(c_, num_arcs); + char* arcs_renumbering_data = arcs_renumbering.Keep().Data(); + const Arc *fsas_data = fsas_.values.Data(); + const int32_t *fsas_row_ids1_data = fsas_.RowIds(1).Data(), + *fsas_row_ids2_data = fsas_.RowIds(2).Data(), + *fsas_row_splits1_data = fsas_.RowSplits(1).Data(); + K2_EVAL( + c_, num_arcs, lambda_set_arcs_renumbering, + (int32_t arc_idx012)->void { + Arc arc = fsas_data[arc_idx012]; + int32_t src_state_idx01 = fsas_row_ids2_data[arc_idx012], + dest_state_idx01 = + arc.dest_state - arc.src_state + src_state_idx01; + // 3 in hex is 0x0000 0011 + if (accessible_data[src_state_idx01] == 3 && + accessible_data[dest_state_idx01] == 3) + arcs_renumbering_data[arc_idx012] = 1; + else + arcs_renumbering_data[arc_idx012] = 0; + }); + Array1 new2old_map_arcs = arcs_renumbering.New2Old(); + int32_t remaining_arcs_num = new2old_map_arcs.Dim(); + + // Construct row_ids2/row_splits2 + Array1 new_row_ids2(c_, remaining_arcs_num); + int32_t *new_row_ids2_data = new_row_ids2.Data(); + const int32_t *new2old_map_arcs_data = new2old_map_arcs.Data(), + *old2new_map_states_data = old2new_map_states.Data(); + K2_EVAL( + c_, remaining_arcs_num, lambda_set_new_row_ids2, + (int32_t arc_idx012)->void { + int32_t idx012 = new2old_map_arcs_data[arc_idx012], + state_idx01 = fsas_row_ids2_data[idx012]; + new_row_ids2_data[arc_idx012] = old2new_map_states_data[state_idx01]; + }); + + Array1 new_row_splits2(c_, new2old_map_states.Dim() + 1); + RowIdsToRowSplits(new_row_ids2, &new_row_splits2); + + // Update arcs to the renumbered states + const int32_t *new_row_ids1_data = new_row_ids1.Data(), + *new_row_splits1_data = new_row_splits1.Data(); + Array1 remaining_arcs(c_, remaining_arcs_num); + Arc *remaining_arcs_data = remaining_arcs.Data(); + K2_EVAL( + c_, remaining_arcs_num, lambda_set_arcs, + (int32_t arc_idx012)->void { + // note: the prefix `old_` means it is an idxXXX w.r.t. the origin + // fsas (`fsas_`). the prefix `new_` means the variable is an idxXXX + // w.r.t. the result fsas. + int32_t old_idx012 = new2old_map_arcs_data[arc_idx012], + old_idx01 = fsas_row_ids2_data[old_idx012], + old_idx0 = fsas_row_ids1_data[old_idx01], + old_idx0x = fsas_row_splits1_data[old_idx0]; + Arc arc = fsas_data[old_idx012]; + int32_t old_src_state_idx01 = old_idx0x + arc.src_state, + new_src_state_idx01 = + old2new_map_states_data[old_src_state_idx01], + new_src_fsa_idx0 = new_row_ids1_data[new_src_state_idx01], + new_src_state_idx0x = new_row_splits1_data[new_src_fsa_idx0], + new_src_state_idx1 = + new_src_state_idx01 - new_src_state_idx0x, + + old_dest_state_idx01 = old_idx0x + arc.dest_state, + new_dest_state_idx01 = + old2new_map_states_data[old_dest_state_idx01], + new_dest_fsa_idx0 = new_row_ids1_data[new_dest_state_idx01], + new_dest_state_idx0x = + new_row_splits1_data[new_dest_fsa_idx0], + new_dest_state_idx1 = + new_dest_state_idx01 - new_dest_state_idx0x; + Arc new_arc; + new_arc.src_state = new_src_state_idx1; + new_arc.dest_state = new_dest_state_idx1; + new_arc.score = arc.score; + new_arc.label = arc.label; + remaining_arcs_data[arc_idx012] = new_arc; + }); + + if (arc_map != nullptr) + *arc_map = std::move(new2old_map_arcs); + + return Ragged( + RaggedShape3(&new_row_splits1, &new_row_ids1, new2old_map_states.Dim(), + &new_row_splits2, &new_row_ids2, remaining_arcs_num), + remaining_arcs); + } + + ContextPtr c_; + FsaVec &fsas_; + + // For each arc in fsas_ (with same structure as fsas_), dest-state + // of that arc as an idx01. + Ragged dest_states_; + // For each state in fsas_ (with same structure as fsas_), incoming-arc + // of that state as an idx012. + Ragged incoming_arcs_; + // With the Dim() the same as num-states, to mark the state (as an idx01) to + // be accessible/coaccessible or not. For each element in this array the first + // bit uses to mark accessible and the second bit to mark coaccessible. + Array1 accessible_; +}; + +void Connect(FsaOrVec &src, FsaOrVec *dest, + Array1 *arc_map /* = nullptr */) { + NVTX_RANGE(K2_FUNC); + K2_CHECK_GE(src.NumAxes(), 2); + K2_CHECK_LE(src.NumAxes(), 3); + if (src.NumAxes() == 2) { + // Turn single Fsa into FsaVec. + FsaVec src_vec = FsaToFsaVec(src), dest_vec; + // Recurse.. + Connect(src_vec, &dest_vec, arc_map); + *dest = GetFsaVecElement(dest_vec, 0); + return; + } + Connector connector(src); + *dest = connector.Connect(arc_map); +} + +} // namespace k2 diff --git a/k2/csrc/connect_test.cu b/k2/csrc/connect_test.cu new file mode 100644 index 000000000..1289d535c --- /dev/null +++ b/k2/csrc/connect_test.cu @@ -0,0 +1,212 @@ +/** + * Copyright 2021 Xiaomi Corporation (authors: Wei Kang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include +#include + +#include "k2/csrc/fsa_algo.h" +#include "k2/csrc/fsa_utils.h" +#include "k2/csrc/host_shim.h" +#include "k2/csrc/math.h" +#include "k2/csrc/test_utils.h" + +namespace k2 { + +TEST(Connect, SingleFsa) { + for (const ContextPtr &c : {GetCpuContext(), GetCudaContext()}) { + std::string s = R"(0 2 1 1 + 0 3 3 3 + 1 4 5 5 + 1 6 -1 0 + 2 1 2 2 + 3 1 4 4 + 5 3 6 6 + 6 + )"; + auto fsa = FsaFromString(s).To(c); + int32_t gt = kFsaPropertiesMaybeAccessible | + kFsaPropertiesMaybeCoaccessible; + int32_t p = GetFsaBasicProperties(fsa); + EXPECT_NE(p & gt, gt); + + Fsa connected; + Array1 arc_map; + Connect(fsa, &connected, &arc_map); + + Fsa ref = Fsa("[ [ 0 2 1 1 0 3 3 3 ] [ 1 4 -1 0 ] " + " [ 2 1 2 2 ] [ 3 1 4 4 ] [ ] ]").To(c); + Array1 arc_map_ref(c, "[ 0 1 3 4 5 ]"); + K2_CHECK(Equal(connected, ref)); + K2_CHECK(Equal(arc_map, arc_map_ref)); + p = GetFsaBasicProperties(connected); + EXPECT_EQ(p & gt, gt); + } +} + +TEST(Connect, CyclicFsa) { + for (const ContextPtr &c : {GetCpuContext(), GetCudaContext()}) { + std::string s = R"(0 1 1 1 + 0 2 2 2 + 1 2 3 3 + 2 3 4 4 + 2 4 5 5 + 3 1 6 6 + 3 6 -1 0 + 5 2 7 7 + 6 + )"; + auto fsa = FsaFromString(s).To(c); + + int32_t gt = kFsaPropertiesMaybeAccessible | + kFsaPropertiesMaybeCoaccessible; + int32_t p = GetFsaBasicProperties(fsa); + EXPECT_NE(p & gt, gt); + + Fsa connected; + Array1 arc_map; + Connect(fsa, &connected, &arc_map); + Fsa ref = Fsa("[ [ 0 1 1 1 0 2 2 2 ] [ 1 2 3 3 ] [ 2 3 4 4] " + " [ 3 1 6 6 3 4 -1 0 ] [ ] ]").To(c); + Array1 arc_map_ref(c, "[ 0 1 2 3 5 6 ]"); + K2_CHECK(Equal(connected, ref)); + K2_CHECK(Equal(arc_map, arc_map_ref)); + p = GetFsaBasicProperties(connected); + EXPECT_EQ(p & gt, gt); + } +} + +TEST(Connect, RandomSingleFsa) { + ContextPtr cpu = GetCpuContext(); + for (const ContextPtr &c : {GetCpuContext(), GetCudaContext()}) { + bool acyclic = RandInt(0, 1); + Fsa fsa = RandomFsa(acyclic).To(c); + int32_t gt = kFsaPropertiesMaybeAccessible | + kFsaPropertiesMaybeCoaccessible; + + Fsa connected; + Array1 arc_map; + Connect(fsa, &connected, &arc_map); + int32_t p = GetFsaBasicProperties(connected); + EXPECT_EQ(p & gt, gt); + + Array1 arcs = connected.values.To(cpu), + fsa_arcs = fsa.values.To(cpu); + arc_map = arc_map.To(cpu); + int32_t num_arcs = arcs.Dim(); + for (int32_t i = 0; i != num_arcs; ++i) { + EXPECT_EQ(arcs[i].score, fsa_arcs[arc_map[i]].score); + } + } +} + +TEST(Connect, FsaVec) { + for (const ContextPtr &c : {GetCpuContext(), GetCudaContext()}) { + std::string s1 = R"(0 1 1 1 + 0 2 2 2 + 1 3 -1 0 + 3 + )"; + auto fsa1 = FsaFromString(s1); + + std::string s2 = R"(0 1 1 1 + 1 3 -1 0 + 2 1 2 2 + 3 + )"; + auto fsa2 = FsaFromString(s2); + + std::string s3 = R"(0 1 1 1 + 1 4 -1 0 + 1 3 3 3 + 2 1 2 2 + 4 + )"; + auto fsa3 = FsaFromString(s3); + + Fsa *fsa_array[] = {&fsa1, &fsa2, &fsa3}; + FsaVec fsa_vec = CreateFsaVec(3, &fsa_array[0]).To(c); + + int32_t gt = kFsaPropertiesMaybeAccessible | + kFsaPropertiesMaybeCoaccessible; + Array1 properties; + int32_t p; + GetFsaVecBasicProperties(fsa_vec, &properties, &p); + + EXPECT_NE(p & gt, gt); + EXPECT_NE(properties[0] & gt, gt); + EXPECT_NE(properties[1] & gt, gt); + EXPECT_NE(properties[2] & gt, gt); + + FsaVec connected; + Array1 arc_map; + Connect(fsa_vec, &connected, &arc_map); + FsaVec ref = FsaVec("[ [ [ 0 1 1 1 ] [ 1 2 -1 0 ] [ ] ] " + " [ [ 0 1 1 1 ] [ 1 2 -1 0 ] [ ] ] " + " [ [ 0 1 1 1 ] [ 1 2 -1 0 ] [ ] ] ]").To(c); + Array1 arc_map_ref(c, "[ 0 2 3 4 6 7 ]"); + K2_CHECK(Equal(connected, ref)); + K2_CHECK(Equal(arc_map, arc_map_ref)); + + GetFsaVecBasicProperties(connected, &properties, &p); + EXPECT_EQ(p & gt, gt); + EXPECT_EQ(properties[0] & gt, gt); + EXPECT_EQ(properties[1] & gt, gt); + EXPECT_EQ(properties[2] & gt, gt); + } +} + +TEST(Connect, RandomFsaVec) { + ContextPtr cpu = GetCpuContext(); + for (auto &c : {GetCpuContext(), GetCudaContext()}) { + bool acyclic = RandInt(0, 1); + + FsaVec fsa_vec = RandomFsaVec(1, 100, acyclic); + fsa_vec = fsa_vec.To(c); + + int32_t gt = kFsaPropertiesMaybeAccessible | + kFsaPropertiesMaybeCoaccessible; + Array1 properties; + int32_t p; + + FsaVec connected; + Array1 arc_map; + Connect(fsa_vec, &connected, &arc_map); + + GetFsaVecBasicProperties(connected, &properties, &p); + + EXPECT_EQ(p & gt, gt); + properties = properties.To(cpu); + int32_t num_fsas = fsa_vec.Dim0(); + for (int32_t i = 0; i != num_fsas; ++i) { + EXPECT_EQ(properties[i] & gt, gt); + } + + Array1 arcs = connected.values.To(cpu), + fsa_arcs = fsa_vec.values.To(cpu); + arc_map = arc_map.To(cpu); + + int32_t num_arcs = connected.TotSize(2); + for (int32_t i = 0; i != num_arcs; ++i) { + EXPECT_EQ(arcs[i].score, fsa_arcs[arc_map[i]].score); + } + } +} + +} // namespace k2 diff --git a/k2/csrc/fsa_algo.cu b/k2/csrc/fsa_algo.cu index 331c08bee..cdaa19cbf 100644 --- a/k2/csrc/fsa_algo.cu +++ b/k2/csrc/fsa_algo.cu @@ -69,13 +69,13 @@ bool RecursionWrapper(bool (*f)(Fsa &, Fsa *, Array1 *), Fsa &src, return true; } -bool Connect(Fsa &src, Fsa *dest, Array1 *arc_map /*=nullptr*/) { +bool ConnectHost(Fsa &src, Fsa *dest, Array1 *arc_map /*=nullptr*/) { NVTX_RANGE(K2_FUNC); int32_t num_axes = src.NumAxes(); if (num_axes < 2 || num_axes > 3) { K2_LOG(FATAL) << "Input has bad num-axes " << num_axes; } else if (num_axes == 3) { - return RecursionWrapper(Connect, src, dest, arc_map); + return RecursionWrapper(ConnectHost, src, dest, arc_map); } k2host::Fsa host_fsa = FsaToHostFsa(src); diff --git a/k2/csrc/fsa_algo.h b/k2/csrc/fsa_algo.h index bad562cbf..743dd8b99 100644 --- a/k2/csrc/fsa_algo.h +++ b/k2/csrc/fsa_algo.h @@ -35,13 +35,26 @@ namespace k2 { kFsaPropertiesMaybeAccessible @param [out,optional] arc_map For each arc in `dest`, gives the index of the corresponding arc in `src` that it corresponds to. - @return Returns true on success (which basically means the input did not - have cycles, so the algorithm could not succeed). Success - does not imply that `dest` is nonempty. + */ +void Connect(FsaOrVec &src, FsaOrVec *dest, Array1 *arc_map = nullptr); - CAUTION: for now this only works for CPU. +/* + This version of Connect() is just a wrapper of `Connection` in host/connect.h + only works for CPU. We will delete it some time, users should never call this + function in production code. Instead, you should call the version above. + @param [in] src Source FSA + @param [out] dest Destination; at exit will be equivalent to `src` + but will have no states that are unreachable or which + can't reach the final-state, i.e. its Properties() will + contain kFsaPropertiesMaybeCoaccessible and + kFsaPropertiesMaybeAccessible + @param [out,optional] arc_map For each arc in `dest`, gives the index of + the corresponding arc in `src` that it corresponds to. + @return Returns true on success (which basically means the input did not + have cycles, so the algorithm could not succeed). Success + does not imply that `dest` is nonempty. */ -bool Connect(Fsa &src, Fsa *dest, Array1 *arc_map = nullptr); +bool ConnectHost(Fsa &src, Fsa *dest, Array1 *arc_map = nullptr); /* Sort arcs of an Fsa or FsaVec in-place (this version of the function does not diff --git a/k2/csrc/fsa_utils.cu b/k2/csrc/fsa_utils.cu index 1230d5b3c..e1c961fb0 100644 --- a/k2/csrc/fsa_utils.cu +++ b/k2/csrc/fsa_utils.cu @@ -2773,6 +2773,93 @@ void FixFinalLabels(FsaOrVec &fsas, } } +FsaVec RenumberFsaVec(FsaVec &fsas, const Array1 &order, + Array1 *arc_map) { + NVTX_RANGE(K2_FUNC); + K2_CHECK_EQ(fsas.NumAxes(), 3); + ContextPtr &c = fsas.Context(); + K2_CHECK_LE(order.Dim(), fsas.TotSize(1)); + Array1 old2new_map(c, fsas.TotSize(1)); + if (order.Dim() != fsas.TotSize(1)) { + old2new_map = -1; + } + int32_t new_num_states = order.Dim(), num_fsas = fsas.Dim0(); + Array1 num_arcs(c, new_num_states + 1); + const int32_t *order_data = order.Data(), + *fsas_row_splits1_data = fsas.RowSplits(1).Data(), + *fsas_row_splits2_data = fsas.RowSplits(2).Data(); + int32_t *old2new_data = old2new_map.Data(), *num_arcs_data = num_arcs.Data(); + K2_EVAL( + c, new_num_states, lambda_set_old2new_and_num_arcs, + (int32_t new_state_idx01)->void { + int32_t old_state_idx01 = order_data[new_state_idx01]; + old2new_data[old_state_idx01] = new_state_idx01; + int32_t num_arcs = fsas_row_splits2_data[old_state_idx01 + 1] - + fsas_row_splits2_data[old_state_idx01]; + num_arcs_data[new_state_idx01] = num_arcs; + }); + + Array1 new_row_splits1, new_row_ids1; + if (order.Dim() == fsas.TotSize(1)) { + new_row_splits1 = fsas.RowSplits(1); + new_row_ids1 = fsas.RowIds(1); + } else { + new_row_ids1 = fsas.RowIds(1)[order]; + new_row_splits1 = Array1(c, num_fsas + 1); + RowIdsToRowSplits(new_row_ids1, &new_row_splits1); + } + ExclusiveSum(num_arcs, &num_arcs); + RaggedShape ans_shape = + RaggedShape3(&new_row_splits1, &new_row_ids1, -1, &num_arcs, nullptr, -1); + const int32_t *ans_row_ids2_data = ans_shape.RowIds(2).Data(), + *ans_row_ids1_data = ans_shape.RowIds(1).Data(), + *ans_row_splits1_data = ans_shape.RowSplits(1).Data(), + *ans_row_splits2_data = ans_shape.RowSplits(2).Data(); + int32_t ans_num_arcs = ans_shape.NumElements(); + Array1 ans_arcs(c, ans_num_arcs); + int32_t *arc_map_data; + if (arc_map != nullptr) { + *arc_map = Array1(c, ans_num_arcs); + arc_map_data = arc_map->Data(); + } else { + arc_map_data = nullptr; + } + + const Arc *fsas_arcs = fsas.values.Data(); + Arc *ans_arcs_data = ans_arcs.Data(); + // if the dest state of any arc from any src kept state is not kept, the + // program will abort with an error. + Array1 all_dest_states_kept(c, 1, 1); + int32_t *all_dest_states_kept_data = all_dest_states_kept.Data(); + K2_EVAL( + c, ans_num_arcs, lambda_set_arcs, (int32_t ans_idx012)->void { + int32_t ans_idx01 = ans_row_ids2_data[ans_idx012], // state index + ans_idx01x = ans_row_splits2_data[ans_idx01], + ans_idx0 = ans_row_ids1_data[ans_idx01], // FSA index + ans_idx0x = ans_row_splits1_data[ans_idx0], + ans_idx1 = ans_idx01 - ans_idx0x, + ans_idx2 = ans_idx012 - ans_idx01x, + fsas_idx01 = order_data[ans_idx01], + fsas_idx01x = fsas_row_splits2_data[fsas_idx01], + fsas_idx012 = fsas_idx01x + ans_idx2; + Arc arc = fsas_arcs[fsas_idx012]; + int32_t fsas_src_idx1 = arc.src_state, fsas_dest_idx1 = arc.dest_state, + fsas_idx0x = fsas_row_splits1_data[ans_idx0], + fsas_src_idx01 = fsas_idx0x + fsas_src_idx1, + fsas_dest_idx01 = fsas_idx0x + fsas_dest_idx1; + K2_CHECK_EQ(old2new_data[fsas_src_idx01], ans_idx01); + int32_t ans_dest_idx01 = old2new_data[fsas_dest_idx01]; + int32_t ans_dest_idx1 = ans_dest_idx01 - ans_idx0x; + arc.src_state = ans_idx1; + arc.dest_state = ans_dest_idx1; + ans_arcs_data[ans_idx012] = arc; + if (arc_map_data != nullptr) arc_map_data[ans_idx012] = fsas_idx012; + if (ans_dest_idx01 == -1) all_dest_states_kept_data[0] = 0; + }); + K2_CHECK_EQ(all_dest_states_kept[0], 1) + << "The dest_state of an arc from a kept state is not present in `order`"; + return FsaVec(ans_shape, ans_arcs); +} } // namespace k2 diff --git a/k2/csrc/fsa_utils_test.cu b/k2/csrc/fsa_utils_test.cu index a25a6a908..c1dc79dc7 100644 --- a/k2/csrc/fsa_utils_test.cu +++ b/k2/csrc/fsa_utils_test.cu @@ -1169,8 +1169,7 @@ TEST_F(StatesBatchSuiteTest, TestBackpropForwardScores) { // make the fsa connected for easy testing for tropical version, the // algorithm should work for non-connected fsa as well. FsaVec connected; - bool status = Connect(random_fsas, &connected); - ASSERT_TRUE(status); + Connect(random_fsas, &connected); TestBackpropGetForwardScores(connected); TestBackpropGetForwardScores(connected); } @@ -1310,8 +1309,7 @@ TEST_F(StatesBatchSuiteTest, TestBackpropBackwardScores) { // make the fsa connected for easy testing for tropical version, the // algorithm should work for non-connected fsa as well. FsaVec connected; - bool status = Connect(random_fsas, &connected); - ASSERT_TRUE(status); + Connect(random_fsas, &connected); TestBackpropGetBackwardScores(connected); TestBackpropGetBackwardScores(connected); } @@ -1461,8 +1459,7 @@ TEST_F(StatesBatchSuiteTest, TestRandomPaths) { // make the fsa connected for easy testing for tropical version, the // algorithm should work for non-connected fsa as well. FsaVec connected; - bool status = Connect(random_fsas, &connected); - ASSERT_TRUE(status); + Connect(random_fsas, &connected); TestRandomPaths(connected); TestRandomPaths(connected); } diff --git a/k2/csrc/top_sort.cu b/k2/csrc/top_sort.cu index 2c0a925c3..4589ed770 100644 --- a/k2/csrc/top_sort.cu +++ b/k2/csrc/top_sort.cu @@ -28,96 +28,6 @@ namespace k2 { // Caution: this is really a .cu file. It contains mixed host and device code. -// See declaration in fsa_util.h -FsaVec RenumberFsaVec(FsaVec &fsas, const Array1 &order, - Array1 *arc_map) { - NVTX_RANGE(K2_FUNC); - K2_CHECK_EQ(fsas.NumAxes(), 3); - ContextPtr &c = fsas.Context(); - K2_CHECK_LE(order.Dim(), fsas.TotSize(1)); - Array1 old2new_map(c, fsas.TotSize(1)); - if (order.Dim() != fsas.TotSize(1)) { - old2new_map = -1; - } - int32_t new_num_states = order.Dim(), num_fsas = fsas.Dim0(); - Array1 num_arcs(c, new_num_states + 1); - const int32_t *order_data = order.Data(), - *fsas_row_splits1_data = fsas.RowSplits(1).Data(), - *fsas_row_splits2_data = fsas.RowSplits(2).Data(); - int32_t *old2new_data = old2new_map.Data(), *num_arcs_data = num_arcs.Data(); - K2_EVAL( - c, new_num_states, lambda_set_old2new_and_num_arcs, - (int32_t new_state_idx01)->void { - int32_t old_state_idx01 = order_data[new_state_idx01]; - old2new_data[old_state_idx01] = new_state_idx01; - int32_t num_arcs = fsas_row_splits2_data[old_state_idx01 + 1] - - fsas_row_splits2_data[old_state_idx01]; - num_arcs_data[new_state_idx01] = num_arcs; - }); - - Array1 new_row_splits1, new_row_ids1; - if (order.Dim() == fsas.TotSize(1)) { - new_row_splits1 = fsas.RowSplits(1); - new_row_ids1 = fsas.RowIds(1); - } else { - new_row_ids1 = fsas.RowIds(1)[order]; - new_row_splits1 = Array1(c, num_fsas + 1); - RowIdsToRowSplits(new_row_ids1, &new_row_splits1); - } - - ExclusiveSum(num_arcs, &num_arcs); - RaggedShape ans_shape = - RaggedShape3(&new_row_splits1, &new_row_ids1, -1, &num_arcs, nullptr, -1); - const int32_t *ans_row_ids2_data = ans_shape.RowIds(2).Data(), - *ans_row_ids1_data = ans_shape.RowIds(1).Data(), - *ans_row_splits1_data = ans_shape.RowSplits(1).Data(), - *ans_row_splits2_data = ans_shape.RowSplits(2).Data(); - int32_t ans_num_arcs = ans_shape.NumElements(); - Array1 ans_arcs(c, ans_num_arcs); - int32_t *arc_map_data; - if (arc_map != nullptr) { - *arc_map = Array1(c, ans_num_arcs); - arc_map_data = arc_map->Data(); - } else { - arc_map_data = nullptr; - } - - const Arc *fsas_arcs = fsas.values.Data(); - Arc *ans_arcs_data = ans_arcs.Data(); - // if the dest state of any arc from any src kept state is not kept, the - // program will abort with an error. - Array1 all_dest_states_kept(c, 1, 1); - int32_t *all_dest_states_kept_data = all_dest_states_kept.Data(); - K2_EVAL( - c, ans_num_arcs, lambda_set_arcs, (int32_t ans_idx012)->void { - int32_t ans_idx01 = ans_row_ids2_data[ans_idx012], // state index - ans_idx01x = ans_row_splits2_data[ans_idx01], - ans_idx0 = ans_row_ids1_data[ans_idx01], // FSA index - ans_idx0x = ans_row_splits1_data[ans_idx0], - ans_idx1 = ans_idx01 - ans_idx0x, - ans_idx2 = ans_idx012 - ans_idx01x, - fsas_idx01 = order_data[ans_idx01], - fsas_idx01x = fsas_row_splits2_data[fsas_idx01], - fsas_idx012 = fsas_idx01x + ans_idx2; - Arc arc = fsas_arcs[fsas_idx012]; - int32_t fsas_src_idx1 = arc.src_state, fsas_dest_idx1 = arc.dest_state, - fsas_idx0x = fsas_row_splits1_data[ans_idx0], - fsas_src_idx01 = fsas_idx0x + fsas_src_idx1, - fsas_dest_idx01 = fsas_idx0x + fsas_dest_idx1; - K2_CHECK_EQ(old2new_data[fsas_src_idx01], ans_idx01); - int32_t ans_dest_idx01 = old2new_data[fsas_dest_idx01]; - int32_t ans_dest_idx1 = ans_dest_idx01 - ans_idx0x; - arc.src_state = ans_idx1; - arc.dest_state = ans_dest_idx1; - ans_arcs_data[ans_idx012] = arc; - if (arc_map_data != nullptr) arc_map_data[ans_idx012] = fsas_idx012; - if (ans_dest_idx01 == -1) all_dest_states_kept_data[0] = 0; - }); - K2_CHECK_EQ(all_dest_states_kept[0], 1) - << "The dest_state of an arc from a kept state is not present in `order`"; - return FsaVec(ans_shape, ans_arcs); -} - class TopSorter { public: /** @@ -131,8 +41,6 @@ class TopSorter { K2_CHECK_EQ(fsas_.NumAxes(), 3); } - int32_t NumFsas() const { return fsas_.Dim0(); } - /* Return the ragged array containing the states active on the 1st iteration of the algorithm. These just correspond to the start-states of all @@ -177,14 +85,14 @@ class TopSorter { /* Computes the next batch of states - @param [in] cur_states Ragged array with 2 axes, containing - state-indexes (idx01) into fsas_. These are states which already have - in-degree 0 + @param [in] cur_states Ragged array with 2 axes, with the shape of + `[fsas][states]`, containing state-indexes (idx01) into fsas_. + These are states which already have in-degree 0 @return Returns the states which, after processing. */ std::unique_ptr> GetNextBatch(Ragged &cur_states) { NVTX_RANGE(K2_FUNC); - // Process arcs leaving all states in `cur` + // Process arcs leaving all states in `cur_states` // First figure out how many arcs leave each state. Array1 num_arcs_per_state(c_, cur_states.NumElements() + 1); @@ -194,45 +102,52 @@ class TopSorter { K2_EVAL( c_, cur_states.NumElements(), lambda_set_arcs_per_state, (int32_t states_idx01)->void { - int32_t fsas_idx01 = states_data[states_idx01], - num_arcs = fsas_row_splits2_data[fsas_idx01 + 1] - - fsas_row_splits2_data[fsas_idx01]; + int32_t idx01 = states_data[states_idx01], + num_arcs = fsas_row_splits2_data[idx01 + 1] - + fsas_row_splits2_data[idx01]; num_arcs_per_state_data[states_idx01] = num_arcs; }); ExclusiveSum(num_arcs_per_state, &num_arcs_per_state); + // arcs_shape `[fsas][states in-degree 0][arcs] RaggedShape arcs_shape = ComposeRaggedShapes( cur_states.shape, RaggedShape2(&num_arcs_per_state, nullptr, -1)); // Each arc that generates a new state (i.e. for which // arc_renumbering.Keep[i] == true) will write the state-id to here (as an // idx01 into fsas_). Other elements will be undefined. - Array1 next_iter_states(c_, arcs_shape.NumElements()); + // We will also write the row-id (which fsa the state belongs) for each + // new state. + int32_t num_arcs = arcs_shape.NumElements(); + Array1 temp(c_, 2 * num_arcs); + Array1 next_iter_states = temp.Arange(0, num_arcs); + Array1 new_state_row_ids = temp.Arange(num_arcs, 2 * num_arcs); // We'll be figuring out which of these arcs leads to a state that now has // in-degree 0. (If >1 arc goes to such a state, only one will 'win', // arbitrarily). - Renumbering arc_renumbering(c_, arcs_shape.NumElements()); + Renumbering arc_renumbering(c_, num_arcs); const int32_t *arcs_row_ids1_data = arcs_shape.RowIds(1).Data(), *arcs_row_ids2_data = arcs_shape.RowIds(2).Data(), - *arcs_row_splits1_data = arcs_shape.RowSplits(1).Data(), *arcs_row_splits2_data = arcs_shape.RowSplits(2).Data(), *fsas_row_splits1_data = fsas_.RowSplits(1).Data(), *dest_states_data = dest_states_.values.Data(); char *keep_arc_data = arc_renumbering.Keep().Data(); int32_t *state_in_degree_data = state_in_degree_.Data(), - *next_iter_states_data = next_iter_states.Data(); + *next_iter_states_data = next_iter_states.Data(), + *new_state_row_ids_data = new_state_row_ids.Data(); K2_EVAL( - c_, arcs_shape.NumElements(), lambda_set_arc_renumbering, + c_, num_arcs, lambda_set_arc_renumbering, (int32_t arcs_idx012)->void { // note: the prefix `arcs_` means it is an idxXXX w.r.t. `arcs_shape`. // the prefix `fsas_` means the variable is an idxXXX w.r.t. `fsas_`. int32_t arcs_idx01 = arcs_row_ids2_data[arcs_idx012], + arcs_idx0 = arcs_row_ids1_data[arcs_idx01], arcs_idx01x = arcs_row_splits2_data[arcs_idx01], arcs_idx2 = arcs_idx012 - arcs_idx01x, fsas_idx01 = states_data[arcs_idx01], // a state index - fsas_idx01x = fsas_row_splits2_data[fsas_idx01], + fsas_idx01x = fsas_row_splits2_data[fsas_idx01], fsas_idx012 = fsas_idx01x + arcs_idx2, fsas_dest_state_idx01 = dest_states_data[fsas_idx012]; // if this arc is a self-loop, just ignore this arc as we have @@ -244,6 +159,7 @@ class TopSorter { if ((keep_arc_data[arcs_idx012] = AtomicDecAndCompareZero( state_in_degree_data + fsas_dest_state_idx01))) { next_iter_states_data[arcs_idx012] = fsas_dest_state_idx01; + new_state_row_ids_data[arcs_idx012] = arcs_idx0; } }); @@ -253,30 +169,33 @@ class TopSorter { // calling code that we processed all arcs. return nullptr; } + int32_t num_states = new2old_map.Dim(); + Array1 temp2(c_, 2 * num_states); // `new_states` will contain state-ids which are idx01's into `fsas_`. - Array1 new_states = next_iter_states[new2old_map]; - Array1 new_states_row_ids(c_, new_states.Dim()); // will map to - // FSA index + Array1 new_states = temp2.Arange(0, num_states); + // `ans_row_ids` will map to FSA index + Array1 ans_row_ids = temp2.Arange(num_states, 2 * num_states); + const int32_t *new2old_map_data = new2old_map.Data(); - int32_t *new_states_row_ids_data = new_states_row_ids.Data(); + int32_t *ans_row_ids_data = ans_row_ids.Data(), + *new_states_data = new_states.Data(); K2_EVAL( - c_, new_states.Dim(), lambda_set_row_ids, + c_, num_states, lambda_set_new_states_and_row_ids, (int32_t new_state_idx)->void { - int32_t arcs_idx012 = new2old_map_data[new_state_idx], - arcs_idx01 = arcs_row_ids2_data[arcs_idx012], // state index - arcs_idx0 = arcs_row_ids1_data[arcs_idx01]; // FSA index - new_states_row_ids_data[new_state_idx] = arcs_idx0; + int32_t arcs_idx012 = new2old_map_data[new_state_idx]; + new_states_data[new_state_idx] = next_iter_states_data[arcs_idx012]; + ans_row_ids_data[new_state_idx] = new_state_row_ids_data[arcs_idx012]; }); int32_t num_fsas = fsas_.Dim0(); - Array1 new_states_row_splits(c_, num_fsas + 1); - RowIdsToRowSplits(new_states_row_ids, &new_states_row_splits); + Array1 ans_row_splits(c_, num_fsas + 1); + RowIdsToRowSplits(ans_row_ids, &ans_row_splits); - std::unique_ptr> ans = std::make_unique>( - RaggedShape2(&new_states_row_splits, &new_states_row_ids, -1), + auto ans = std::make_unique>( + RaggedShape2(&ans_row_splits, &ans_row_ids, num_states), new_states); // The following will ensure the answer has deterministic numbering - SortSublists(ans.get(), nullptr); + SortSublists(ans.get()); return ans; } @@ -289,7 +208,7 @@ class TopSorter { */ std::unique_ptr> GetFinalBatch() { NVTX_RANGE(K2_FUNC); - int32_t num_fsas = NumFsas(); + int32_t num_fsas = fsas_.Dim0(); const int32_t *fsas_row_splits1_data = fsas_.RowSplits(1).Data(); Array1 has_final_state(c_, num_fsas + 1); int32_t *has_final_state_data = has_final_state.Data(); @@ -302,7 +221,7 @@ class TopSorter { ExclusiveSum(has_final_state, &has_final_state); int32_t n = has_final_state[num_fsas]; - std::unique_ptr> ans = std::make_unique>( + auto ans = std::make_unique>( RaggedShape2(&has_final_state, nullptr, n), Array1(c_, n)); int32_t *ans_data = ans->values.Data(); const int32_t *ans_row_ids1_data = ans->RowIds(1).Data(); @@ -383,6 +302,7 @@ class TopSorter { *first_batch_row_splits1_data = first_batch->RowSplits(1).Data(), *fsas_row_splits1_data = fsas_.RowSplits(1).Data(); + // Act as a flag Array1 start_state_present(c_, 1, 1); int32_t *start_state_present_data = start_state_present.Data(); K2_EVAL( @@ -441,8 +361,7 @@ void TopSort(FsaVec &src, FsaVec *dest, Array1 *arc_map) { K2_CHECK_LE(src.NumAxes(), 3); if (src.NumAxes() == 2) { // Turn single Fsa into FsaVec. - Fsa *srcs = &src; - FsaVec src_vec = CreateFsaVec(1, &srcs), dest_vec; + FsaVec src_vec = FsaToFsaVec(src), dest_vec; // Recurse.. TopSort(src_vec, &dest_vec, arc_map); *dest = GetFsaVecElement(dest_vec, 0); diff --git a/k2/csrc/utils.h b/k2/csrc/utils.h index a3767e104..d2d09a9c3 100644 --- a/k2/csrc/utils.h +++ b/k2/csrc/utils.h @@ -107,9 +107,9 @@ namespace k2 { In a ragged tensor t with n axes (say, 3) the actual elements will be written in a linear array we'll have various levels of indexes that allow us to look up an element given the hierarchical indexes and vice versa. A 3-d - ragged tensor will have t.RowIds(1), t.RowSplits(1), t.RowIds(2), t.RowSplits(2), - and the actual elements. We have a naming scheme that expresses what - information is packed into a single integer. + ragged tensor will have t.RowIds(1), t.RowSplits(1), t.RowIds(2), + t.RowSplits(2), and the actual elements. We have a naming scheme that + expresses what information is packed into a single integer. Some entry-level facts about the naming scheme are: @@ -170,6 +170,11 @@ namespace k2 { is intuitively obvious and any mismatches will tend to be obvious in an individual line of code once you have understood the naming scheme and its rules. + + Note: We also have a naming convention according to the index naming scheme + above. For a Ragged array with index `[fsa][state][arc]` (say, an FsaVec), + we usually call `idx0` as `fsa_idx0`, `idx01` as `state_idx01`, `idx012` + as `arc_idx012`, that is, `theThingWeIndex_idx0[123]`. */ /** @@ -353,7 +358,7 @@ __host__ __device__ __forceinline__ bool AtomicDecAndCompareZero(int32_t *i) { CAUTION: For host code, we assume single-threaded for now. @param [inout] address The memory address. - @param in] value The value to be added. + @param [in] value The value to be added. */ template __host__ __device__ __forceinline__ void AtomicAdd(T *address, T value) { @@ -370,7 +375,7 @@ __host__ __device__ __forceinline__ void AtomicAdd(T *address, T value) { // The following implementation is copied from // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#atomic-functions __host__ __device__ __forceinline__ void AtomicAdd(double *address, - double value) { + double value) { #if __CUDA_ARCH__ >= 600 atomicAdd(address, value); #elif defined(__CUDA_ARCH__) diff --git a/k2/python/k2/fsa_algo.py b/k2/python/k2/fsa_algo.py index 83634d57e..c84b89b9e 100644 --- a/k2/python/k2/fsa_algo.py +++ b/k2/python/k2/fsa_algo.py @@ -443,8 +443,6 @@ def connect(fsa: Fsa) -> Fsa: Removes states that are neither accessible nor co-accessible. - It works only on CPU. - Note: A state is not accessible if it is not reachable from the start state. A state is not co-accessible if it cannot reach the final state. @@ -464,11 +462,8 @@ def connect(fsa: Fsa) -> Fsa: fsa.properties & fsa_properties.COACCESSIBLE != 0: return fsa - assert fsa.is_cpu() - need_arc_map = True ragged_arc, arc_map = _k2.connect(fsa.arcs, need_arc_map=need_arc_map) - out_fsa = k2.utils.fsa_from_unary_function_tensor(fsa, ragged_arc, arc_map) return out_fsa diff --git a/k2/python/tests/connect_test.py b/k2/python/tests/connect_test.py index 5f510a926..47bc2ff35 100644 --- a/k2/python/tests/connect_test.py +++ b/k2/python/tests/connect_test.py @@ -28,6 +28,15 @@ class TestConnect(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.devices = [torch.device('cpu')] + if torch.cuda.is_available() and k2.with_cuda: + cls.devices.append(torch.device('cuda', 0)) + if torch.cuda.device_count() > 1: + torch.cuda.set_device(1) + cls.devices.append(torch.device('cuda', 1)) + def test(self): s = ''' 0 1 1 0.1 @@ -36,17 +45,21 @@ def test(self): 3 4 -1 0.4 4 ''' - fsa = k2.Fsa.from_str(s) - fsa.requires_grad_(True) - expected_str = '\n'.join(['0 1 1 0.1', '1 2 -1 0.3', '2']) - connected_fsa = k2.connect(fsa) - actual_str = k2.to_str_simple(connected_fsa) - assert actual_str.strip() == expected_str - - loss = connected_fsa.scores.sum() - loss.backward() - assert torch.allclose(fsa.scores.grad, - torch.tensor([1, 0, 1, 0], dtype=torch.float32)) + for device in self.devices: + fsa = k2.Fsa.from_str(s).to(device) + fsa.requires_grad_(True) + expected_str = '\n'.join(['0 1 1 0.1', '1 2 -1 0.3', '2']) + connected_fsa = k2.connect(fsa) + + loss = connected_fsa.scores.sum() + loss.backward() + assert torch.allclose(fsa.scores.grad, + torch.tensor([1, 0, 1, 0], + dtype=torch.float32, + device=device)) + + actual_str = k2.to_str_simple(connected_fsa.to('cpu')) + assert actual_str.strip() == expected_str if __name__ == '__main__':