From 63a8867bae77e302539358d57a48823847ad62aa Mon Sep 17 00:00:00 2001 From: uni-manjunath-ke <123362348+uni-manjunath-ke@users.noreply.github.com> Date: Fri, 25 Aug 2023 16:09:24 +0530 Subject: [PATCH 1/5] icefall feat extraction settings added (#465) --- sherpa/cpp_api/feature-config.cc | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/sherpa/cpp_api/feature-config.cc b/sherpa/cpp_api/feature-config.cc index 8aa895cb7..c6449acac 100644 --- a/sherpa/cpp_api/feature-config.cc +++ b/sherpa/cpp_api/feature-config.cc @@ -40,6 +40,12 @@ void FeatureConfig::Register(ParseOptions *po) { fbank_opts.mel_opts.num_bins = 80; RegisterMelBanksOptions(po, &fbank_opts.mel_opts); + fbank_opts.mel_opts.high_freq = -400; + fbank_opts.frame_opts.remove_dc_offset = true; + fbank_opts.frame_opts.round_to_power_of_two = true; + fbank_opts.energy_floor = 1e-10; + fbank_opts.frame_opts.snip_edges = false; + fbank_opts.frame_opts.samp_freq = 16000; po->Register("normalize-samples", &normalize_samples, "true to use samples in the range [-1, 1]. " "false to use samples in the range [-32768, 32767]. " From 1ad6943a1a812c4c1b1934b15c5090b18d3d8238 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 25 Aug 2023 11:40:06 +0100 Subject: [PATCH 2/5] Fix initial tokens for decoding. (#464) See also https://github.com/k2-fsa/icefall/issues/1206 and https://github.com/k2-fsa/icefall/pull/1208 --- sherpa/csrc/offline-transducer-greedy-search-decoder.cc | 9 +++++++-- .../offline-transducer-modified-beam-search-decoder.cc | 4 +++- sherpa/csrc/online-transducer-greedy-search-decoder.cc | 3 ++- .../online-transducer-modified-beam-search-decoder.cc | 8 ++++++-- 4 files changed, 18 insertions(+), 6 deletions(-) diff --git a/sherpa/csrc/offline-transducer-greedy-search-decoder.cc b/sherpa/csrc/offline-transducer-greedy-search-decoder.cc index 287c42f1a..9b1ca5f06 100644 --- a/sherpa/csrc/offline-transducer-greedy-search-decoder.cc +++ b/sherpa/csrc/offline-transducer-greedy-search-decoder.cc @@ -68,17 +68,22 @@ OfflineTransducerGreedySearchDecoder::Decode(torch::Tensor encoder_out, std::vector results(N); - std::vector padding(context_size, blank_id); + std::vector padding(context_size, -1); + padding.back() = blank_id; + for (auto &r : results) { // We will remove the padding at the end r.tokens = padding; } auto decoder_input = - torch::full({N, context_size}, blank_id, + torch::full({N, context_size}, -1, torch::dtype(torch::kLong) .memory_format(torch::MemoryFormat::Contiguous)); + // set the last column to blank_id, i.e., decoder_input[:, -1] = blank_id + decoder_input.index({torch::indexing::Slice(), -1}) = blank_id; + // its shape is (N, 1, joiner_dim) auto decoder_out = model_->RunDecoder(decoder_input.to(device)); diff --git a/sherpa/csrc/offline-transducer-modified-beam-search-decoder.cc b/sherpa/csrc/offline-transducer-modified-beam-search-decoder.cc index 63004d879..f075ab289 100644 --- a/sherpa/csrc/offline-transducer-modified-beam-search-decoder.cc +++ b/sherpa/csrc/offline-transducer-modified-beam-search-decoder.cc @@ -100,7 +100,9 @@ OfflineTransducerModifiedBeamSearchDecoder::Decode( if (ss != nullptr) SHERPA_CHECK_EQ(batch_size, n); - std::vector blanks(context_size, blank_id); + std::vector blanks(context_size, -1); + blanks.back() = blank_id; + Hypotheses blank_hyp({{blanks, 0}}); std::deque finalized; diff --git a/sherpa/csrc/online-transducer-greedy-search-decoder.cc b/sherpa/csrc/online-transducer-greedy-search-decoder.cc index 5ba1921db..a274bbd9e 100644 --- a/sherpa/csrc/online-transducer-greedy-search-decoder.cc +++ b/sherpa/csrc/online-transducer-greedy-search-decoder.cc @@ -28,7 +28,8 @@ OnlineTransducerGreedySearchDecoder::GetEmptyResult() { int32_t context_size = model_->ContextSize(); int32_t blank_id = 0; // always 0 OnlineTransducerDecoderResult r; - r.tokens.resize(context_size, blank_id); + r.tokens.resize(context_size, -1); + r.tokens.back() = blank_id; return r; } diff --git a/sherpa/csrc/online-transducer-modified-beam-search-decoder.cc b/sherpa/csrc/online-transducer-modified-beam-search-decoder.cc index 61f72a55a..41d1a1c2d 100644 --- a/sherpa/csrc/online-transducer-modified-beam-search-decoder.cc +++ b/sherpa/csrc/online-transducer-modified-beam-search-decoder.cc @@ -66,7 +66,9 @@ OnlineTransducerModifiedBeamSearchDecoder::GetEmptyResult() { int32_t context_size = model_->ContextSize(); int32_t blank_id = 0; // always 0 // - std::vector blanks(context_size, blank_id); + std::vector blanks(context_size, -1); + blanks.back() = blank_id; + Hypotheses blank_hyp({{blanks, 0}}); OnlineTransducerDecoderResult r; @@ -121,7 +123,9 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( int32_t N = encoder_out.size(0); int32_t T = encoder_out.size(1); - if (ss) SHERPA_CHECK_EQ(N, num_streams); + if (ss) { + SHERPA_CHECK_EQ(N, num_streams); + } std::vector cur; cur.reserve(N); From cea1dbd4a1e8f75c1ada5965a8ae863f8e0c3eb3 Mon Sep 17 00:00:00 2001 From: Shayne Mei <31003116+shaynemei@users.noreply.github.com> Date: Sat, 26 Aug 2023 23:07:00 -0700 Subject: [PATCH 3/5] Add softmax temperature to online/offline recognizer (#466) --- get_version.py | 3 ++- setup.py | 1 - sherpa/bin/offline_transducer_asr.py | 11 +++++++++++ sherpa/bin/offline_transducer_server.py | 10 ++++++++++ sherpa/bin/online_transducer_asr.py | 10 ++++++++++ sherpa/bin/streaming_server.py | 10 ++++++++++ sherpa/cpp_api/offline-recognizer.cc | 7 ++++++- sherpa/cpp_api/offline-recognizer.h | 3 +++ sherpa/cpp_api/online-recognizer.cc | 9 +++++++-- sherpa/cpp_api/online-recognizer.h | 3 +++ ...offline-transducer-modified-beam-search-decoder.cc | 2 +- .../offline-transducer-modified-beam-search-decoder.h | 8 ++++---- .../online-transducer-modified-beam-search-decoder.cc | 2 +- .../online-transducer-modified-beam-search-decoder.h | 5 +++-- sherpa/python/csrc/offline-recognizer.cc | 4 ++++ sherpa/python/csrc/online-recognizer.cc | 7 ++++++- 16 files changed, 81 insertions(+), 14 deletions(-) diff --git a/get_version.py b/get_version.py index fd112f358..15bd8ae1f 100755 --- a/get_version.py +++ b/get_version.py @@ -95,7 +95,8 @@ def get_package_version(): if not is_stable(): dt = datetime.datetime.utcnow() package_version = ( - f"{latest_version}.dev{dt.year}{dt.month:02d}{dt.day:02d}{local_version}" + f"{latest_version}.dev{dt.year}{dt.month:02d}{dt.day:02d}" + f"{local_version}" ) else: package_version = f"{latest_version}" diff --git a/setup.py b/setup.py index 702848c2d..c9340e30d 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,6 @@ #!/usr/bin/env python3 import os -import re import sys from pathlib import Path diff --git a/sherpa/bin/offline_transducer_asr.py b/sherpa/bin/offline_transducer_asr.py index 62ba3bd62..6c3a0d279 100755 --- a/sherpa/bin/offline_transducer_asr.py +++ b/sherpa/bin/offline_transducer_asr.py @@ -225,6 +225,15 @@ def add_modified_beam_search_args(parser: argparse.ArgumentParser): """, ) + parser.add_argument( + "--temperature", + type=float, + default=1.0, + help="""Used only when --decoding-method is modified_beam_search. + It specifies the softmax temperature. + """, + ) + def add_fast_beam_search_args(parser: argparse.ArgumentParser): parser.add_argument( @@ -323,6 +332,7 @@ def check_args(args): if args.decoding_method == "modified_beam_search": assert args.num_active_paths > 0, args.num_active_paths + assert args.temperature > 0, args.temperature if args.decoding_method == "fast_beam_search" and args.LG: if not Path(args.LG).is_file(): @@ -406,6 +416,7 @@ def create_recognizer(args) -> sherpa.OfflineRecognizer: feat_config=feat_config, decoding_method=args.decoding_method, fast_beam_search_config=fast_beam_search_config, + temperature=args.temperature, ) recognizer = sherpa.OfflineRecognizer(config) diff --git a/sherpa/bin/offline_transducer_server.py b/sherpa/bin/offline_transducer_server.py index 02959a230..58b96b27a 100755 --- a/sherpa/bin/offline_transducer_server.py +++ b/sherpa/bin/offline_transducer_server.py @@ -117,6 +117,14 @@ def add_modified_beam_search_args(parser: argparse.ArgumentParser): It specifies number of active paths to keep during decoding. """, ) + parser.add_argument( + "--temperature", + type=float, + default=1.0, + help="""Used only when --decoding-method is modified_beam_search. + It specifies the softmax temperature. + """, + ) def add_fast_beam_search_args(parser: argparse.ArgumentParser): @@ -210,6 +218,7 @@ def check_args(args): if args.decoding_method == "modified_beam_search": assert args.num_active_paths > 0, args.num_active_paths + assert args.temperature > 0, args.temperature if args.decoding_method == "fast_beam_search" and args.LG: if not Path(args.LG).is_file(): @@ -639,6 +648,7 @@ def create_recognizer(args) -> sherpa.OfflineRecognizer: feat_config=feat_config, decoding_method=args.decoding_method, fast_beam_search_config=fast_beam_search_config, + temperature=args.temperature ) recognizer = sherpa.OfflineRecognizer(config) diff --git a/sherpa/bin/online_transducer_asr.py b/sherpa/bin/online_transducer_asr.py index a00920c07..80a356871 100755 --- a/sherpa/bin/online_transducer_asr.py +++ b/sherpa/bin/online_transducer_asr.py @@ -214,6 +214,14 @@ def add_modified_beam_search_args(parser: argparse.ArgumentParser): Used only when --decoding-method=modified_beam_search """, ) + parser.add_argument( + "--temperature", + type=float, + default=1.0, + help="""Used only when --decoding-method is modified_beam_search. + It specifies the softmax temperature. + """, + ) def add_fast_beam_search_args(parser: argparse.ArgumentParser): @@ -313,6 +321,7 @@ def check_args(args): if args.decoding_method == "modified_beam_search": assert args.num_active_paths > 0, args.num_active_paths + assert args.temperature > 0, args.temperature if args.decoding_method == "fast_beam_search" and args.LG: if not Path(args.LG).is_file(): @@ -396,6 +405,7 @@ def create_recognizer(args) -> sherpa.OnlineRecognizer: feat_config=feat_config, decoding_method=args.decoding_method, fast_beam_search_config=fast_beam_search_config, + temperature=args.temperature ) recognizer = sherpa.OnlineRecognizer(config) diff --git a/sherpa/bin/streaming_server.py b/sherpa/bin/streaming_server.py index 9ce00cc19..d81947c9e 100755 --- a/sherpa/bin/streaming_server.py +++ b/sherpa/bin/streaming_server.py @@ -177,6 +177,14 @@ def add_modified_beam_search_args(parser: argparse.ArgumentParser): It specifies number of active paths to keep during decoding. """, ) + parser.add_argument( + "--temperature", + type=float, + default=1.0, + help="""Used only when --decoding-method is modified_beam_search. + It specifies the softmax temperature. + """, + ) def add_endpointing_args(parser: argparse.ArgumentParser): @@ -405,6 +413,7 @@ def create_recognizer(args) -> sherpa.OnlineRecognizer: tokens=args.tokens, use_gpu=args.use_gpu, num_active_paths=args.num_active_paths, + temperature=args.temperature, feat_config=feat_config, decoding_method=args.decoding_method, fast_beam_search_config=fast_beam_search_config, @@ -768,6 +777,7 @@ def check_args(args): if args.decoding_method == "modified_beam_search": assert args.num_active_paths > 0, args.num_active_paths + assert args.temperature > 0, args.temperature if args.decoding_method == "fast_beam_search" and args.LG: if not Path(args.LG).is_file(): diff --git a/sherpa/cpp_api/offline-recognizer.cc b/sherpa/cpp_api/offline-recognizer.cc index 9b299d871..9c9798bff 100644 --- a/sherpa/cpp_api/offline-recognizer.cc +++ b/sherpa/cpp_api/offline-recognizer.cc @@ -109,6 +109,10 @@ void OfflineRecognizerConfig::Register(ParseOptions *po) { po->Register("context-score", &context_score, "The bonus score for each token in context word/phrase. " "Used only when decoding_method is modified_beam_search"); + + po->Register("temperature", &temperature, + "Softmax temperature,. " + "Used only when decoding_method is modified_beam_search."); } void OfflineRecognizerConfig::Validate() const { @@ -150,7 +154,8 @@ std::string OfflineRecognizerConfig::ToString() const { os << "use_gpu=" << (use_gpu ? "True" : "False") << ", "; os << "decoding_method=\"" << decoding_method << "\", "; os << "num_active_paths=" << num_active_paths << ", "; - os << "context_score=" << context_score << ")"; + os << "context_score=" << context_score << ", "; + os << "temperature=" << temperature << ")"; return os.str(); } diff --git a/sherpa/cpp_api/offline-recognizer.h b/sherpa/cpp_api/offline-recognizer.h index 373f25d6c..8bff2a9ef 100644 --- a/sherpa/cpp_api/offline-recognizer.h +++ b/sherpa/cpp_api/offline-recognizer.h @@ -67,6 +67,9 @@ struct OfflineRecognizerConfig { /// used only for modified_beam_search float context_score = 1.5; + // temperature for the softmax in the joiner + float temperature = 1.0; + void Register(ParseOptions *po); void Validate() const; diff --git a/sherpa/cpp_api/online-recognizer.cc b/sherpa/cpp_api/online-recognizer.cc index d40d7406f..abf1a6941 100644 --- a/sherpa/cpp_api/online-recognizer.cc +++ b/sherpa/cpp_api/online-recognizer.cc @@ -113,6 +113,10 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) { "and streaming Zipformer, i.e, models from " "pruned_transducer_stateless7_streaming in icefall." "Number of frames before subsampling during decoding."); + + po->Register("temperature", &temperature, + "Softmax temperature,. " + "Used only when decoding_method is modified_beam_search."); } void OnlineRecognizerConfig::Validate() const { @@ -172,7 +176,8 @@ std::string OnlineRecognizerConfig::ToString() const { os << "context_score=" << context_score << ", "; os << "left_context=" << left_context << ", "; os << "right_context=" << right_context << ", "; - os << "chunk_size=" << chunk_size << ")"; + os << "chunk_size=" << chunk_size << ", "; + os << "temperature=" << temperature << ")"; return os.str(); } @@ -296,7 +301,7 @@ class OnlineRecognizer::OnlineRecognizerImpl { std::make_unique(model_.get()); } else if (config.decoding_method == "modified_beam_search") { decoder_ = std::make_unique( - model_.get(), config.num_active_paths); + model_.get(), config.num_active_paths, config.temperature); } else if (config.decoding_method == "fast_beam_search") { config.fast_beam_search_config.Validate(); diff --git a/sherpa/cpp_api/online-recognizer.h b/sherpa/cpp_api/online-recognizer.h index 54c4b6fbb..fe9c277c0 100644 --- a/sherpa/cpp_api/online-recognizer.h +++ b/sherpa/cpp_api/online-recognizer.h @@ -69,6 +69,9 @@ struct OnlineRecognizerConfig { // In number of frames after subsampling int32_t chunk_size = 12; + // temperature for the softmax in the joiner + float temperature = 1.0; + void Register(ParseOptions *po); void Validate() const; diff --git a/sherpa/csrc/offline-transducer-modified-beam-search-decoder.cc b/sherpa/csrc/offline-transducer-modified-beam-search-decoder.cc index f075ab289..09c30ebf4 100644 --- a/sherpa/csrc/offline-transducer-modified-beam-search-decoder.cc +++ b/sherpa/csrc/offline-transducer-modified-beam-search-decoder.cc @@ -186,7 +186,7 @@ OfflineTransducerModifiedBeamSearchDecoder::Decode( logits = logits.squeeze(1).squeeze(1); // now logits' shape is (num_hyps, vocab_size) - auto log_probs = logits.log_softmax(-1).cpu(); + auto log_probs = (logits / temperature_).log_softmax(-1).cpu(); log_probs.add_(ys_log_probs); diff --git a/sherpa/csrc/offline-transducer-modified-beam-search-decoder.h b/sherpa/csrc/offline-transducer-modified-beam-search-decoder.h index e4ab3d6c1..a6141b71a 100644 --- a/sherpa/csrc/offline-transducer-modified-beam-search-decoder.h +++ b/sherpa/csrc/offline-transducer-modified-beam-search-decoder.h @@ -15,9 +15,9 @@ namespace sherpa { class OfflineTransducerModifiedBeamSearchDecoder : public OfflineTransducerDecoder { public: - OfflineTransducerModifiedBeamSearchDecoder(OfflineTransducerModel *model, - int32_t num_active_paths) - : model_(model), num_active_paths_(num_active_paths) {} + OfflineTransducerModifiedBeamSearchDecoder( + OfflineTransducerModel *model, int32_t num_active_paths, float temperature) + : model_(model), num_active_paths_(num_active_paths), temperature_(temperature) {} /** Run modified beam search given the output from the encoder model. * @@ -35,8 +35,8 @@ class OfflineTransducerModifiedBeamSearchDecoder private: OfflineTransducerModel *model_; // Not owned - int32_t num_active_paths_; + float temperature_ = 1.0; }; } // namespace sherpa diff --git a/sherpa/csrc/online-transducer-modified-beam-search-decoder.cc b/sherpa/csrc/online-transducer-modified-beam-search-decoder.cc index 41d1a1c2d..f8b19b93d 100644 --- a/sherpa/csrc/online-transducer-modified-beam-search-decoder.cc +++ b/sherpa/csrc/online-transducer-modified-beam-search-decoder.cc @@ -173,7 +173,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( auto logits = model_->RunJoiner(cur_encoder_out, decoder_out); // logits has shape (num_hyps, vocab_size) - auto log_probs = logits.log_softmax(-1).cpu(); + auto log_probs = (logits / temperature_).log_softmax(-1).cpu(); log_probs.add_(ys_log_probs); diff --git a/sherpa/csrc/online-transducer-modified-beam-search-decoder.h b/sherpa/csrc/online-transducer-modified-beam-search-decoder.h index b13456563..3927e002c 100644 --- a/sherpa/csrc/online-transducer-modified-beam-search-decoder.h +++ b/sherpa/csrc/online-transducer-modified-beam-search-decoder.h @@ -15,8 +15,8 @@ class OnlineTransducerModifiedBeamSearchDecoder : public OnlineTransducerDecoder { public: explicit OnlineTransducerModifiedBeamSearchDecoder( - OnlineTransducerModel *model, int32_t num_active_paths) - : model_(model), num_active_paths_(num_active_paths) {} + OnlineTransducerModel *model, int32_t num_active_paths, float temperature) + : model_(model), num_active_paths_(num_active_paths), temperature_(temperature) {} OnlineTransducerDecoderResult GetEmptyResult() override; @@ -34,6 +34,7 @@ class OnlineTransducerModifiedBeamSearchDecoder private: OnlineTransducerModel *model_; // Not owned int32_t num_active_paths_; + float temperature_ = 1.0; }; } // namespace sherpa diff --git a/sherpa/python/csrc/offline-recognizer.cc b/sherpa/python/csrc/offline-recognizer.cc index d75e77f60..6a77a434c 100644 --- a/sherpa/python/csrc/offline-recognizer.cc +++ b/sherpa/python/csrc/offline-recognizer.cc @@ -155,6 +155,7 @@ static void PybindOfflineRecognizerConfig(py::module &m) { // NOLINT .def(py::init([](const std::string &nn_model, const std::string &tokens, bool use_gpu = false, int32_t num_active_paths = 4, float context_score = 1.5, + float temperature = 1.0, const OfflineCtcDecoderConfig &ctc_decoder_config = {}, const FeatureConfig &feat_config = {}, const FastBeamSearchConfig &fast_beam_search_config = {}, @@ -171,11 +172,13 @@ static void PybindOfflineRecognizerConfig(py::module &m) { // NOLINT config->decoding_method = decoding_method; config->num_active_paths = num_active_paths; config->context_score = context_score; + config->temperature = temperature; return config; }), py::arg("nn_model"), py::arg("tokens"), py::arg("use_gpu") = false, py::arg("num_active_paths") = 4, py::arg("context_score") = 1.5, + py::arg("temperature") = 1.0, py::arg("ctc_decoder_config") = OfflineCtcDecoderConfig(), py::arg("feat_config") = FeatureConfig(), py::arg("fast_beam_search_config") = FastBeamSearchConfig(), @@ -193,6 +196,7 @@ static void PybindOfflineRecognizerConfig(py::module &m) { // NOLINT .def_readwrite("decoding_method", &PyClass::decoding_method) .def_readwrite("num_active_paths", &PyClass::num_active_paths) .def_readwrite("context_score", &PyClass::context_score) + .def_readwrite("temperature", &PyClass::temperature) .def("validate", &PyClass::Validate); } diff --git a/sherpa/python/csrc/online-recognizer.cc b/sherpa/python/csrc/online-recognizer.cc index 598a12072..f4a9c3652 100644 --- a/sherpa/python/csrc/online-recognizer.cc +++ b/sherpa/python/csrc/online-recognizer.cc @@ -23,6 +23,7 @@ static void PybindOnlineRecognizerConfig(py::module &m) { // NOLINT int32_t num_active_paths = 4, float context_score = 1.5, int32_t left_context = 64, int32_t right_context = 0, int32_t chunk_size = 16, + float temperature = 1.0, const FeatureConfig &feat_config = {}, const EndpointConfig &endpoint_config = {}, const FastBeamSearchConfig &fast_beam_search_config = {}) @@ -45,6 +46,7 @@ static void PybindOnlineRecognizerConfig(py::module &m) { // NOLINT ans->left_context = left_context; ans->right_context = right_context; ans->chunk_size = chunk_size; + ans->temperature = temperature; return ans; }), @@ -55,9 +57,11 @@ static void PybindOnlineRecognizerConfig(py::module &m) { // NOLINT py::arg("decoding_method") = "greedy_search", py::arg("num_active_paths") = 4, py::arg("context_score") = 1.5, py::arg("left_context") = 64, py::arg("right_context") = 0, - py::arg("chunk_size") = 16, py::arg("feat_config") = FeatureConfig(), + py::arg("chunk_size") = 16, py::arg("temperature") = 1.0, + py::arg("feat_config") = FeatureConfig(), py::arg("endpoint_config") = EndpointConfig(), py::arg("fast_beam_search_config") = FastBeamSearchConfig()) + .def_readwrite("feat_config", &PyClass::feat_config) .def_readwrite("endpoint_config", &PyClass::endpoint_config) .def_readwrite("fast_beam_search_config", @@ -75,6 +79,7 @@ static void PybindOnlineRecognizerConfig(py::module &m) { // NOLINT .def_readwrite("left_context", &PyClass::left_context) .def_readwrite("right_context", &PyClass::right_context) .def_readwrite("chunk_size", &PyClass::chunk_size) + .def_readwrite("temperature", &PyClass::temperature) .def("validate", &PyClass::Validate) .def("__str__", [](const PyClass &self) -> std::string { return self.ToString(); }); From 4369bd71325debe8752c48cabef2e816811a56f8 Mon Sep 17 00:00:00 2001 From: Wei Kang Date: Mon, 28 Aug 2023 15:41:10 +0800 Subject: [PATCH 4/5] Fix compile error (#467) --- sherpa/cpp_api/offline-recognizer-transducer-impl.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sherpa/cpp_api/offline-recognizer-transducer-impl.h b/sherpa/cpp_api/offline-recognizer-transducer-impl.h index fbb13998d..b44582c5f 100644 --- a/sherpa/cpp_api/offline-recognizer-transducer-impl.h +++ b/sherpa/cpp_api/offline-recognizer-transducer-impl.h @@ -69,7 +69,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { std::make_unique(model_.get()); } else if (config.decoding_method == "modified_beam_search") { decoder_ = std::make_unique( - model_.get(), config.num_active_paths); + model_.get(), config.num_active_paths, config.temperature); } else if (config.decoding_method == "fast_beam_search") { config.fast_beam_search_config.Validate(); From 96de89a13e40c0b62de580fc6c8241050c51638a Mon Sep 17 00:00:00 2001 From: Wei Kang Date: Mon, 28 Aug 2023 19:37:52 +0800 Subject: [PATCH 5/5] Fix context graph bug (#468) --- sherpa/csrc/context-graph.cc | 17 ++++------------- sherpa/csrc/context-graph.h | 6 +++--- sherpa/csrc/test-context-graph.cc | 7 ++++--- 3 files changed, 11 insertions(+), 19 deletions(-) diff --git a/sherpa/csrc/context-graph.cc b/sherpa/csrc/context-graph.cc index f5f67a61b..8e71a8f1b 100644 --- a/sherpa/csrc/context-graph.cc +++ b/sherpa/csrc/context-graph.cc @@ -33,7 +33,7 @@ void ContextGraph::Build( bool is_end = j == (static_cast(token_ids[i].size()) - 1); node->next[token] = std::make_unique( token, context_score_, node->node_score + context_score_, - is_end ? 0 : node->local_node_score + context_score_, is_end); + is_end ? node->node_score + context_score_ : 0, is_end); } node = node->next[token].get(); } @@ -48,7 +48,6 @@ std::pair ContextGraph::ForwardOneStep( if (1 == state->next.count(token)) { node = state->next.at(token).get(); score = node->token_score; - if (state->is_end) score += state->node_score; } else { node = state->fail; while (0 == node->next.count(token)) { @@ -58,24 +57,15 @@ std::pair ContextGraph::ForwardOneStep( if (1 == node->next.count(token)) { node = node->next.at(token).get(); } - score = node->node_score - state->local_node_score; + score = node->node_score - state->node_score; } SHERPA_CHECK(nullptr != node); - float matched_score = 0; - auto output = node->output; - while (nullptr != output) { - matched_score += output->node_score; - output = output->output; - } - return std::make_pair(score + matched_score, node); + return std::make_pair(score + node->output_score, node); } std::pair ContextGraph::Finalize( const ContextState *state) const { float score = -state->node_score; - if (state->is_end) { - score = 0; - } return std::make_pair(score, root_.get()); } @@ -112,6 +102,7 @@ void ContextGraph::FillFailOutput() const { } } kv.second->output = output; + kv.second->output_score += output == nullptr ? 0 : output->output_score; node_queue.push(kv.second.get()); } } diff --git a/sherpa/csrc/context-graph.h b/sherpa/csrc/context-graph.h index 18e502136..ab591b0c7 100644 --- a/sherpa/csrc/context-graph.h +++ b/sherpa/csrc/context-graph.h @@ -35,7 +35,7 @@ struct ContextState { int32_t token; float token_score; float node_score; - float local_node_score; + float output_score; bool is_end; std::unordered_map> next; const ContextState *fail = nullptr; @@ -43,11 +43,11 @@ struct ContextState { ContextState() = default; ContextState(int32_t token, float token_score, float node_score, - float local_node_score, bool is_end) + float output_score, bool is_end) : token(token), token_score(token_score), node_score(node_score), - local_node_score(local_node_score), + output_score(output_score), is_end(is_end) {} }; diff --git a/sherpa/csrc/test-context-graph.cc b/sherpa/csrc/test-context-graph.cc index d1f820027..00b18903e 100644 --- a/sherpa/csrc/test-context-graph.cc +++ b/sherpa/csrc/test-context-graph.cc @@ -29,14 +29,15 @@ TEST(ContextGraph, TestBasic) { std::vector contexts_str( {"S", "HE", "SHE", "SHELL", "HIS", "HERS", "HELLO", "THIS", "THEM"}); std::vector> contexts; - for (int32_t i = 0; i < contexts_str.size(); ++i) { + for (size_t i = 0; i < contexts_str.size(); ++i) { contexts.emplace_back(contexts_str[i].begin(), contexts_str[i].end()); } auto context_graph = ContextGraph(contexts, 1); auto queries = std::map{ - {"HEHERSHE", 14}, {"HERSHE", 12}, {"HISHE", 9}, {"SHED", 6}, - {"HELL", 2}, {"HELLO", 7}, {"DHRHISQ", 4}, {"THEN", 2}}; + {"HEHERSHE", 14}, {"HERSHE", 12}, {"HISHE", 9}, + {"SHED", 6}, {"SHELF", 6}, {"HELL", 2}, + {"HELLO", 7}, {"DHRHISQ", 4}, {"THEN", 2}}; for (const auto &iter : queries) { float total_scores = 0;