Skip to content

Commit

Permalink
Implements Connect for device (#771)
Browse files Browse the repository at this point in the history
* implements Connect operation for device

* fix errors when the input is FsaVec

* renumber states and add more test case

* small fix

* fix typo

* Apply suggestions from code review

Co-authored-by: Fangjun Kuang <[email protected]>

* Apply suggestions from code review

Co-authored-by: Fangjun Kuang <[email protected]>

* fix code review suggestions

* fix sytax error

* fix code style & make forward backward traverse run parallel

* save one kernel for forward&backward procedure, fix naming convention issue

* save one kernel for topsorter's GetNextBatch()

* change GetRandFsa() to RandomFsa()

* Update k2/csrc/connect.cu

Co-authored-by: Fangjun Kuang <[email protected]>

* Add the same state only once during traversing to avoid out-of-memory issue

* fix code style

* use atomicCSA to save 4 Byte memory for each state

* remove atomic functions

Co-authored-by: pkufool <[email protected]>
Co-authored-by: Fangjun Kuang <[email protected]>
  • Loading branch information
3 people authored Jul 2, 2021
1 parent 9207b79 commit c414ca6
Show file tree
Hide file tree
Showing 11 changed files with 916 additions and 157 deletions.
2 changes: 2 additions & 0 deletions k2/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ add_subdirectory(host)
set(context_srcs
algorithms.cu
array_ops.cu
connect.cu
context.cu
dtype.cu
fsa.cu
Expand Down Expand Up @@ -136,6 +137,7 @@ set(cuda_test_srcs
algorithms_test.cu
array_ops_test.cu
array_test.cu
connect_test.cu
dtype_test.cu
fsa_algo_test.cu
fsa_test.cu
Expand Down
516 changes: 516 additions & 0 deletions k2/csrc/connect.cu

Large diffs are not rendered by default.

212 changes: 212 additions & 0 deletions k2/csrc/connect_test.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
/**
* Copyright 2021 Xiaomi Corporation (authors: Wei Kang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <gtest/gtest.h>

#include <algorithm>
#include <random>

#include "k2/csrc/fsa_algo.h"
#include "k2/csrc/fsa_utils.h"
#include "k2/csrc/host_shim.h"
#include "k2/csrc/math.h"
#include "k2/csrc/test_utils.h"

namespace k2 {

TEST(Connect, SingleFsa) {
for (const ContextPtr &c : {GetCpuContext(), GetCudaContext()}) {
std::string s = R"(0 2 1 1
0 3 3 3
1 4 5 5
1 6 -1 0
2 1 2 2
3 1 4 4
5 3 6 6
6
)";
auto fsa = FsaFromString(s).To(c);
int32_t gt = kFsaPropertiesMaybeAccessible |
kFsaPropertiesMaybeCoaccessible;
int32_t p = GetFsaBasicProperties(fsa);
EXPECT_NE(p & gt, gt);

Fsa connected;
Array1<int32_t> arc_map;
Connect(fsa, &connected, &arc_map);

Fsa ref = Fsa("[ [ 0 2 1 1 0 3 3 3 ] [ 1 4 -1 0 ] "
" [ 2 1 2 2 ] [ 3 1 4 4 ] [ ] ]").To(c);
Array1<int32_t> arc_map_ref(c, "[ 0 1 3 4 5 ]");
K2_CHECK(Equal(connected, ref));
K2_CHECK(Equal(arc_map, arc_map_ref));
p = GetFsaBasicProperties(connected);
EXPECT_EQ(p & gt, gt);
}
}

TEST(Connect, CyclicFsa) {
for (const ContextPtr &c : {GetCpuContext(), GetCudaContext()}) {
std::string s = R"(0 1 1 1
0 2 2 2
1 2 3 3
2 3 4 4
2 4 5 5
3 1 6 6
3 6 -1 0
5 2 7 7
6
)";
auto fsa = FsaFromString(s).To(c);

int32_t gt = kFsaPropertiesMaybeAccessible |
kFsaPropertiesMaybeCoaccessible;
int32_t p = GetFsaBasicProperties(fsa);
EXPECT_NE(p & gt, gt);

Fsa connected;
Array1<int32_t> arc_map;
Connect(fsa, &connected, &arc_map);
Fsa ref = Fsa("[ [ 0 1 1 1 0 2 2 2 ] [ 1 2 3 3 ] [ 2 3 4 4] "
" [ 3 1 6 6 3 4 -1 0 ] [ ] ]").To(c);
Array1<int32_t> arc_map_ref(c, "[ 0 1 2 3 5 6 ]");
K2_CHECK(Equal(connected, ref));
K2_CHECK(Equal(arc_map, arc_map_ref));
p = GetFsaBasicProperties(connected);
EXPECT_EQ(p & gt, gt);
}
}

TEST(Connect, RandomSingleFsa) {
ContextPtr cpu = GetCpuContext();
for (const ContextPtr &c : {GetCpuContext(), GetCudaContext()}) {
bool acyclic = RandInt(0, 1);
Fsa fsa = RandomFsa(acyclic).To(c);
int32_t gt = kFsaPropertiesMaybeAccessible |
kFsaPropertiesMaybeCoaccessible;

Fsa connected;
Array1<int32_t> arc_map;
Connect(fsa, &connected, &arc_map);
int32_t p = GetFsaBasicProperties(connected);
EXPECT_EQ(p & gt, gt);

Array1<Arc> arcs = connected.values.To(cpu),
fsa_arcs = fsa.values.To(cpu);
arc_map = arc_map.To(cpu);
int32_t num_arcs = arcs.Dim();
for (int32_t i = 0; i != num_arcs; ++i) {
EXPECT_EQ(arcs[i].score, fsa_arcs[arc_map[i]].score);
}
}
}

TEST(Connect, FsaVec) {
for (const ContextPtr &c : {GetCpuContext(), GetCudaContext()}) {
std::string s1 = R"(0 1 1 1
0 2 2 2
1 3 -1 0
3
)";
auto fsa1 = FsaFromString(s1);

std::string s2 = R"(0 1 1 1
1 3 -1 0
2 1 2 2
3
)";
auto fsa2 = FsaFromString(s2);

std::string s3 = R"(0 1 1 1
1 4 -1 0
1 3 3 3
2 1 2 2
4
)";
auto fsa3 = FsaFromString(s3);

Fsa *fsa_array[] = {&fsa1, &fsa2, &fsa3};
FsaVec fsa_vec = CreateFsaVec(3, &fsa_array[0]).To(c);

int32_t gt = kFsaPropertiesMaybeAccessible |
kFsaPropertiesMaybeCoaccessible;
Array1<int32_t> properties;
int32_t p;
GetFsaVecBasicProperties(fsa_vec, &properties, &p);

EXPECT_NE(p & gt, gt);
EXPECT_NE(properties[0] & gt, gt);
EXPECT_NE(properties[1] & gt, gt);
EXPECT_NE(properties[2] & gt, gt);

FsaVec connected;
Array1<int32_t> arc_map;
Connect(fsa_vec, &connected, &arc_map);
FsaVec ref = FsaVec("[ [ [ 0 1 1 1 ] [ 1 2 -1 0 ] [ ] ] "
" [ [ 0 1 1 1 ] [ 1 2 -1 0 ] [ ] ] "
" [ [ 0 1 1 1 ] [ 1 2 -1 0 ] [ ] ] ]").To(c);
Array1<int32_t> arc_map_ref(c, "[ 0 2 3 4 6 7 ]");
K2_CHECK(Equal(connected, ref));
K2_CHECK(Equal(arc_map, arc_map_ref));

GetFsaVecBasicProperties(connected, &properties, &p);
EXPECT_EQ(p & gt, gt);
EXPECT_EQ(properties[0] & gt, gt);
EXPECT_EQ(properties[1] & gt, gt);
EXPECT_EQ(properties[2] & gt, gt);
}
}

TEST(Connect, RandomFsaVec) {
ContextPtr cpu = GetCpuContext();
for (auto &c : {GetCpuContext(), GetCudaContext()}) {
bool acyclic = RandInt(0, 1);

FsaVec fsa_vec = RandomFsaVec(1, 100, acyclic);
fsa_vec = fsa_vec.To(c);

int32_t gt = kFsaPropertiesMaybeAccessible |
kFsaPropertiesMaybeCoaccessible;
Array1<int32_t> properties;
int32_t p;

FsaVec connected;
Array1<int32_t> arc_map;
Connect(fsa_vec, &connected, &arc_map);

GetFsaVecBasicProperties(connected, &properties, &p);

EXPECT_EQ(p & gt, gt);
properties = properties.To(cpu);
int32_t num_fsas = fsa_vec.Dim0();
for (int32_t i = 0; i != num_fsas; ++i) {
EXPECT_EQ(properties[i] & gt, gt);
}

Array1<Arc> arcs = connected.values.To(cpu),
fsa_arcs = fsa_vec.values.To(cpu);
arc_map = arc_map.To(cpu);

int32_t num_arcs = connected.TotSize(2);
for (int32_t i = 0; i != num_arcs; ++i) {
EXPECT_EQ(arcs[i].score, fsa_arcs[arc_map[i]].score);
}
}
}

} // namespace k2
4 changes: 2 additions & 2 deletions k2/csrc/fsa_algo.cu
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,13 @@ bool RecursionWrapper(bool (*f)(Fsa &, Fsa *, Array1<int32_t> *), Fsa &src,
return true;
}

bool Connect(Fsa &src, Fsa *dest, Array1<int32_t> *arc_map /*=nullptr*/) {
bool ConnectHost(Fsa &src, Fsa *dest, Array1<int32_t> *arc_map /*=nullptr*/) {
NVTX_RANGE(K2_FUNC);
int32_t num_axes = src.NumAxes();
if (num_axes < 2 || num_axes > 3) {
K2_LOG(FATAL) << "Input has bad num-axes " << num_axes;
} else if (num_axes == 3) {
return RecursionWrapper(Connect, src, dest, arc_map);
return RecursionWrapper(ConnectHost, src, dest, arc_map);
}

k2host::Fsa host_fsa = FsaToHostFsa(src);
Expand Down
23 changes: 18 additions & 5 deletions k2/csrc/fsa_algo.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,26 @@ namespace k2 {
kFsaPropertiesMaybeAccessible
@param [out,optional] arc_map For each arc in `dest`, gives the index of
the corresponding arc in `src` that it corresponds to.
@return Returns true on success (which basically means the input did not
have cycles, so the algorithm could not succeed). Success
does not imply that `dest` is nonempty.
*/
void Connect(FsaOrVec &src, FsaOrVec *dest, Array1<int32_t> *arc_map = nullptr);

CAUTION: for now this only works for CPU.
/*
This version of Connect() is just a wrapper of `Connection` in host/connect.h
only works for CPU. We will delete it some time, users should never call this
function in production code. Instead, you should call the version above.
@param [in] src Source FSA
@param [out] dest Destination; at exit will be equivalent to `src`
but will have no states that are unreachable or which
can't reach the final-state, i.e. its Properties() will
contain kFsaPropertiesMaybeCoaccessible and
kFsaPropertiesMaybeAccessible
@param [out,optional] arc_map For each arc in `dest`, gives the index of
the corresponding arc in `src` that it corresponds to.
@return Returns true on success (which basically means the input did not
have cycles, so the algorithm could not succeed). Success
does not imply that `dest` is nonempty.
*/
bool Connect(Fsa &src, Fsa *dest, Array1<int32_t> *arc_map = nullptr);
bool ConnectHost(Fsa &src, Fsa *dest, Array1<int32_t> *arc_map = nullptr);

/*
Sort arcs of an Fsa or FsaVec in-place (this version of the function does not
Expand Down
87 changes: 87 additions & 0 deletions k2/csrc/fsa_utils.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2773,6 +2773,93 @@ void FixFinalLabels(FsaOrVec &fsas,
}
}

FsaVec RenumberFsaVec(FsaVec &fsas, const Array1<int32_t> &order,
Array1<int32_t> *arc_map) {
NVTX_RANGE(K2_FUNC);
K2_CHECK_EQ(fsas.NumAxes(), 3);
ContextPtr &c = fsas.Context();
K2_CHECK_LE(order.Dim(), fsas.TotSize(1));
Array1<int32_t> old2new_map(c, fsas.TotSize(1));
if (order.Dim() != fsas.TotSize(1)) {
old2new_map = -1;
}
int32_t new_num_states = order.Dim(), num_fsas = fsas.Dim0();
Array1<int32_t> num_arcs(c, new_num_states + 1);
const int32_t *order_data = order.Data(),
*fsas_row_splits1_data = fsas.RowSplits(1).Data(),
*fsas_row_splits2_data = fsas.RowSplits(2).Data();
int32_t *old2new_data = old2new_map.Data(), *num_arcs_data = num_arcs.Data();
K2_EVAL(
c, new_num_states, lambda_set_old2new_and_num_arcs,
(int32_t new_state_idx01)->void {
int32_t old_state_idx01 = order_data[new_state_idx01];
old2new_data[old_state_idx01] = new_state_idx01;
int32_t num_arcs = fsas_row_splits2_data[old_state_idx01 + 1] -
fsas_row_splits2_data[old_state_idx01];
num_arcs_data[new_state_idx01] = num_arcs;
});

Array1<int32_t> new_row_splits1, new_row_ids1;
if (order.Dim() == fsas.TotSize(1)) {
new_row_splits1 = fsas.RowSplits(1);
new_row_ids1 = fsas.RowIds(1);
} else {
new_row_ids1 = fsas.RowIds(1)[order];
new_row_splits1 = Array1<int32_t>(c, num_fsas + 1);
RowIdsToRowSplits(new_row_ids1, &new_row_splits1);
}

ExclusiveSum(num_arcs, &num_arcs);
RaggedShape ans_shape =
RaggedShape3(&new_row_splits1, &new_row_ids1, -1, &num_arcs, nullptr, -1);
const int32_t *ans_row_ids2_data = ans_shape.RowIds(2).Data(),
*ans_row_ids1_data = ans_shape.RowIds(1).Data(),
*ans_row_splits1_data = ans_shape.RowSplits(1).Data(),
*ans_row_splits2_data = ans_shape.RowSplits(2).Data();
int32_t ans_num_arcs = ans_shape.NumElements();
Array1<Arc> ans_arcs(c, ans_num_arcs);
int32_t *arc_map_data;
if (arc_map != nullptr) {
*arc_map = Array1<int32_t>(c, ans_num_arcs);
arc_map_data = arc_map->Data();
} else {
arc_map_data = nullptr;
}

const Arc *fsas_arcs = fsas.values.Data();
Arc *ans_arcs_data = ans_arcs.Data();
// if the dest state of any arc from any src kept state is not kept, the
// program will abort with an error.
Array1<int32_t> all_dest_states_kept(c, 1, 1);
int32_t *all_dest_states_kept_data = all_dest_states_kept.Data();
K2_EVAL(
c, ans_num_arcs, lambda_set_arcs, (int32_t ans_idx012)->void {
int32_t ans_idx01 = ans_row_ids2_data[ans_idx012], // state index
ans_idx01x = ans_row_splits2_data[ans_idx01],
ans_idx0 = ans_row_ids1_data[ans_idx01], // FSA index
ans_idx0x = ans_row_splits1_data[ans_idx0],
ans_idx1 = ans_idx01 - ans_idx0x,
ans_idx2 = ans_idx012 - ans_idx01x,
fsas_idx01 = order_data[ans_idx01],
fsas_idx01x = fsas_row_splits2_data[fsas_idx01],
fsas_idx012 = fsas_idx01x + ans_idx2;
Arc arc = fsas_arcs[fsas_idx012];
int32_t fsas_src_idx1 = arc.src_state, fsas_dest_idx1 = arc.dest_state,
fsas_idx0x = fsas_row_splits1_data[ans_idx0],
fsas_src_idx01 = fsas_idx0x + fsas_src_idx1,
fsas_dest_idx01 = fsas_idx0x + fsas_dest_idx1;
K2_CHECK_EQ(old2new_data[fsas_src_idx01], ans_idx01);
int32_t ans_dest_idx01 = old2new_data[fsas_dest_idx01];
int32_t ans_dest_idx1 = ans_dest_idx01 - ans_idx0x;
arc.src_state = ans_idx1;
arc.dest_state = ans_dest_idx1;
ans_arcs_data[ans_idx012] = arc;
if (arc_map_data != nullptr) arc_map_data[ans_idx012] = fsas_idx012;
if (ans_dest_idx01 == -1) all_dest_states_kept_data[0] = 0;
});
K2_CHECK_EQ(all_dest_states_kept[0], 1)
<< "The dest_state of an arc from a kept state is not present in `order`";
return FsaVec(ans_shape, ans_arcs);
}

} // namespace k2
Loading

0 comments on commit c414ca6

Please sign in to comment.