diff --git a/examples/cpp/silero-vad-onnx.cpp b/examples/cpp/silero-vad-onnx.cpp index eb92296..dd2bf4e 100644 --- a/examples/cpp/silero-vad-onnx.cpp +++ b/examples/cpp/silero-vad-onnx.cpp @@ -120,8 +120,7 @@ class VadIterator void reset_states() { // Call reset before each audio start - std::memset(_h.data(), 0.0f, _h.size() * sizeof(float)); - std::memset(_c.data(), 0.0f, _c.size() * sizeof(float)); + std::memset(_state.data(), 0.0f, _state.size() * sizeof(float)); triggered = false; temp_end = 0; current_sample = 0; @@ -139,19 +138,16 @@ class VadIterator input.assign(data.begin(), data.end()); Ort::Value input_ort = Ort::Value::CreateTensor( memory_info, input.data(), input.size(), input_node_dims, 2); + Ort::Value state_ort = Ort::Value::CreateTensor( + memory_info, _state.data(), _state.size(), state_node_dims, 3); Ort::Value sr_ort = Ort::Value::CreateTensor( memory_info, sr.data(), sr.size(), sr_node_dims, 1); - Ort::Value h_ort = Ort::Value::CreateTensor( - memory_info, _h.data(), _h.size(), hc_node_dims, 3); - Ort::Value c_ort = Ort::Value::CreateTensor( - memory_info, _c.data(), _c.size(), hc_node_dims, 3); // Clear and add inputs ort_inputs.clear(); ort_inputs.emplace_back(std::move(input_ort)); + ort_inputs.emplace_back(std::move(state_ort)); ort_inputs.emplace_back(std::move(sr_ort)); - ort_inputs.emplace_back(std::move(h_ort)); - ort_inputs.emplace_back(std::move(c_ort)); // Infer ort_outputs = session->Run( @@ -161,10 +157,8 @@ class VadIterator // Output probability & update h,c recursively float speech_prob = ort_outputs[0].GetTensorMutableData()[0]; - float *hn = ort_outputs[1].GetTensorMutableData(); - std::memcpy(_h.data(), hn, size_hc * sizeof(float)); - float *cn = ort_outputs[2].GetTensorMutableData(); - std::memcpy(_c.data(), cn, size_hc * sizeof(float)); + float *stateN = ort_outputs[1].GetTensorMutableData(); + std::memcpy(_state.data(), stateN, size_state * sizeof(float)); // Push forward sample index current_sample += window_size_samples; @@ -376,27 +370,26 @@ class VadIterator // Inputs std::vector ort_inputs; - std::vector input_node_names = {"input", "sr", "h", "c"}; + std::vector input_node_names = {"input", "state", "sr"}; std::vector input; + unsigned int size_state = 2 * 1 * 128; // It's FIXED. + std::vector _state; std::vector sr; - unsigned int size_hc = 2 * 1 * 64; // It's FIXED. - std::vector _h; - std::vector _c; - int64_t input_node_dims[2] = {}; + int64_t input_node_dims[2] = {}; + const int64_t state_node_dims[3] = {2, 1, 128}; const int64_t sr_node_dims[1] = {1}; - const int64_t hc_node_dims[3] = {2, 1, 64}; // Outputs std::vector ort_outputs; - std::vector output_node_names = {"output", "hn", "cn"}; + std::vector output_node_names = {"output", "stateN"}; public: // Construction VadIterator(const std::wstring ModelPath, - int Sample_rate = 16000, int windows_frame_size = 64, + int Sample_rate = 16000, int windows_frame_size = 32, float Threshold = 0.5, int min_silence_duration_ms = 0, - int speech_pad_ms = 64, int min_speech_duration_ms = 64, + int speech_pad_ms = 32, int min_speech_duration_ms = 32, float max_speech_duration_s = std::numeric_limits::infinity()) { init_onnx_model(ModelPath); @@ -422,8 +415,7 @@ class VadIterator input_node_dims[0] = 1; input_node_dims[1] = window_size_samples; - _h.resize(size_hc); - _c.resize(size_hc); + _state.resize(size_state); sr.resize(1); sr[0] = sample_rate; };