Skip to content

Commit

Permalink
Merge with master
Browse files Browse the repository at this point in the history
  • Loading branch information
pkufool committed Aug 30, 2023
2 parents a6a5c3c + 96de89a commit 18ed8c0
Show file tree
Hide file tree
Showing 23 changed files with 116 additions and 41 deletions.
3 changes: 2 additions & 1 deletion get_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ def get_package_version():
if not is_stable():
dt = datetime.datetime.utcnow()
package_version = (
f"{latest_version}.dev{dt.year}{dt.month:02d}{dt.day:02d}{local_version}"
f"{latest_version}.dev{dt.year}{dt.month:02d}{dt.day:02d}"
f"{local_version}"
)
else:
package_version = f"{latest_version}"
Expand Down
1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#!/usr/bin/env python3

import os
import re
import sys
from pathlib import Path

Expand Down
11 changes: 11 additions & 0 deletions sherpa/bin/offline_transducer_asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,15 @@ def add_modified_beam_search_args(parser: argparse.ArgumentParser):
""",
)

parser.add_argument(
"--temperature",
type=float,
default=1.0,
help="""Used only when --decoding-method is modified_beam_search.
It specifies the softmax temperature.
""",
)


def add_fast_beam_search_args(parser: argparse.ArgumentParser):
parser.add_argument(
Expand Down Expand Up @@ -330,6 +339,7 @@ def check_args(args):

if args.decoding_method == "modified_beam_search":
assert args.num_active_paths > 0, args.num_active_paths
assert args.temperature > 0, args.temperature

if args.decoding_method == "fast_beam_search" and args.LG:
if not Path(args.LG).is_file():
Expand Down Expand Up @@ -414,6 +424,7 @@ def create_recognizer(args) -> sherpa.OfflineRecognizer:
feat_config=feat_config,
decoding_method=args.decoding_method,
fast_beam_search_config=fast_beam_search_config,
temperature=args.temperature,
)

recognizer = sherpa.OfflineRecognizer(config)
Expand Down
10 changes: 10 additions & 0 deletions sherpa/bin/offline_transducer_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,14 @@ def add_modified_beam_search_args(parser: argparse.ArgumentParser):
It specifies number of active paths to keep during decoding.
""",
)
parser.add_argument(
"--temperature",
type=float,
default=1.0,
help="""Used only when --decoding-method is modified_beam_search.
It specifies the softmax temperature.
""",
)


def add_fast_beam_search_args(parser: argparse.ArgumentParser):
Expand Down Expand Up @@ -217,6 +225,7 @@ def check_args(args):

if args.decoding_method == "modified_beam_search":
assert args.num_active_paths > 0, args.num_active_paths
assert args.temperature > 0, args.temperature

if args.decoding_method == "fast_beam_search" and args.LG:
if not Path(args.LG).is_file():
Expand Down Expand Up @@ -647,6 +656,7 @@ def create_recognizer(args) -> sherpa.OfflineRecognizer:
feat_config=feat_config,
decoding_method=args.decoding_method,
fast_beam_search_config=fast_beam_search_config,
temperature=args.temperature
)

recognizer = sherpa.OfflineRecognizer(config)
Expand Down
10 changes: 10 additions & 0 deletions sherpa/bin/online_transducer_asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,14 @@ def add_modified_beam_search_args(parser: argparse.ArgumentParser):
Used only when --decoding-method=modified_beam_search
""",
)
parser.add_argument(
"--temperature",
type=float,
default=1.0,
help="""Used only when --decoding-method is modified_beam_search.
It specifies the softmax temperature.
""",
)


def add_fast_beam_search_args(parser: argparse.ArgumentParser):
Expand Down Expand Up @@ -320,6 +328,7 @@ def check_args(args):

if args.decoding_method == "modified_beam_search":
assert args.num_active_paths > 0, args.num_active_paths
assert args.temperature > 0, args.temperature

if args.decoding_method == "fast_beam_search" and args.LG:
if not Path(args.LG).is_file():
Expand Down Expand Up @@ -404,6 +413,7 @@ def create_recognizer(args) -> sherpa.OnlineRecognizer:
feat_config=feat_config,
decoding_method=args.decoding_method,
fast_beam_search_config=fast_beam_search_config,
temperature=args.temperature
)

recognizer = sherpa.OnlineRecognizer(config)
Expand Down
10 changes: 10 additions & 0 deletions sherpa/bin/streaming_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,14 @@ def add_modified_beam_search_args(parser: argparse.ArgumentParser):
It specifies number of active paths to keep during decoding.
""",
)
parser.add_argument(
"--temperature",
type=float,
default=1.0,
help="""Used only when --decoding-method is modified_beam_search.
It specifies the softmax temperature.
""",
)


def add_endpointing_args(parser: argparse.ArgumentParser):
Expand Down Expand Up @@ -413,6 +421,7 @@ def create_recognizer(args) -> sherpa.OnlineRecognizer:
use_gpu=args.use_gpu,
num_active_paths=args.num_active_paths,
use_bbpe=args.use_bbpe,
temperature=args.temperature,
feat_config=feat_config,
decoding_method=args.decoding_method,
fast_beam_search_config=fast_beam_search_config,
Expand Down Expand Up @@ -776,6 +785,7 @@ def check_args(args):

if args.decoding_method == "modified_beam_search":
assert args.num_active_paths > 0, args.num_active_paths
assert args.temperature > 0, args.temperature

if args.decoding_method == "fast_beam_search" and args.LG:
if not Path(args.LG).is_file():
Expand Down
6 changes: 6 additions & 0 deletions sherpa/cpp_api/feature-config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ void FeatureConfig::Register(ParseOptions *po) {
fbank_opts.mel_opts.num_bins = 80;
RegisterMelBanksOptions(po, &fbank_opts.mel_opts);

fbank_opts.mel_opts.high_freq = -400;
fbank_opts.frame_opts.remove_dc_offset = true;
fbank_opts.frame_opts.round_to_power_of_two = true;
fbank_opts.energy_floor = 1e-10;
fbank_opts.frame_opts.snip_edges = false;
fbank_opts.frame_opts.samp_freq = 16000;
po->Register("normalize-samples", &normalize_samples,
"true to use samples in the range [-1, 1]. "
"false to use samples in the range [-32768, 32767]. "
Expand Down
2 changes: 1 addition & 1 deletion sherpa/cpp_api/offline-recognizer-transducer-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
std::make_unique<OfflineTransducerGreedySearchDecoder>(model_.get());
} else if (config.decoding_method == "modified_beam_search") {
decoder_ = std::make_unique<OfflineTransducerModifiedBeamSearchDecoder>(
model_.get(), config.num_active_paths);
model_.get(), config.num_active_paths, config.temperature);
} else if (config.decoding_method == "fast_beam_search") {
config.fast_beam_search_config.Validate();

Expand Down
7 changes: 6 additions & 1 deletion sherpa/cpp_api/offline-recognizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,10 @@ void OfflineRecognizerConfig::Register(ParseOptions *po) {
"languages or multilingual datasets, it can further break "
"the multi-byte unicode characters into byte sequence and "
"then train some kind of sub-char bpes.");

po->Register("temperature", &temperature,
"Softmax temperature,. "
"Used only when decoding_method is modified_beam_search.");
}

void OfflineRecognizerConfig::Validate() const {
Expand Down Expand Up @@ -158,7 +162,8 @@ std::string OfflineRecognizerConfig::ToString() const {
os << "decoding_method=\"" << decoding_method << "\", ";
os << "num_active_paths=" << num_active_paths << ", ";
os << "context_score=" << context_score << ", ";
os << "use_bbpe=" << use_bbpe << ")";
os << "use_bbpe=" << use_bbpe << ", ";
os << "temperature=" << temperature << ")";

return os.str();
}
Expand Down
3 changes: 3 additions & 0 deletions sherpa/cpp_api/offline-recognizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ struct OfflineRecognizerConfig {
// True if the model used is trained with byte level bpe.
bool use_bbpe = false;

// temperature for the softmax in the joiner
float temperature = 1.0;

void Register(ParseOptions *po);

void Validate() const;
Expand Down
9 changes: 7 additions & 2 deletions sherpa/cpp_api/online-recognizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,10 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) {
"languages or multilingual datasets, it can further break "
"the multi-byte unicode characters into byte sequence and "
"then train some kind of sub-char bpes.");

po->Register("temperature", &temperature,
"Softmax temperature,. "
"Used only when decoding_method is modified_beam_search.");
}

void OnlineRecognizerConfig::Validate() const {
Expand Down Expand Up @@ -181,7 +185,8 @@ std::string OnlineRecognizerConfig::ToString() const {
os << "left_context=" << left_context << ", ";
os << "right_context=" << right_context << ", ";
os << "chunk_size=" << chunk_size << ", ";
os << "use_bbpe=" << use_bbpe << ")";
os << "use_bbpe=" << use_bbpe << ", ";
os << "temperature=" << temperature << ")";
return os.str();
}

Expand Down Expand Up @@ -312,7 +317,7 @@ class OnlineRecognizer::OnlineRecognizerImpl {
std::make_unique<OnlineTransducerGreedySearchDecoder>(model_.get());
} else if (config.decoding_method == "modified_beam_search") {
decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
model_.get(), config.num_active_paths);
model_.get(), config.num_active_paths, config.temperature);
} else if (config.decoding_method == "fast_beam_search") {
config.fast_beam_search_config.Validate();

Expand Down
3 changes: 3 additions & 0 deletions sherpa/cpp_api/online-recognizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ struct OnlineRecognizerConfig {
// True if the model used is trained with byte level bpe.
bool use_bbpe = false;

// temperature for the softmax in the joiner
float temperature = 1.0;

void Register(ParseOptions *po);

void Validate() const;
Expand Down
17 changes: 4 additions & 13 deletions sherpa/csrc/context-graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ void ContextGraph::Build(
bool is_end = j == (token_ids[i].size() - 1);
node->next[token] = std::make_unique<ContextState>(
token, context_score_, node->node_score + context_score_,
is_end ? 0 : node->local_node_score + context_score_, is_end);
is_end ? node->node_score + context_score_ : 0, is_end);
}
node = node->next[token].get();
}
Expand All @@ -48,7 +48,6 @@ std::pair<float, const ContextState *> ContextGraph::ForwardOneStep(
if (1 == state->next.count(token)) {
node = state->next.at(token).get();
score = node->token_score;
if (state->is_end) score += state->node_score;
} else {
node = state->fail;
while (0 == node->next.count(token)) {
Expand All @@ -58,24 +57,15 @@ std::pair<float, const ContextState *> ContextGraph::ForwardOneStep(
if (1 == node->next.count(token)) {
node = node->next.at(token).get();
}
score = node->node_score - state->local_node_score;
score = node->node_score - state->node_score;
}
SHERPA_CHECK(nullptr != node);
float matched_score = 0;
auto output = node->output;
while (nullptr != output) {
matched_score += output->node_score;
output = output->output;
}
return std::make_pair(score + matched_score, node);
return std::make_pair(score + node->output_score, node);
}

std::pair<float, const ContextState *> ContextGraph::Finalize(
const ContextState *state) const {
float score = -state->node_score;
if (state->is_end) {
score = 0;
}
return std::make_pair(score, root_.get());
}

Expand Down Expand Up @@ -112,6 +102,7 @@ void ContextGraph::FillFailOutput() const {
}
}
kv.second->output = output;
kv.second->output_score += output == nullptr ? 0 : output->output_score;
node_queue.push(kv.second.get());
}
}
Expand Down
6 changes: 3 additions & 3 deletions sherpa/csrc/context-graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,19 +35,19 @@ struct ContextState {
int32_t token;
float token_score;
float node_score;
float local_node_score;
float output_score;
bool is_end;
std::unordered_map<int32_t, std::unique_ptr<ContextState>> next;
const ContextState *fail = nullptr;
const ContextState *output = nullptr;

ContextState() = default;
ContextState(int32_t token, float token_score, float node_score,
float local_node_score, bool is_end)
float output_score, bool is_end)
: token(token),
token_score(token_score),
node_score(node_score),
local_node_score(local_node_score),
output_score(output_score),
is_end(is_end) {}
};

Expand Down
9 changes: 7 additions & 2 deletions sherpa/csrc/offline-transducer-greedy-search-decoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,17 +68,22 @@ OfflineTransducerGreedySearchDecoder::Decode(torch::Tensor encoder_out,

std::vector<OfflineTransducerDecoderResult> results(N);

std::vector<int32_t> padding(context_size, blank_id);
std::vector<int32_t> padding(context_size, -1);
padding.back() = blank_id;

for (auto &r : results) {
// We will remove the padding at the end
r.tokens = padding;
}

auto decoder_input =
torch::full({N, context_size}, blank_id,
torch::full({N, context_size}, -1,
torch::dtype(torch::kLong)
.memory_format(torch::MemoryFormat::Contiguous));

// set the last column to blank_id, i.e., decoder_input[:, -1] = blank_id
decoder_input.index({torch::indexing::Slice(), -1}) = blank_id;

// its shape is (N, 1, joiner_dim)
auto decoder_out = model_->RunDecoder(decoder_input.to(device));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,9 @@ OfflineTransducerModifiedBeamSearchDecoder::Decode(

if (ss != nullptr) SHERPA_CHECK_EQ(batch_size, n);

std::vector<int32_t> blanks(context_size, blank_id);
std::vector<int32_t> blanks(context_size, -1);
blanks.back() = blank_id;

Hypotheses blank_hyp({{blanks, 0}});

std::deque<Hypotheses> finalized;
Expand Down Expand Up @@ -184,7 +186,7 @@ OfflineTransducerModifiedBeamSearchDecoder::Decode(
logits = logits.squeeze(1).squeeze(1);
// now logits' shape is (num_hyps, vocab_size)

auto log_probs = logits.log_softmax(-1).cpu();
auto log_probs = (logits / temperature_).log_softmax(-1).cpu();

log_probs.add_(ys_log_probs);

Expand Down
8 changes: 4 additions & 4 deletions sherpa/csrc/offline-transducer-modified-beam-search-decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ namespace sherpa {
class OfflineTransducerModifiedBeamSearchDecoder
: public OfflineTransducerDecoder {
public:
OfflineTransducerModifiedBeamSearchDecoder(OfflineTransducerModel *model,
int32_t num_active_paths)
: model_(model), num_active_paths_(num_active_paths) {}
OfflineTransducerModifiedBeamSearchDecoder(
OfflineTransducerModel *model, int32_t num_active_paths, float temperature)
: model_(model), num_active_paths_(num_active_paths), temperature_(temperature) {}

/** Run modified beam search given the output from the encoder model.
*
Expand All @@ -35,8 +35,8 @@ class OfflineTransducerModifiedBeamSearchDecoder

private:
OfflineTransducerModel *model_; // Not owned

int32_t num_active_paths_;
float temperature_ = 1.0;
};

} // namespace sherpa
Expand Down
3 changes: 2 additions & 1 deletion sherpa/csrc/online-transducer-greedy-search-decoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ OnlineTransducerGreedySearchDecoder::GetEmptyResult() {
int32_t context_size = model_->ContextSize();
int32_t blank_id = 0; // always 0
OnlineTransducerDecoderResult r;
r.tokens.resize(context_size, blank_id);
r.tokens.resize(context_size, -1);
r.tokens.back() = blank_id;

return r;
}
Expand Down
Loading

0 comments on commit 18ed8c0

Please sign in to comment.