diff --git a/sherpa/bin/offline_transducer_asr.py b/sherpa/bin/offline_transducer_asr.py index 6c3a0d279..58650bdf8 100755 --- a/sherpa/bin/offline_transducer_asr.py +++ b/sherpa/bin/offline_transducer_asr.py @@ -155,6 +155,13 @@ def add_model_args(parser: argparse.ArgumentParser): help="Feature dimension of the model", ) + parser.add_argument( + "--use-bbpe", + type=str2bool, + default=False, + help="Whether the model to be used is trained with bbpe", + ) + def add_decoding_args(parser: argparse.ArgumentParser): parser.add_argument( @@ -413,6 +420,7 @@ def create_recognizer(args) -> sherpa.OfflineRecognizer: use_gpu=args.use_gpu, num_active_paths=args.num_active_paths, context_score=args.context_score, + use_bbpe=args.use_bbpe, feat_config=feat_config, decoding_method=args.decoding_method, fast_beam_search_config=fast_beam_search_config, diff --git a/sherpa/bin/offline_transducer_server.py b/sherpa/bin/offline_transducer_server.py index 58b96b27a..aae180c75 100755 --- a/sherpa/bin/offline_transducer_server.py +++ b/sherpa/bin/offline_transducer_server.py @@ -91,6 +91,13 @@ def add_model_args(parser: argparse.ArgumentParser): help="Feature dimension of the model", ) + parser.add_argument( + "--use-bbpe", + type=sherpa.str2bool, + default=False, + help="Whether the model to be used is trained with bbpe", + ) + def add_decoding_args(parser: argparse.ArgumentParser): parser.add_argument( @@ -645,6 +652,7 @@ def create_recognizer(args) -> sherpa.OfflineRecognizer: tokens=args.tokens, use_gpu=args.use_gpu, num_active_paths=args.num_active_paths, + use_bbpe=args.use_bbpe, feat_config=feat_config, decoding_method=args.decoding_method, fast_beam_search_config=fast_beam_search_config, diff --git a/sherpa/bin/online_transducer_asr.py b/sherpa/bin/online_transducer_asr.py index 80a356871..145f459cc 100755 --- a/sherpa/bin/online_transducer_asr.py +++ b/sherpa/bin/online_transducer_asr.py @@ -144,6 +144,13 @@ def add_model_args(parser: argparse.ArgumentParser): help="Feature dimension of the model", ) + parser.add_argument( + "--use-bbpe", + type=str2bool, + default=False, + help="Whether the model to be used is trained with bbpe", + ) + def add_decoding_args(parser: argparse.ArgumentParser): parser.add_argument( @@ -402,6 +409,7 @@ def create_recognizer(args) -> sherpa.OnlineRecognizer: use_gpu=args.use_gpu, num_active_paths=args.num_active_paths, context_score=args.context_score, + use_bbpe=args.use_bbpe, feat_config=feat_config, decoding_method=args.decoding_method, fast_beam_search_config=fast_beam_search_config, diff --git a/sherpa/bin/streaming_server.py b/sherpa/bin/streaming_server.py index 73ec28a55..1398d1593 100755 --- a/sherpa/bin/streaming_server.py +++ b/sherpa/bin/streaming_server.py @@ -151,6 +151,13 @@ def add_model_args(parser: argparse.ArgumentParser): help="Feature dimension of the model", ) + parser.add_argument( + "--use-bbpe", + type=sherpa.str2bool, + default=False, + help="Whether the model to be used is trained with bbpe", + ) + def add_decoding_args(parser: argparse.ArgumentParser): parser.add_argument( @@ -413,6 +420,7 @@ def create_recognizer(args) -> sherpa.OnlineRecognizer: tokens=args.tokens, use_gpu=args.use_gpu, num_active_paths=args.num_active_paths, + use_bbpe=args.use_bbpe, temperature=args.temperature, feat_config=feat_config, decoding_method=args.decoding_method, diff --git a/sherpa/cpp_api/feature-config.cc b/sherpa/cpp_api/feature-config.cc index 8aa895cb7..c6449acac 100644 --- a/sherpa/cpp_api/feature-config.cc +++ b/sherpa/cpp_api/feature-config.cc @@ -40,6 +40,12 @@ void FeatureConfig::Register(ParseOptions *po) { fbank_opts.mel_opts.num_bins = 80; RegisterMelBanksOptions(po, &fbank_opts.mel_opts); + fbank_opts.mel_opts.high_freq = -400; + fbank_opts.frame_opts.remove_dc_offset = true; + fbank_opts.frame_opts.round_to_power_of_two = true; + fbank_opts.energy_floor = 1e-10; + fbank_opts.frame_opts.snip_edges = false; + fbank_opts.frame_opts.samp_freq = 16000; po->Register("normalize-samples", &normalize_samples, "true to use samples in the range [-1, 1]. " "false to use samples in the range [-32768, 32767]. " diff --git a/sherpa/cpp_api/offline-recognizer-transducer-impl.h b/sherpa/cpp_api/offline-recognizer-transducer-impl.h index fbb13998d..f7cff07c2 100644 --- a/sherpa/cpp_api/offline-recognizer-transducer-impl.h +++ b/sherpa/cpp_api/offline-recognizer-transducer-impl.h @@ -12,6 +12,7 @@ #include "sherpa/cpp_api/feature-config.h" #include "sherpa/cpp_api/offline-recognizer-impl.h" +#include "sherpa/csrc/byte_util.h" #include "sherpa/csrc/context-graph.h" #include "sherpa/csrc/offline-conformer-transducer-model.h" #include "sherpa/csrc/offline-transducer-decoder.h" @@ -25,7 +26,7 @@ namespace sherpa { static OfflineRecognitionResult Convert( const OfflineTransducerDecoderResult &src, const SymbolTable &sym_table, - int32_t frame_shift_ms, int32_t subsampling_factor) { + int32_t frame_shift_ms, int32_t subsampling_factor, bool use_bbpe) { OfflineRecognitionResult r; r.tokens.reserve(src.tokens.size()); r.timestamps.reserve(src.timestamps.size()); @@ -37,6 +38,12 @@ static OfflineRecognitionResult Convert( r.tokens.push_back(std::move(sym)); } + + if (use_bbpe) { + auto bu = GetByteUtil(); + text = bu->Decode(text); + } + r.text = std::move(text); float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor; @@ -69,7 +76,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { std::make_unique(model_.get()); } else if (config.decoding_method == "modified_beam_search") { decoder_ = std::make_unique( - model_.get(), config.num_active_paths); + model_.get(), config.num_active_paths, config.temperature); } else if (config.decoding_method == "fast_beam_search") { config.fast_beam_search_config.Validate(); @@ -133,7 +140,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { auto ans = Convert(results[i], symbol_table_, config_.feat_config.fbank_opts.frame_opts.frame_shift_ms, - model_->SubsamplingFactor()); + model_->SubsamplingFactor(), config_.use_bbpe); ss[i]->SetResult(ans); } diff --git a/sherpa/cpp_api/offline-recognizer.cc b/sherpa/cpp_api/offline-recognizer.cc index 9c9798bff..039e9ae4b 100644 --- a/sherpa/cpp_api/offline-recognizer.cc +++ b/sherpa/cpp_api/offline-recognizer.cc @@ -110,6 +110,13 @@ void OfflineRecognizerConfig::Register(ParseOptions *po) { "The bonus score for each token in context word/phrase. " "Used only when decoding_method is modified_beam_search"); + po->Register("use-bbpe", &use_bbpe, + "true if the model to use is trained with byte level bpe, " + "The byte level bpe modeling unit is mainly used on CJK " + "languages or multilingual datasets, it can further break " + "the multi-byte unicode characters into byte sequence and " + "then train some kind of sub-char bpes."); + po->Register("temperature", &temperature, "Softmax temperature,. " "Used only when decoding_method is modified_beam_search."); @@ -155,6 +162,7 @@ std::string OfflineRecognizerConfig::ToString() const { os << "decoding_method=\"" << decoding_method << "\", "; os << "num_active_paths=" << num_active_paths << ", "; os << "context_score=" << context_score << ", "; + os << "use_bbpe=" << (use_bbpe ? "True" : "False") << ", "; os << "temperature=" << temperature << ")"; return os.str(); diff --git a/sherpa/cpp_api/offline-recognizer.h b/sherpa/cpp_api/offline-recognizer.h index 8bff2a9ef..6fe7ccb4e 100644 --- a/sherpa/cpp_api/offline-recognizer.h +++ b/sherpa/cpp_api/offline-recognizer.h @@ -67,6 +67,9 @@ struct OfflineRecognizerConfig { /// used only for modified_beam_search float context_score = 1.5; + // True if the model used is trained with byte level bpe. + bool use_bbpe = false; + // temperature for the softmax in the joiner float temperature = 1.0; diff --git a/sherpa/cpp_api/online-recognizer.cc b/sherpa/cpp_api/online-recognizer.cc index abf1a6941..eedbea5a6 100644 --- a/sherpa/cpp_api/online-recognizer.cc +++ b/sherpa/cpp_api/online-recognizer.cc @@ -9,6 +9,7 @@ #include #include "nlohmann/json.hpp" +#include "sherpa/csrc/byte_util.h" #include "sherpa/csrc/file-utils.h" #include "sherpa/csrc/log.h" #include "sherpa/csrc/online-conformer-transducer-model.h" @@ -114,6 +115,13 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) { "pruned_transducer_stateless7_streaming in icefall." "Number of frames before subsampling during decoding."); + po->Register("use-bbpe", &use_bbpe, + "true if the model to use is trained with byte level bpe, " + "The byte level bpe modeling unit is mainly used on CJK " + "languages or multilingual datasets, it can further break " + "the multi-byte unicode characters into byte sequence and " + "then train some kind of sub-char bpes."); + po->Register("temperature", &temperature, "Softmax temperature,. " "Used only when decoding_method is modified_beam_search."); @@ -177,6 +185,7 @@ std::string OnlineRecognizerConfig::ToString() const { os << "left_context=" << left_context << ", "; os << "right_context=" << right_context << ", "; os << "chunk_size=" << chunk_size << ", "; + os << "use_bbpe=" << (use_bbpe ? "True" : "False") << ", "; os << "temperature=" << temperature << ")"; return os.str(); } @@ -184,7 +193,8 @@ std::string OnlineRecognizerConfig::ToString() const { static OnlineRecognitionResult Convert(const OnlineTransducerDecoderResult &src, const SymbolTable &sym_table, int32_t frame_shift_ms, - int32_t subsampling_factor) { + int32_t subsampling_factor, + bool use_bbpe) { OnlineRecognitionResult r; r.tokens.reserve(src.tokens.size()); r.timestamps.reserve(src.timestamps.size()); @@ -196,6 +206,12 @@ static OnlineRecognitionResult Convert(const OnlineTransducerDecoderResult &src, r.tokens.push_back(std::move(sym)); } + + if (use_bbpe) { + auto bu = GetByteUtil(); + text = bu->Decode(text); + } + r.text = std::move(text); float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor; @@ -440,7 +456,7 @@ class OnlineRecognizer::OnlineRecognizerImpl { auto ans = Convert(r, symbol_table_, config_.feat_config.fbank_opts.frame_opts.frame_shift_ms, - model_->SubsamplingFactor()); + model_->SubsamplingFactor(), config_.use_bbpe); ans.is_final = is_final; ans.segment = s->GetWavSegment(); diff --git a/sherpa/cpp_api/online-recognizer.h b/sherpa/cpp_api/online-recognizer.h index fe9c277c0..0423fbb03 100644 --- a/sherpa/cpp_api/online-recognizer.h +++ b/sherpa/cpp_api/online-recognizer.h @@ -69,6 +69,9 @@ struct OnlineRecognizerConfig { // In number of frames after subsampling int32_t chunk_size = 12; + // True if the model used is trained with byte level bpe. + bool use_bbpe = false; + // temperature for the softmax in the joiner float temperature = 1.0; diff --git a/sherpa/csrc/CMakeLists.txt b/sherpa/csrc/CMakeLists.txt index 4e1fb2230..9e65119f2 100644 --- a/sherpa/csrc/CMakeLists.txt +++ b/sherpa/csrc/CMakeLists.txt @@ -1,5 +1,6 @@ # Please sort the filenames alphabetically set(sherpa_srcs + byte_util.cc context-graph.cc fbank-features.cc file-utils.cc @@ -66,6 +67,7 @@ if(SHERPA_ENABLE_TESTS) # test-offline-conformer-transducer-model.cc # test-online-conv-emformer-transducer-model.cc + test-byte-util.cc test-context-graph.cc test-hypothesis.cc test-log.cc diff --git a/sherpa/csrc/byte_util.cc b/sherpa/csrc/byte_util.cc new file mode 100644 index 000000000..98560db20 --- /dev/null +++ b/sherpa/csrc/byte_util.cc @@ -0,0 +1,221 @@ +/** Copyright 2023 Xiaomi Corporation (authors: Wei Kang) + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "sherpa/csrc/byte_util.h" + +#include // NOLINT +#include + +#include "sherpa/csrc/log.h" + +namespace sherpa { + +ByteUtil::ByteUtil() { + // The table below is copied from + // https://github.com/k2-fsa/icefall/blob/master/icefall/byte_utils.py + // which is used to train byte level bpe, if you change the table in icefall + // you have to change the table below accordingly. + byte2token_ = std::vector( + {256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, + 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, + 284, 285, 286, 287, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, + 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, + 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, + 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, + 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, + 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, + 126, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, + 301, 302, 303, 304, 305, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 321, 322, 323, 324, 325, 326, 327, 328, 330, 331, 332, 333, + 334, 335, 336, 337, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, + 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, 359, 360, 361, + 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, + 376, 377, 378, 379, 380, 381, 382, 384, 385, 386, 387, 388, 389, 390, + 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, 401, 402, 403, 404, + 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, + 419, 420, 421, 422}); + max_token_ = 422; // the max number in above table + token2byte_ = + std::vector(max_token_ + 1, -1); // the max token in byte2token_ + // is 422, so we set the length + // of token2bytes_ 423. + for (size_t i = 0; i < byte2token_.size(); ++i) { + token2byte_[byte2token_[i]] = i; + } +} + +std::string ByteUtil::Encode(const std::string &str) const { + std::ostringstream oss; + const uint8_t *p = reinterpret_cast(str.data()); + for (size_t i = 0; i < str.size(); ++i) { + oss << CodePointToUTF8String(byte2token_[p[i]]); + } + return oss.str(); +} + +std::string ByteUtil::Decode(const std::string &str) const { + std::vector bytes; + UTF8StringToTokensAndMapToBytes(str, &bytes); + std::vector codes; + BytesToCodePoints(bytes.data(), bytes.size(), &codes); + std::ostringstream oss; + for (size_t i = 0; i < codes.size(); ++i) { + oss << CodePointToUTF8String(codes[i]); + } + return oss.str(); +} + +void ByteUtil::UTF8StringToTokensAndMapToBytes( + const std::string &str, std::vector *bytes) const { + const char *data = str.data(); + bytes->clear(); + const size_t length = str.size(); + for (size_t i = 0; i < length; /* no update */) { + int32_t c = data[i++] & 0xff; + if ((c & 0x80) == 0) { + if (c > max_token_ || token2byte_[c] == -1) { + SHERPA_LOG(WARNING) << "Skip OOV token, code point : " << c + << " utf8 char : " << CodePointToUTF8String(c); + continue; + } + bytes->push_back(token2byte_[c]); + } else { + if ((c & 0xc0) == 0x80) { + SHERPA_LOG(FATAL) << "Invalid utf8 string : " << str + << ", code point : " << c; + } + int32_t count = + (c >= 0xc0) + (c >= 0xe0) + (c >= 0xf0) + (c >= 0xf8) + (c >= 0xfc); + int32_t code = c & ((1 << (6 - count)) - 1); + while (count != 0) { + if (i == length) { + SHERPA_LOG(FATAL) + << "Invalid utf8 string : " << str << ", code point : " << code; + } + char cb = data[i++]; + if ((cb & 0xc0) != 0x80) { + SHERPA_LOG(FATAL) + << "Invalid utf8 string : " << str << ", code point : " << code; + } + code = (code << 6) | (cb & 0x3f); + count--; + } + if (code < 0) { + // This should not be able to happen. + SHERPA_LOG(FATAL) << "Invalid utf8 string : " << str + << ", code point : " << code; + } + if (code > max_token_ || token2byte_[code] == -1) { + SHERPA_LOG(WARNING) << "Skip OOV token, code point : " << code + << " utf8 char : " << CodePointToUTF8String(code); + continue; + } + bytes->push_back(token2byte_[code]); + } + } +} + +void ByteUtil::BytesToCodePoints(const uint8_t *bytes, int32_t length, + std::vector *codes) const { + if (length <= 0) { + return; + } + const char *data = reinterpret_cast(bytes); + int32_t idx = 1; // means starting from the next byte + for (int32_t i = 0; i < length; /* no update */) { + int32_t c = data[i++] & 0xff; + if ((c & 0x80) == 0) { + codes->push_back(c); + idx = i + 1; + } else { + if ((c & 0xc0) == 0x80) { + BytesToCodePoints(bytes + idx, length - idx, codes); + return; + } + int32_t count = + (c >= 0xc0) + (c >= 0xe0) + (c >= 0xf0) + (c >= 0xf8) + (c >= 0xfc); + int32_t code = c & ((1 << (6 - count)) - 1); + while (count != 0) { + if (i == length) { + BytesToCodePoints(bytes + idx, length - idx, codes); + return; + } + char cb = data[i++]; + if ((cb & 0xc0) != 0x80) { + BytesToCodePoints(bytes + idx, length - idx, codes); + return; + } + code = (code << 6) | (cb & 0x3f); + count--; + } + if (code < 0) { + BytesToCodePoints(bytes + idx, length - idx, codes); + return; + } + codes->push_back(code); + idx = i + 1; + } + } +} + +std::string ByteUtil::CodePointToUTF8String(int32_t code) const { + std::ostringstream ostr; + if (code < 0) { + SHERPA_LOG(FATAL) << "Invalid utf8 code point : " << code; + return ostr.str(); // Unreachable code. + } else if (code < 0x80) { + ostr << static_cast(code); + } else if (code < 0x800) { + ostr << static_cast((code >> 6) | 0xc0); + ostr << static_cast((code & 0x3f) | 0x80); + } else if (code < 0x10000) { + ostr << static_cast((code >> 12) | 0xe0); + ostr << static_cast(((code >> 6) & 0x3f) | 0x80); + ostr << static_cast((code & 0x3f) | 0x80); + } else if (code < 0x200000) { + ostr << static_cast((code >> 18) | 0xf0); + ostr << static_cast(((code >> 12) & 0x3f) | 0x80); + ostr << static_cast(((code >> 6) & 0x3f) | 0x80); + ostr << static_cast((code & 0x3f) | 0x80); + } else if (code < 0x4000000) { + ostr << static_cast((code >> 24) | 0xf8); + ostr << static_cast(((code >> 18) & 0x3f) | 0x80); + ostr << static_cast(((code >> 12) & 0x3f) | 0x80); + ostr << static_cast(((code >> 6) & 0x3f) | 0x80); + ostr << static_cast((code & 0x3f) | 0x80); + } else { + ostr << static_cast((code >> 30) | 0xfc); + ostr << static_cast(((code >> 24) & 0x3f) | 0x80); + ostr << static_cast(((code >> 18) & 0x3f) | 0x80); + ostr << static_cast(((code >> 12) & 0x3f) | 0x80); + ostr << static_cast(((code >> 6) & 0x3f) | 0x80); + ostr << static_cast((code & 0x3f) | 0x80); + } + return ostr.str(); +} + +const ByteUtilPtr GetByteUtil() { + static ByteUtilPtr bu = nullptr; + static std::once_flag init_flag; + + std::call_once(init_flag, + []() { bu = std::make_shared(ByteUtil()); }); + SHERPA_CHECK_NE(bu, nullptr); + return bu; +} + +} // namespace sherpa diff --git a/sherpa/csrc/byte_util.h b/sherpa/csrc/byte_util.h new file mode 100644 index 000000000..c89702ee7 --- /dev/null +++ b/sherpa/csrc/byte_util.h @@ -0,0 +1,115 @@ +/** + * Copyright 2023 Xiaomi Corporation (authors: Wei Kang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef SHERPA_CSRC_BYTE_UTIL_H_ +#define SHERPA_CSRC_BYTE_UTIL_H_ + +#include +#include +#include + +namespace sherpa { + +class ByteUtil; +using ByteUtilPtr = std::shared_ptr; + +/* The class implements the functions in byte_utils.py + * (https://github.com/k2-fsa/icefall/blob/master/icefall/byte_utils.py) + * It will be used to decode the output hypothesis of model trained with byte + * level bpe. + * + * Caution: The base characters (the byte token table) in the constructor MUST + * be the same as the `PRINTABLE_BASE_CHARS` in icefall. + */ +class ByteUtil { + public: + ByteUtil(); + /* + * Encode the normal string (for example, the transcripts in dataset) to a + * special utf8 characters sequence, the characters are all in the byte2token_ + * table (see in the constructor). It breaks the non-ascii characters into + * several characters (each byte a character), while the printable ascii will + * keep the same. + * + * @param str The original string. + * + * @returns Returns the encoded string. + */ + std::string Encode(const std::string &str) const; + + /* Decode the string encoded by Encode to its original one. + * str should be equal to Decode(Encode(str)). + * + * Note: The str here actually represents a sequence of bytes, the number of + * bytes equals to the number of utf8 characters, we will re-map this utf8 + * characters back to bytes with token2byte_ and then convert the bytes array + * to a string. Sometimes, there will be some invalid bytes in the array, we + * will drop these invalid bytes when decoding the bytes array. See more + * examples in test-byte-util.cc. + * + * @returns Return the deocded string. + */ + std::string Decode(const std::string &str) const; + + private: + int32_t max_token_; // The max token in byte2token_. + std::vector token2byte_; // map token to byte. + std::vector byte2token_; // map byte to token. + + /* Convert utf8 code points to corresponding character. + * @param code The utf8 code point. + * + * @return Returns the corresponding character (as std::string). + */ + std::string CodePointToUTF8String(int32_t code) const; + + /* Convert bytes to corresponding utf8 code points. + * + * Note: We will skip invalid bytes (i.e the bytes can not combine into a + * valid utf8 character). + * + * @param bytes The pointer to the bytes array. + * @param length The length of bytes array. + * @param code The utf8 code points will be written here. + */ + void BytesToCodePoints(const uint8_t *bytes, int32_t length, + std::vector *codes) const; + /* + * The utf8 string here is expected to be the encoded string (the string + * encoded by Encode or the recognition result from a asr system built with + * byte level bpe. + * + * This function first extract the utf8 characters from the str, then map them + * to byte with token2byte_. + * + * @param str The input string. + * @param bytes The converted bytes will be written here. + */ + void UTF8StringToTokensAndMapToBytes(const std::string &str, + std::vector *bytes) const; +}; + +/* + * Get the ByteUtil pointer, this guarantees the ByteUtil object only be + * initialized once. + */ +const ByteUtilPtr GetByteUtil(); + +} // namespace sherpa + +#endif // SHERPA_CSRC_BYTE_UTIL_H_ diff --git a/sherpa/csrc/context-graph.cc b/sherpa/csrc/context-graph.cc index f5f67a61b..ab1b87d32 100644 --- a/sherpa/csrc/context-graph.cc +++ b/sherpa/csrc/context-graph.cc @@ -25,15 +25,15 @@ namespace sherpa { void ContextGraph::Build( const std::vector> &token_ids) const { - for (int32_t i = 0; i < static_cast(token_ids.size()); ++i) { + for (size_t i = 0; i < token_ids.size(); ++i) { auto node = root_.get(); - for (int32_t j = 0; j < static_cast(token_ids[i].size()); ++j) { + for (size_t j = 0; j < token_ids[i].size(); ++j) { int32_t token = token_ids[i][j]; if (0 == node->next.count(token)) { - bool is_end = j == (static_cast(token_ids[i].size()) - 1); + bool is_end = j == (token_ids[i].size() - 1); node->next[token] = std::make_unique( token, context_score_, node->node_score + context_score_, - is_end ? 0 : node->local_node_score + context_score_, is_end); + is_end ? node->node_score + context_score_ : 0, is_end); } node = node->next[token].get(); } @@ -48,7 +48,6 @@ std::pair ContextGraph::ForwardOneStep( if (1 == state->next.count(token)) { node = state->next.at(token).get(); score = node->token_score; - if (state->is_end) score += state->node_score; } else { node = state->fail; while (0 == node->next.count(token)) { @@ -58,24 +57,15 @@ std::pair ContextGraph::ForwardOneStep( if (1 == node->next.count(token)) { node = node->next.at(token).get(); } - score = node->node_score - state->local_node_score; + score = node->node_score - state->node_score; } SHERPA_CHECK(nullptr != node); - float matched_score = 0; - auto output = node->output; - while (nullptr != output) { - matched_score += output->node_score; - output = output->output; - } - return std::make_pair(score + matched_score, node); + return std::make_pair(score + node->output_score, node); } std::pair ContextGraph::Finalize( const ContextState *state) const { float score = -state->node_score; - if (state->is_end) { - score = 0; - } return std::make_pair(score, root_.get()); } @@ -112,6 +102,7 @@ void ContextGraph::FillFailOutput() const { } } kv.second->output = output; + kv.second->output_score += output == nullptr ? 0 : output->output_score; node_queue.push(kv.second.get()); } } diff --git a/sherpa/csrc/context-graph.h b/sherpa/csrc/context-graph.h index 18e502136..ab591b0c7 100644 --- a/sherpa/csrc/context-graph.h +++ b/sherpa/csrc/context-graph.h @@ -35,7 +35,7 @@ struct ContextState { int32_t token; float token_score; float node_score; - float local_node_score; + float output_score; bool is_end; std::unordered_map> next; const ContextState *fail = nullptr; @@ -43,11 +43,11 @@ struct ContextState { ContextState() = default; ContextState(int32_t token, float token_score, float node_score, - float local_node_score, bool is_end) + float output_score, bool is_end) : token(token), token_score(token_score), node_score(node_score), - local_node_score(local_node_score), + output_score(output_score), is_end(is_end) {} }; diff --git a/sherpa/csrc/offline-transducer-greedy-search-decoder.cc b/sherpa/csrc/offline-transducer-greedy-search-decoder.cc index 287c42f1a..9b1ca5f06 100644 --- a/sherpa/csrc/offline-transducer-greedy-search-decoder.cc +++ b/sherpa/csrc/offline-transducer-greedy-search-decoder.cc @@ -68,17 +68,22 @@ OfflineTransducerGreedySearchDecoder::Decode(torch::Tensor encoder_out, std::vector results(N); - std::vector padding(context_size, blank_id); + std::vector padding(context_size, -1); + padding.back() = blank_id; + for (auto &r : results) { // We will remove the padding at the end r.tokens = padding; } auto decoder_input = - torch::full({N, context_size}, blank_id, + torch::full({N, context_size}, -1, torch::dtype(torch::kLong) .memory_format(torch::MemoryFormat::Contiguous)); + // set the last column to blank_id, i.e., decoder_input[:, -1] = blank_id + decoder_input.index({torch::indexing::Slice(), -1}) = blank_id; + // its shape is (N, 1, joiner_dim) auto decoder_out = model_->RunDecoder(decoder_input.to(device)); diff --git a/sherpa/csrc/offline-transducer-modified-beam-search-decoder.cc b/sherpa/csrc/offline-transducer-modified-beam-search-decoder.cc index 39edc77fc..09c30ebf4 100644 --- a/sherpa/csrc/offline-transducer-modified-beam-search-decoder.cc +++ b/sherpa/csrc/offline-transducer-modified-beam-search-decoder.cc @@ -100,7 +100,9 @@ OfflineTransducerModifiedBeamSearchDecoder::Decode( if (ss != nullptr) SHERPA_CHECK_EQ(batch_size, n); - std::vector blanks(context_size, blank_id); + std::vector blanks(context_size, -1); + blanks.back() = blank_id; + Hypotheses blank_hyp({{blanks, 0}}); std::deque finalized; diff --git a/sherpa/csrc/online-transducer-greedy-search-decoder.cc b/sherpa/csrc/online-transducer-greedy-search-decoder.cc index 5ba1921db..a274bbd9e 100644 --- a/sherpa/csrc/online-transducer-greedy-search-decoder.cc +++ b/sherpa/csrc/online-transducer-greedy-search-decoder.cc @@ -28,7 +28,8 @@ OnlineTransducerGreedySearchDecoder::GetEmptyResult() { int32_t context_size = model_->ContextSize(); int32_t blank_id = 0; // always 0 OnlineTransducerDecoderResult r; - r.tokens.resize(context_size, blank_id); + r.tokens.resize(context_size, -1); + r.tokens.back() = blank_id; return r; } diff --git a/sherpa/csrc/online-transducer-modified-beam-search-decoder.cc b/sherpa/csrc/online-transducer-modified-beam-search-decoder.cc index 3a8f64aa9..f8b19b93d 100644 --- a/sherpa/csrc/online-transducer-modified-beam-search-decoder.cc +++ b/sherpa/csrc/online-transducer-modified-beam-search-decoder.cc @@ -66,7 +66,9 @@ OnlineTransducerModifiedBeamSearchDecoder::GetEmptyResult() { int32_t context_size = model_->ContextSize(); int32_t blank_id = 0; // always 0 // - std::vector blanks(context_size, blank_id); + std::vector blanks(context_size, -1); + blanks.back() = blank_id; + Hypotheses blank_hyp({{blanks, 0}}); OnlineTransducerDecoderResult r; @@ -121,7 +123,9 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( int32_t N = encoder_out.size(0); int32_t T = encoder_out.size(1); - if (ss) SHERPA_CHECK_EQ(N, num_streams); + if (ss) { + SHERPA_CHECK_EQ(N, num_streams); + } std::vector cur; cur.reserve(N); diff --git a/sherpa/csrc/test-byte-util.cc b/sherpa/csrc/test-byte-util.cc new file mode 100644 index 000000000..4458278eb --- /dev/null +++ b/sherpa/csrc/test-byte-util.cc @@ -0,0 +1,60 @@ +/** + * Copyright 2023 Xiaomi Corporation (authors: Wei Kang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "gtest/gtest.h" +#include "sherpa/csrc/byte_util.h" +#include "sherpa/csrc/log.h" + +namespace sherpa { + +TEST(ByteUtil, TestBasic) { + auto bu = GetByteUtil(); + std::string str = "Hello world"; + SHERPA_CHECK_EQ(bu->Decode(bu->Encode(str)), str); + + str = "世界人民大团结万岁"; + SHERPA_CHECK_EQ(bu->Decode(bu->Encode(str)), str); + + str = "美国 America vs China 中国 123 go!!!"; + SHERPA_CHECK_EQ(bu->Decode(bu->Encode(str)), str); +} + +TEST(ByteUtil, TestInvalidBytes) { + auto bu = GetByteUtil(); + std::string str = "ƍĩĴƎĩŗƋţŅƋ⁇Şœƌľţ"; + SHERPA_CHECK_EQ(bu->Decode(str), "我爱你中国"); + + str = "ƍĩĴĩŗƋţŅƋŞœƌľţ"; // drop one byte in 爱 + SHERPA_CHECK_EQ(bu->Decode(str), "我你中国"); + + str = "ƍĩƎĩŗƋţŅƋŞœƌľţ"; // drop one byte in 我 + SHERPA_CHECK_EQ(bu->Decode(str), "爱你中国"); + + str = "ƍĩĴƎĩŗƋţŅƋŞœƌţ"; // drop one byte in 国 + SHERPA_CHECK_EQ(bu->Decode(str), "我爱你中"); + + str = "ƍĩĴƎĩŗƋţŅƋœƌľ"; // drop one byte in 中 and 国 + SHERPA_CHECK_EQ(bu->Decode(str), "我爱你"); + + str = "ƍĩĴƎĩŗƋţŅƋlœƌoľve"; // replace one byte in 中 and 国 with l o + SHERPA_CHECK_EQ(bu->Decode(str), "我爱你love"); +} + +} // namespace sherpa diff --git a/sherpa/csrc/test-context-graph.cc b/sherpa/csrc/test-context-graph.cc index d1f820027..00b18903e 100644 --- a/sherpa/csrc/test-context-graph.cc +++ b/sherpa/csrc/test-context-graph.cc @@ -29,14 +29,15 @@ TEST(ContextGraph, TestBasic) { std::vector contexts_str( {"S", "HE", "SHE", "SHELL", "HIS", "HERS", "HELLO", "THIS", "THEM"}); std::vector> contexts; - for (int32_t i = 0; i < contexts_str.size(); ++i) { + for (size_t i = 0; i < contexts_str.size(); ++i) { contexts.emplace_back(contexts_str[i].begin(), contexts_str[i].end()); } auto context_graph = ContextGraph(contexts, 1); auto queries = std::map{ - {"HEHERSHE", 14}, {"HERSHE", 12}, {"HISHE", 9}, {"SHED", 6}, - {"HELL", 2}, {"HELLO", 7}, {"DHRHISQ", 4}, {"THEN", 2}}; + {"HEHERSHE", 14}, {"HERSHE", 12}, {"HISHE", 9}, + {"SHED", 6}, {"SHELF", 6}, {"HELL", 2}, + {"HELLO", 7}, {"DHRHISQ", 4}, {"THEN", 2}}; for (const auto &iter : queries) { float total_scores = 0; diff --git a/sherpa/python/csrc/offline-recognizer.cc b/sherpa/python/csrc/offline-recognizer.cc index 6a77a434c..f8a004486 100644 --- a/sherpa/python/csrc/offline-recognizer.cc +++ b/sherpa/python/csrc/offline-recognizer.cc @@ -154,7 +154,7 @@ static void PybindOfflineRecognizerConfig(py::module &m) { // NOLINT py::class_(m, "OfflineRecognizerConfig") .def(py::init([](const std::string &nn_model, const std::string &tokens, bool use_gpu = false, int32_t num_active_paths = 4, - float context_score = 1.5, + float context_score = 1.5, bool use_bbpe = false, float temperature = 1.0, const OfflineCtcDecoderConfig &ctc_decoder_config = {}, const FeatureConfig &feat_config = {}, @@ -172,13 +172,14 @@ static void PybindOfflineRecognizerConfig(py::module &m) { // NOLINT config->decoding_method = decoding_method; config->num_active_paths = num_active_paths; config->context_score = context_score; + config->use_bbpe = use_bbpe; config->temperature = temperature; return config; }), py::arg("nn_model"), py::arg("tokens"), py::arg("use_gpu") = false, py::arg("num_active_paths") = 4, py::arg("context_score") = 1.5, - py::arg("temperature") = 1.0, + py::arg("use_bbpe") = false, py::arg("temperature") = 1.0, py::arg("ctc_decoder_config") = OfflineCtcDecoderConfig(), py::arg("feat_config") = FeatureConfig(), py::arg("fast_beam_search_config") = FastBeamSearchConfig(), @@ -196,6 +197,7 @@ static void PybindOfflineRecognizerConfig(py::module &m) { // NOLINT .def_readwrite("decoding_method", &PyClass::decoding_method) .def_readwrite("num_active_paths", &PyClass::num_active_paths) .def_readwrite("context_score", &PyClass::context_score) + .def_readwrite("use_bbpe", &PyClass::use_bbpe) .def_readwrite("temperature", &PyClass::temperature) .def("validate", &PyClass::Validate); } diff --git a/sherpa/python/csrc/online-recognizer.cc b/sherpa/python/csrc/online-recognizer.cc index f4a9c3652..53fb04a75 100644 --- a/sherpa/python/csrc/online-recognizer.cc +++ b/sherpa/python/csrc/online-recognizer.cc @@ -22,7 +22,7 @@ static void PybindOnlineRecognizerConfig(py::module &m) { // NOLINT const std::string &decoding_method = "greedy_search", int32_t num_active_paths = 4, float context_score = 1.5, int32_t left_context = 64, int32_t right_context = 0, - int32_t chunk_size = 16, + int32_t chunk_size = 16, bool use_bbpe = false, float temperature = 1.0, const FeatureConfig &feat_config = {}, const EndpointConfig &endpoint_config = {}, @@ -46,8 +46,8 @@ static void PybindOnlineRecognizerConfig(py::module &m) { // NOLINT ans->left_context = left_context; ans->right_context = right_context; ans->chunk_size = chunk_size; + ans->use_bbpe = use_bbpe; ans->temperature = temperature; - return ans; }), py::arg("nn_model"), py::arg("tokens"), @@ -57,7 +57,8 @@ static void PybindOnlineRecognizerConfig(py::module &m) { // NOLINT py::arg("decoding_method") = "greedy_search", py::arg("num_active_paths") = 4, py::arg("context_score") = 1.5, py::arg("left_context") = 64, py::arg("right_context") = 0, - py::arg("chunk_size") = 16, py::arg("temperature") = 1.0, + py::arg("chunk_size") = 16, py::arg("use_bbpe") = false, + py::arg("temperature") = 1.0, py::arg("feat_config") = FeatureConfig(), py::arg("endpoint_config") = EndpointConfig(), py::arg("fast_beam_search_config") = FastBeamSearchConfig()) @@ -79,6 +80,7 @@ static void PybindOnlineRecognizerConfig(py::module &m) { // NOLINT .def_readwrite("left_context", &PyClass::left_context) .def_readwrite("right_context", &PyClass::right_context) .def_readwrite("chunk_size", &PyClass::chunk_size) + .def_readwrite("use_bbpe", &PyClass::use_bbpe) .def_readwrite("temperature", &PyClass::temperature) .def("validate", &PyClass::Validate) .def("__str__",