Skip to content

Commit

Permalink
Rework rescoring for faster and more accurate results
Browse files Browse the repository at this point in the history
  • Loading branch information
nshmyrev committed Aug 31, 2021
1 parent abff8a4 commit e2af710
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 76 deletions.
2 changes: 1 addition & 1 deletion python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def get_tag(self):

setuptools.setup(
name="vosk",
version="0.3.30",
version="0.3.31",
author="Alpha Cephei Inc",
author_email="[email protected]",
description="Offline open source speech recognition API based on Kaldi and Vosk",
Expand Down
94 changes: 42 additions & 52 deletions src/kaldi_recognizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -142,13 +142,14 @@ KaldiRecognizer::~KaldiRecognizer() {
delete g_fst_;
delete decode_fst_;
delete spk_feature_;
delete lm_fst_;

delete info;
delete lm_to_subtract_det_backoff;
delete lm_to_subtract_det_scale;
delete lm_to_add_orig;
delete lm_to_add;
delete rnnlm_info_;
delete lm_to_subtract_;
delete lm_to_subtract_scale_;
delete carpa_to_add_;
delete carpa_to_add_scale_;
delete rnnlm_to_add_;
delete rnnlm_to_add_scale_;

model_->Unref();
if (spk_model_)
Expand All @@ -166,21 +167,18 @@ void KaldiRecognizer::InitState()

void KaldiRecognizer::InitRescoring()
{
if (model_->rnnlm_lm_fst_) {
float lm_scale = 0.5;
int lm_order = 4;

info = new kaldi::rnnlm::RnnlmComputeStateInfo(model_->rnnlm_compute_opts, model_->rnnlm, model_->word_embedding_mat);
lm_to_subtract_det_backoff = new fst::BackoffDeterministicOnDemandFst<fst::StdArc>(*model_->rnnlm_lm_fst_);
lm_to_subtract_det_scale = new fst::ScaleDeterministicOnDemandFst(-lm_scale, lm_to_subtract_det_backoff);
lm_to_add_orig = new kaldi::rnnlm::KaldiRnnlmDeterministicFst(lm_order, *info);
lm_to_add = new fst::ScaleDeterministicOnDemandFst(lm_scale, lm_to_add_orig);

} else if (model_->std_lm_fst_) {
fst::CacheOptions cache_opts(true, 50000);
fst::ArcMapFstOptions mapfst_opts(cache_opts);
fst::StdToLatticeMapper<kaldi::BaseFloat> mapper;
lm_fst_ = new fst::ArcMapFst<fst::StdArc, kaldi::LatticeArc, fst::StdToLatticeMapper<kaldi::BaseFloat> >(*model_->std_lm_fst_, mapper, mapfst_opts);
if (model_->graph_lm_fst_) {
lm_to_subtract_ = new fst::BackoffDeterministicOnDemandFst<StdArc>(*model_->graph_lm_fst_);
lm_to_subtract_scale_ = new fst::ScaleDeterministicOnDemandFst(-1.0, lm_to_subtract_);
carpa_to_add_ = new ConstArpaLmDeterministicFst(model_->const_arpa_);

if (model_->rnnlm_enabled_) {
int lm_order = 4;
rnnlm_info_ = new kaldi::rnnlm::RnnlmComputeStateInfo(model_->rnnlm_compute_opts, model_->rnnlm, model_->word_embedding_mat);
rnnlm_to_add_ = new kaldi::rnnlm::KaldiRnnlmDeterministicFst(lm_order, *rnnlm_info_);
rnnlm_to_add_scale_ = new fst::ScaleDeterministicOnDemandFst(0.5, rnnlm_to_add_);
carpa_to_add_scale_ = new fst::ScaleDeterministicOnDemandFst(-0.5, carpa_to_add_);
}
}
}

Expand Down Expand Up @@ -592,38 +590,30 @@ const char* KaldiRecognizer::GetResult()
kaldi::CompactLattice rlat;
decoder_->GetLattice(true, &clat);

if (model_->rnnlm_lm_fst_) {
kaldi::ComposeLatticePrunedOptions compose_opts;
compose_opts.lattice_compose_beam = 3.0;
compose_opts.max_arcs = 3000;

if (lm_to_subtract_scale_ && carpa_to_add_) {
TopSortCompactLatticeIfNeeded(&clat);
fst::ComposeDeterministicOnDemandFst<fst::StdArc> combined_lms(lm_to_subtract_det_scale, lm_to_add);
CompactLattice composed_clat;
ComposeCompactLatticePruned(compose_opts, clat,
&combined_lms, &rlat);
lm_to_add_orig->Clear();
} else if (model_->std_lm_fst_) {
Lattice lat1;

ConvertLattice(clat, &lat1);
fst::ScaleLattice(fst::GraphLatticeScale(-1.0), &lat1);
fst::ArcSort(&lat1, fst::OLabelCompare<kaldi::LatticeArc>());
kaldi::Lattice composed_lat;
fst::Compose(lat1, *lm_fst_, &composed_lat);
fst::Invert(&composed_lat);
kaldi::CompactLattice determinized_lat;
DeterminizeLattice(composed_lat, &determinized_lat);
fst::ScaleLattice(fst::GraphLatticeScale(-1), &determinized_lat);
fst::ArcSort(&determinized_lat, fst::OLabelCompare<kaldi::CompactLatticeArc>());

kaldi::ConstArpaLmDeterministicFst const_arpa_fst(model_->const_arpa_);
kaldi::CompactLattice composed_clat;
kaldi::ComposeCompactLatticeDeterministic(determinized_lat, &const_arpa_fst, &composed_clat);
kaldi::Lattice composed_lat1;
ConvertLattice(composed_clat, &composed_lat1);
fst::Invert(&composed_lat1);
DeterminizeLattice(composed_lat1, &rlat);
CompactLattice tlat;
fst::ComposeDeterministicOnDemandFst<StdArc> combined_lm(lm_to_subtract_scale_, carpa_to_add_);
ComposeCompactLatticeDeterministic(clat, &combined_lm, &tlat);

if (rnnlm_to_add_scale_) {
ComposeLatticePrunedOptions compose_opts;
compose_opts.lattice_compose_beam = 3.0;
compose_opts.max_arcs = 3000;
TopSortCompactLatticeIfNeeded(&tlat);
fst::ComposeDeterministicOnDemandFst<StdArc> combined_rnnlm(carpa_to_add_scale_, rnnlm_to_add_scale_);
ComposeCompactLatticePruned(compose_opts, tlat,
&combined_rnnlm, &rlat);
rnnlm_to_add_->Clear();
} else {
rlat = tlat;
}

kaldi::Lattice slat;
ConvertLattice(rlat, &slat);
fst::Invert(&slat);
DeterminizeLattice(slat, &rlat);

} else {
rlat = clat;
}
Expand Down
15 changes: 9 additions & 6 deletions src/kaldi_recognizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,15 +82,18 @@ class KaldiRecognizer {
OnlineBaseFeature *spk_feature_ = nullptr;

// Rescoring
fst::ArcMapFst<fst::StdArc, kaldi::LatticeArc, fst::StdToLatticeMapper<kaldi::BaseFloat> > *lm_fst_ = nullptr;
fst::BackoffDeterministicOnDemandFst<fst::StdArc> *lm_to_subtract_ = nullptr;
fst::ScaleDeterministicOnDemandFst *lm_to_subtract_scale_ = nullptr;
kaldi::ConstArpaLmDeterministicFst *carpa_to_add_ = nullptr;
fst::ScaleDeterministicOnDemandFst *carpa_to_add_scale_ = nullptr;

// RNNLM rescoring
kaldi::rnnlm::RnnlmComputeStateInfo *info = nullptr;
fst::ScaleDeterministicOnDemandFst *lm_to_subtract_det_scale = nullptr;
fst::BackoffDeterministicOnDemandFst<fst::StdArc> *lm_to_subtract_det_backoff = nullptr;
kaldi::rnnlm::KaldiRnnlmDeterministicFst* lm_to_add_orig = nullptr;
fst::DeterministicOnDemandFst<fst::StdArc> *lm_to_add = nullptr;
kaldi::rnnlm::KaldiRnnlmDeterministicFst* rnnlm_to_add_ = nullptr;
fst::DeterministicOnDemandFst<fst::StdArc> *rnnlm_to_add_scale_ = nullptr;
kaldi::rnnlm::RnnlmComputeStateInfo *rnnlm_info_ = nullptr;


// Other
int max_alternatives_ = 0; // Disable alternatives by default
bool words_ = false;

Expand Down
25 changes: 11 additions & 14 deletions src/model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,6 @@ void Model::ConfigureV1()
rnnlm_feat_embedding_rxfilename_ = model_path_str_ + "/rnnlm/feat_embedding.final.mat";
rnnlm_config_rxfilename_ = model_path_str_ + "/rnnlm/special_symbol_opts.conf";
rnnlm_lm_rxfilename_ = model_path_str_ + "/rnnlm/final.raw";
rnnlm_lm_fst_rxfilename_ = model_path_str_ + "/rescore/G.fst";
}

void Model::ConfigureV2()
Expand Down Expand Up @@ -203,7 +202,6 @@ void Model::ConfigureV2()
rnnlm_feat_embedding_rxfilename_ = model_path_str_ + "/rnnlm/feat_embedding.final.mat";
rnnlm_config_rxfilename_ = model_path_str_ + "/rnnlm/special_symbol_opts.conf";
rnnlm_lm_rxfilename_ = model_path_str_ + "/rnnlm/final.raw";
rnnlm_lm_fst_rxfilename_ = model_path_str_ + "/rescore/G.fst";
}

void Model::ReadDataFiles()
Expand Down Expand Up @@ -296,12 +294,19 @@ void Model::ReadDataFiles()
winfo_ = new kaldi::WordBoundaryInfo(opts, winfo_rxfilename_);
}

if (stat(carpa_rxfilename_.c_str(), &buffer) == 0) {

KALDI_LOG << "Loading subtract G.fst model from " << std_fst_rxfilename_;
graph_lm_fst_ = fst::ReadAndPrepareLmFst(std_fst_rxfilename_);
KALDI_LOG << "Loading CARPA model from " << carpa_rxfilename_;
ReadKaldiObject(carpa_rxfilename_, &const_arpa_);
}

// RNNLM Rescoring
if (stat(rnnlm_lm_rxfilename_.c_str(), &buffer) == 0) {
KALDI_LOG << "Loading RNNLM model from " << rnnlm_lm_rxfilename_;

ReadKaldiObject(rnnlm_lm_rxfilename_, &rnnlm);
rnnlm_lm_fst_ = fst::ReadAndPrepareLmFst(rnnlm_lm_fst_rxfilename_);
Matrix<BaseFloat> feature_embedding_mat;
ReadKaldiObject(rnnlm_feat_embedding_rxfilename_, &feature_embedding_mat);
SparseMatrix<BaseFloat> word_feature_mat;
Expand All @@ -319,17 +324,9 @@ void Model::ReadDataFiles()

ReadConfigFromFile(rnnlm_config_rxfilename_, &rnnlm_compute_opts);

} else if (stat(carpa_rxfilename_.c_str(), &buffer) == 0) {

KALDI_LOG << "Loading CARPA model from " << carpa_rxfilename_;
std_lm_fst_ = fst::ReadFstKaldi(std_fst_rxfilename_);
fst::Project(std_lm_fst_, fst::ProjectType::OUTPUT);
if (std_lm_fst_->Properties(fst::kILabelSorted, true) == 0) {
fst::ILabelCompare<fst::StdArc> ilabel_comp;
fst::ArcSort(std_lm_fst_, ilabel_comp);
}
ReadKaldiObject(carpa_rxfilename_, &const_arpa_);
rnnlm_enabled_ = true;
}

}

void Model::Ref()
Expand Down Expand Up @@ -363,5 +360,5 @@ Model::~Model() {
delete hclg_fst_;
delete hcl_fst_;
delete g_fst_;
delete std_lm_fst_;
delete graph_lm_fst_;
}
5 changes: 2 additions & 3 deletions src/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ class Model {
string rnnlm_word_feats_rxfilename_;
string rnnlm_feat_embedding_rxfilename_;
string rnnlm_config_rxfilename_;
string rnnlm_lm_fst_rxfilename_;
string rnnlm_lm_rxfilename_;

kaldi::OnlineEndpointConfig endpoint_config_;
Expand All @@ -92,13 +91,13 @@ class Model {
fst::Fst<fst::StdArc> *hcl_fst_ = nullptr;
fst::Fst<fst::StdArc> *g_fst_ = nullptr;

fst::VectorFst<fst::StdArc> *std_lm_fst_ = nullptr;
fst::VectorFst<fst::StdArc> *graph_lm_fst_ = nullptr;
kaldi::ConstArpaLm const_arpa_;

kaldi::rnnlm::RnnlmComputeStateComputationOptions rnnlm_compute_opts;
CuMatrix<BaseFloat> word_embedding_mat;
fst::VectorFst<fst::StdArc> *rnnlm_lm_fst_ = NULL;
kaldi::nnet3::Nnet rnnlm;
bool rnnlm_enabled_ = false;

std::atomic<int> ref_cnt_;
};
Expand Down

0 comments on commit e2af710

Please sign in to comment.