From 1270c627ec68df289a6b3325eba93aaced12695c Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sat, 4 Nov 2023 22:31:36 +0800 Subject: [PATCH] Support removing bytes equal to 0 from the output in text normalization (#56) --- CMakeLists.txt | 2 +- kaldifst/csrc/text-normalizer.cc | 50 ++++++++++++++++++++----- kaldifst/csrc/text-normalizer.h | 6 ++- kaldifst/python/csrc/text-normalizer.cc | 6 ++- 4 files changed, 51 insertions(+), 13 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 482e197..7a0687f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -13,7 +13,7 @@ cmake_minimum_required(VERSION 3.13 FATAL_ERROR) project(kaldifst CXX) -set(KALDIFST_VERSION "1.7.7") +set(KALDIFST_VERSION "1.7.8") if(NOT CMAKE_BUILD_TYPE) message(STATUS "No CMAKE_BUILD_TYPE given, default to Release") diff --git a/kaldifst/csrc/text-normalizer.cc b/kaldifst/csrc/text-normalizer.cc index 7e3110a..bcbf8ab 100644 --- a/kaldifst/csrc/text-normalizer.cc +++ b/kaldifst/csrc/text-normalizer.cc @@ -8,7 +8,6 @@ #include #include -#include "fst/arcsort.h" #include "kaldifst/csrc/kaldi-fst-io.h" #include "kaldifst/csrc/table-matcher.h" @@ -48,6 +47,44 @@ static fst::StdVectorFst StringToFst(const std::string &text) { return ans; } +static std::string FstToString(const fst::StdVectorFst &fst, + bool remove_output_zero) { + std::string ans; + + using Weight = typename fst::StdArc::Weight; + using Arc = fst::StdArc; + auto s = fst.Start(); + if (s == fst::kNoStateId) { + // this is an empty FST + return ""; + } + while (fst.Final(s) == Weight::Zero()) { + fst::ArcIterator> aiter(fst, s); + if (aiter.Done()) { + // not reached final. + return ""; + } + + const auto &arc = aiter.Value(); + if (arc.olabel != 0 || !remove_output_zero) { + ans.push_back(arc.olabel); + } + + s = arc.nextstate; + if (s == fst::kNoStateId) { + // Transition to invalid state"; + return ""; + } + + aiter.Next(); + if (!aiter.Done()) { + // not a linear FST + return ""; + } + } + return ans; +} + TextNormalizer::TextNormalizer(const std::string &rule) { rule_ = std::unique_ptr( CastOrConvertToConstFst(fst::ReadFstKaldiGeneric(rule))); @@ -56,7 +93,8 @@ TextNormalizer::TextNormalizer(const std::string &rule) { TextNormalizer::TextNormalizer(std::unique_ptr rule) : rule_(std::move(rule)) {} -std::string TextNormalizer::Normalize(const std::string &s) const { +std::string TextNormalizer::Normalize(const std::string &s, + bool remove_output_zero /*=true*/) const { // Step 1: Convert the input text into an FST fst::StdVectorFst text = StringToFst(s); @@ -68,13 +106,7 @@ std::string TextNormalizer::Normalize(const std::string &s) const { fst::StdVectorFst one_best; fst::ShortestPath(composed_fst, &one_best, 1); - // Step 4: Concatenate the output labels of the best path - fst::StringPrinter string_printer(fst::StringTokenType::BYTE); - - std::string normalized; - string_printer(one_best, &normalized); - - return normalized; + return FstToString(one_best, remove_output_zero); } } // namespace kaldifst diff --git a/kaldifst/csrc/text-normalizer.h b/kaldifst/csrc/text-normalizer.h index ce5dd17..e304a95 100644 --- a/kaldifst/csrc/text-normalizer.h +++ b/kaldifst/csrc/text-normalizer.h @@ -20,7 +20,11 @@ class TextNormalizer { explicit TextNormalizer(std::unique_ptr rule); - std::string Normalize(const std::string &s) const; + // @param s The input text to be normalized + // @param remove_output_zero True to remove bytes whose value is 0 from the + // output. + std::string Normalize(const std::string &s, + bool remove_output_zero = true) const; private: std::unique_ptr rule_; diff --git a/kaldifst/python/csrc/text-normalizer.cc b/kaldifst/python/csrc/text-normalizer.cc index ba6bd4d..5d63716 100644 --- a/kaldifst/python/csrc/text-normalizer.cc +++ b/kaldifst/python/csrc/text-normalizer.cc @@ -14,8 +14,10 @@ void PybindTextNormalizer(py::module *m) { using PyClass = TextNormalizer; py::class_(*m, "TextNormalizer") .def(py::init(), py::arg("rule")) - .def("normalize", &PyClass::Normalize) - .def("__call__", &PyClass::Normalize); + .def("normalize", &PyClass::Normalize, py::arg("s"), + py::arg("remove_output_zero") = true) + .def("__call__", &PyClass::Normalize, py::arg("s"), + py::arg("remove_output_zero") = true); } } // namespace kaldifst