Skip to content

Commit

Permalink
Replace FbankOptions with FeatureConfig to support 'normalize_samples…
Browse files Browse the repository at this point in the history
…' option (#546)
  • Loading branch information
megazone87 authored Feb 28, 2024
1 parent 9841097 commit dd8c367
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 11 deletions.
4 changes: 2 additions & 2 deletions sherpa/cpp_api/online-recognizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ class OnlineRecognizer::OnlineRecognizerImpl {
}

std::unique_ptr<OnlineStream> CreateStream() {
auto s = std::make_unique<OnlineStream>(config_.feat_config.fbank_opts);
auto s = std::make_unique<OnlineStream>(config_.feat_config);
InitOnlineStream(s.get());
return s;
}
Expand All @@ -359,7 +359,7 @@ class OnlineRecognizer::OnlineRecognizerImpl {
// model rather than each stream.
auto context_graph =
std::make_shared<ContextGraph>(contexts, config_.context_score);
auto s = std::make_unique<OnlineStream>(config_.feat_config.fbank_opts,
auto s = std::make_unique<OnlineStream>(config_.feat_config,
context_graph);
InitOnlineStream(s.get());
return s;
Expand Down
4 changes: 2 additions & 2 deletions sherpa/cpp_api/online-stream.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#include <string>
#include <vector>

#include "kaldifeat/csrc/feature-fbank.h"
#include "sherpa/cpp_api/feature-config.h"
#include "sherpa/csrc/context-graph.h"
#include "torch/script.h"

Expand Down Expand Up @@ -58,7 +58,7 @@ struct OnlineTransducerDecoderResult;

class OnlineStream {
public:
explicit OnlineStream(const kaldifeat::FbankOptions &opts,
explicit OnlineStream(const FeatureConfig &feat_config,
ContextGraphPtr context_graph = nullptr);
~OnlineStream();

Expand Down
16 changes: 10 additions & 6 deletions sherpa/csrc/online-stream.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
#include <utility>
#include <vector>

#include "kaldifeat/csrc/feature-fbank.h"
#include "kaldifeat/csrc/online-feature.h"
#include "sherpa/cpp_api/endpoint.h"
#include "sherpa/csrc/context-graph.h"
Expand All @@ -36,15 +35,19 @@ namespace sherpa {

class OnlineStream::OnlineStreamImpl {
public:
explicit OnlineStreamImpl(const kaldifeat::FbankOptions &opts,
explicit OnlineStreamImpl(const FeatureConfig &feat_config,
ContextGraphPtr context_graph /*=nullptr*/)
: opts_(opts), context_graph_(context_graph) {
fbank_ = std::make_unique<kaldifeat::OnlineFbank>(opts);
: opts_(feat_config.fbank_opts), feat_config_(feat_config), context_graph_(context_graph) {
fbank_ = std::make_unique<kaldifeat::OnlineFbank>(opts_);
}

void AcceptWaveform(int32_t sampling_rate, torch::Tensor waveform) {
std::lock_guard<std::mutex> lock(feat_mutex_);

if (!feat_config_.normalize_samples) {
waveform.mul_(32767);
}

if (resampler_) {
if (sampling_rate != resampler_->GetInputSamplingRate()) {
SHERPA_LOG(FATAL) << "You changed the input sampling rate!! Expected: "
Expand Down Expand Up @@ -124,6 +127,7 @@ class OnlineStream::OnlineStreamImpl {
private:
kaldifeat::FbankOptions opts_;
std::unique_ptr<kaldifeat::OnlineFbank> fbank_;
FeatureConfig feat_config_;
mutable std::mutex feat_mutex_;

torch::IValue state_;
Expand All @@ -144,9 +148,9 @@ class OnlineStream::OnlineStreamImpl {
std::unique_ptr<LinearResample> resampler_;
};

OnlineStream::OnlineStream(const kaldifeat::FbankOptions &opts,
OnlineStream::OnlineStream(const FeatureConfig &feat_config,
ContextGraphPtr context_graph)
: impl_(std::make_unique<OnlineStreamImpl>(opts, context_graph)) {}
: impl_(std::make_unique<OnlineStreamImpl>(feat_config, context_graph)) {}

OnlineStream::~OnlineStream() = default;

Expand Down
2 changes: 1 addition & 1 deletion sherpa/csrc/test-online-stream.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ TEST(OnlineStream, Test) {
FeatureConfig feat_config;
feat_config.fbank_opts.mel_opts.num_bins = feature_dim;

OnlineStream s(feat_config.fbank_opts);
OnlineStream s(feat_config);
EXPECT_EQ(s.NumFramesReady(), 0);
auto a = torch::rand({500}, torch::kFloat);
s.AcceptWaveform(sampling_rate, a);
Expand Down

0 comments on commit dd8c367

Please sign in to comment.