Skip to content

Commit

Permalink
Support computing features for whisper (#82)
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj authored Nov 8, 2023
1 parent 7912c2f commit 01aed93
Show file tree
Hide file tree
Showing 22 changed files with 2,734 additions and 5 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ project(kaldifeat)
# remember to change the version in
# scripts/conda/kaldifeat/meta.yaml
# scripts/conda-cpu/kaldifeat/meta.yaml
set(kaldifeat_VERSION "1.25.1")
set(kaldifeat_VERSION "1.25.2")

set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
Expand Down
1 change: 1 addition & 0 deletions kaldifeat/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ set(kaldifeat_srcs
matrix-functions.cc
mel-computations.cc
online-feature.cc
whisper-fbank.cc
)

add_library(kaldifeat_core ${kaldifeat_srcs})
Expand Down
1 change: 1 addition & 0 deletions kaldifeat/csrc/CPPLINT.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
exclude_files=whisper-mel-bank.h
2 changes: 1 addition & 1 deletion kaldifeat/csrc/feature-fbank.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ torch::Tensor FbankComputer::Compute(torch::Tensor signal_raw_log_energy,
// note spectrum is in magnitude, not power, because of `abs()`
#if defined(KALDIFEAT_HAS_FFT_NAMESPACE)
// signal_frame shape: [x, 512]
// spectrum shape [x, 257
// spectrum shape [x, 257]
torch::Tensor spectrum = torch::fft::rfft(signal_frame).abs();
#else
// signal_frame shape [x, 512]
Expand Down
9 changes: 9 additions & 0 deletions kaldifeat/csrc/feature-window.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@ FeatureWindowFunction::FeatureWindowFunction(const FrameExtractionOptions &opts,
float *window_data = window.data_ptr<float>();

double a = M_2PI / (frame_length - 1);

if (opts.window_type == "hann") {
// see https://pytorch.org/docs/stable/generated/torch.hann_window.html
// We assume periodic is true
a = M_2PI / frame_length;
}

for (int32_t i = 0; i < frame_length; i++) {
double i_fl = static_cast<double>(i);
if (opts.window_type == "hanning") {
Expand All @@ -39,6 +46,8 @@ FeatureWindowFunction::FeatureWindowFunction(const FrameExtractionOptions &opts,
window_data[i] = sin(0.5 * a * i_fl);
} else if (opts.window_type == "hamming") {
window_data[i] = 0.54 - 0.46 * cos(a * i_fl);
} else if (opts.window_type == "hann") {
window_data[i] = 0.50 - 0.50 * cos(a * i_fl);
} else if (opts.window_type ==
"povey") { // like hamming but goes to zero at edges.
window_data[i] = pow(0.5 - 0.5 * cos(a * i_fl), 0.85);
Expand Down
39 changes: 39 additions & 0 deletions kaldifeat/csrc/generate-whisper-melbank.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#!/usr/bin/env python3

# Copyright (c) 2023 Xiaomi Corporation (authors: Fangjun Kuang)

import librosa
import numpy as np


def main():
m = librosa.filters.mel(sr=16000, n_fft=400, n_mels=80)
assert m.shape == (80, 201)
s = "// Auto-generated. Do NOT edit!\n\n"
s += "// Copyright (c) 2023 Xiaomi Corporation (authors: Fangjun Kuang)\n\n"
s += "\n"
s += "#ifndef KALDIFEAT_CSRC_WHISPER_MEL_BANK_H_\n"
s += "#define KALDIFEAT_CSRC_WHISPER_MEL_BANK_H_\n"
s += "namespace kaldifeat {\n\n"
s += f"constexpr int32_t kWhisperMelRows = {m.shape[0]};\n"
s += f"constexpr int32_t kWhisperMelCols = {m.shape[1]};\n"
s += "\n"
s += "constexpr float kWhisperMelArray[] = {\n"
sep = ""
for i, f in enumerate(m.reshape(-1).tolist()):
s += f"{sep}{f:.8f}"
sep = ", "
if i and i % 7 == 0:
s += ",\n"
sep = ""

s += "};\n\n"
s += "} // namespace kaldifeat\n\n"
s += "#endif // KALDIFEAT_CSRC_WHISPER_MEL_BANK_H_\n"

with open("whisper-mel-bank.h", "w") as f:
f.write(s)


if __name__ == "__main__":
main()
9 changes: 9 additions & 0 deletions kaldifeat/csrc/mel-computations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,15 @@ MelBanks::MelBanks(const MelBanksOptions &opts,
}
}

MelBanks::MelBanks(const float *weights, int32_t num_rows, int32_t num_cols,
torch::Device device)
: debug_(false), htk_mode_(false) {
bins_mat_ = torch::from_blob(const_cast<float *>(weights),
{num_rows, num_cols}, torch::kFloat)
.t()
.to(device);
}

torch::Tensor MelBanks::Compute(const torch::Tensor &spectrum) const {
return torch::mm(spectrum, bins_mat_);
}
Expand Down
14 changes: 13 additions & 1 deletion kaldifeat/csrc/mel-computations.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,17 @@ class MelBanks {
const FrameExtractionOptions &frame_opts, float vtln_warp_factor,
torch::Device device);

// Initialize with a 2-d weights matrix
//
// Note: This constructor is for Whisper. It does not initialize
// center_freqs_.
//
// @param weights Pointer to the start address of the matrix
// @param num_rows It equals to number of mel bins
// @param num_cols It equals to (number of fft bins)/2+1
MelBanks(const float *weights, int32_t num_rows, int32_t num_cols,
torch::Device device);

// CAUTION: we save a transposed version of bins_mat_, so return size(1) here
int32_t NumBins() const { return static_cast<int32_t>(bins_mat_.size(1)); }

Expand All @@ -89,7 +100,8 @@ class MelBanks {

private:
// A 2-D matrix. Its shape is NOT [num_bins, num_fft_bins]
// Its shape is [num_fft_bins, num_bins].
// Its shape is [num_fft_bins, num_bins] for non-whisper.
// For whisper, its shape is [num_fft_bins/2+1, num_bins]
torch::Tensor bins_mat_;

// center frequencies of bins, numbered from 0 ... num_bins-1.
Expand Down
78 changes: 78 additions & 0 deletions kaldifeat/csrc/whisper-fbank.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/**
* Copyright (c) 2023 Xiaomi Corporation (authors: Fangjun Kuang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "kaldifeat/csrc/whisper-fbank.h"

#include <cmath>
#include <vector>

#include "kaldifeat/csrc/mel-computations.h"
#include "kaldifeat/csrc/whisper-mel-bank.h"

#ifndef M_2PI
#define M_2PI 6.283185307179586476925286766559005
#endif

namespace kaldifeat {

WhisperFbankComputer::WhisperFbankComputer(const WhisperFbankOptions &opts)
: opts_(opts),
mel_banks_(kWhisperMelArray, kWhisperMelRows, kWhisperMelCols,
opts.device) {
opts_.frame_opts.samp_freq = 16000;
opts_.frame_opts.frame_shift_ms = 10;
opts_.frame_opts.frame_length_ms = 25;
opts_.frame_opts.dither = 0;
opts_.frame_opts.preemph_coeff = 0;
opts_.frame_opts.remove_dc_offset = false;
opts_.frame_opts.window_type = "hann";
opts_.frame_opts.round_to_power_of_two = false;
opts_.frame_opts.snip_edges = false;
}

torch::Tensor WhisperFbankComputer::Compute(
torch::Tensor /*signal_raw_log_energy*/, float /*vtln_warp*/,
const torch::Tensor &signal_frame) {
KALDIFEAT_ASSERT(signal_frame.dim() == 2);
KALDIFEAT_ASSERT(signal_frame.size(1) == opts_.frame_opts.PaddedWindowSize());

// note spectrum is in magnitude, not power, because of `abs()`
#if defined(KALDIFEAT_HAS_FFT_NAMESPACE)
// signal_frame shape: [x, 512]
// power shape [x, 257]
torch::Tensor power = torch::fft::rfft(signal_frame).abs().pow(2);
#else
// signal_frame shape [x, 512]
// real_imag shape [x, 257, 2],
// where [..., 0] is the real part
// [..., 1] is the imaginary part
torch::Tensor real_imag = torch::rfft(signal_frame, 1);
torch::Tensor real = real_imag.index({"...", 0});
torch::Tensor imag = real_imag.index({"...", 1});
torch::Tensor power = (real.square() + imag.square());
#endif

torch::Tensor mel_energies = mel_banks_.Compute(power);
torch::Tensor log_spec = torch::clamp_min(mel_energies, 1e-10).log10();
log_spec = torch::maximum(log_spec, log_spec.max() - 8.0);
torch::Tensor mel = (log_spec + 4.0) / 4.0;

return mel;
}

} // namespace kaldifeat
74 changes: 74 additions & 0 deletions kaldifeat/csrc/whisper-fbank.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/**
* Copyright (c) 2023 Xiaomi Corporation (authors: Fangjun Kuang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef KALDIFEAT_CSRC_WHISPER_FBANK_H_
#define KALDIFEAT_CSRC_WHISPER_FBANK_H_

#include <string>
#include <vector>

#include "kaldifeat/csrc/feature-common.h"
#include "kaldifeat/csrc/feature-window.h"
#include "kaldifeat/csrc/mel-computations.h"

namespace kaldifeat {

struct WhisperFbankOptions {
FrameExtractionOptions frame_opts;

torch::Device device{"cpu"};
std::string ToString() const {
std::ostringstream os;
os << "WhisperFbankOptions(";
os << "frame_opts=" << frame_opts.ToString() << ", ";
os << "device=\"" << device << "\")";
return os.str();
}
};

class WhisperFbankComputer {
public:
// note: Only frame_opts.device is used. All other fields from frame_opts
// are ignored
explicit WhisperFbankComputer(const WhisperFbankOptions &opts = {});

int32_t Dim() const { return 80; }

const FrameExtractionOptions &GetFrameOptions() const {
return opts_.frame_opts;
}

const WhisperFbankOptions &GetOptions() const { return opts_; }

torch::Tensor Compute(torch::Tensor /*signal_raw_log_energy*/,
float /*vtln_warp*/, const torch::Tensor &signal_frame);

// if true, compute log_energy_pre_window but after dithering and dc removal
bool NeedRawLogEnergy() const { return false; }
using Options = WhisperFbankOptions;

private:
WhisperFbankOptions opts_;
MelBanks mel_banks_;
};

using WhisperFbank = OfflineFeatureTpl<WhisperFbankComputer>;

} // namespace kaldifeat

#endif // KALDIFEAT_CSRC_WHISPER_FBANK_H_
Loading

0 comments on commit 01aed93

Please sign in to comment.