Revert "RNN VAD: Replace Ooura with PFFFT for the pitch auto correlation."
This reverts commit 8fcd6537f242ffd74154a62dad410e573e2efc4b.
Reason for revert: broke internal projects.
Original change's description:
> RNN VAD: Replace Ooura with PFFFT for the pitch auto correlation.
>
> Bug: webrtc:9577, webrtc:10480
> Change-Id: I6d58866d48b8eaaa4102551b88d4f55133d1915c
> Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/130482
> Commit-Queue: Alessio Bazzica <alessiob@webrtc.org>
> Reviewed-by: Gustaf Ullberg <gustaf@webrtc.org>
> Cr-Commit-Position: refs/heads/master@{#27387}
TBR=gustaf@webrtc.org,alessiob@webrtc.org
Change-Id: Ia05057326ebc277f334b13db0bfec9d4442903c2
No-Presubmit: true
No-Tree-Checks: true
No-Try: true
Bug: webrtc:9577, webrtc:10480
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/130369
Reviewed-by: Qingsi Wang <qingsi@webrtc.org>
Commit-Queue: Qingsi Wang <qingsi@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#27405}
diff --git a/modules/audio_processing/agc2/rnn_vad/BUILD.gn b/modules/audio_processing/agc2/rnn_vad/BUILD.gn
index 237c809..7379d41 100644
--- a/modules/audio_processing/agc2/rnn_vad/BUILD.gn
+++ b/modules/audio_processing/agc2/rnn_vad/BUILD.gn
@@ -11,8 +11,6 @@
rtc_source_set("rnn_vad") {
visibility = [ "../*" ]
sources = [
- "auto_correlation.cc",
- "auto_correlation.h",
"common.h",
"features_extraction.cc",
"features_extraction.h",
@@ -39,9 +37,9 @@
"..:biquad_filter",
"../../../../api:array_view",
"../../../../api:function_view",
+ "../../../../common_audio/",
"../../../../rtc_base:checks",
"../../../../rtc_base:rtc_base_approved",
- "../../utility:pffft_wrapper",
"//third_party/rnnoise:kiss_fft",
"//third_party/rnnoise:rnn_vad",
]
@@ -55,7 +53,6 @@
"test_utils.h",
]
deps = [
- ":rnn_vad",
"../../../../api:array_view",
"../../../../api:scoped_refptr",
"../../../../rtc_base:checks",
@@ -89,7 +86,6 @@
rtc_source_set("unittests") {
testonly = true
sources = [
- "auto_correlation_unittest.cc",
"features_extraction_unittest.cc",
"fft_util_unittest.cc",
"lp_residual_unittest.cc",
diff --git a/modules/audio_processing/agc2/rnn_vad/auto_correlation.cc b/modules/audio_processing/agc2/rnn_vad/auto_correlation.cc
deleted file mode 100644
index d932c78..0000000
--- a/modules/audio_processing/agc2/rnn_vad/auto_correlation.cc
+++ /dev/null
@@ -1,92 +0,0 @@
-/*
- * Copyright (c) 2019 The WebRTC project authors. All Rights Reserved.
- *
- * Use of this source code is governed by a BSD-style license
- * that can be found in the LICENSE file in the root of the source
- * tree. An additional intellectual property rights grant can be found
- * in the file PATENTS. All contributing project authors may
- * be found in the AUTHORS file in the root of the source tree.
- */
-
-#include "modules/audio_processing/agc2/rnn_vad/auto_correlation.h"
-
-#include <algorithm>
-
-#include "rtc_base/checks.h"
-
-namespace webrtc {
-namespace rnn_vad {
-namespace {
-
-constexpr int kAutoCorrelationFftOrder = 9; // Length-512 FFT.
-static_assert(1 << kAutoCorrelationFftOrder >
- kNumInvertedLags12kHz + kBufSize12kHz - kMaxPitch12kHz,
- "");
-
-} // namespace
-
-AutoCorrelationCalculator::AutoCorrelationCalculator()
- : fft_(1 << kAutoCorrelationFftOrder, Pffft::FftType::kReal),
- tmp_(fft_.CreateBuffer()),
- X_(fft_.CreateBuffer()),
- H_(fft_.CreateBuffer()) {}
-
-AutoCorrelationCalculator::~AutoCorrelationCalculator() = default;
-
-// The auto-correlations coefficients are computed as follows:
-// |.........|...........| <- pitch buffer
-// [ x (fixed) ]
-// [ y_0 ]
-// [ y_{m-1} ]
-// x and y are sub-array of equal length; x is never moved, whereas y slides.
-// The cross-correlation between y_0 and x corresponds to the auto-correlation
-// for the maximum pitch period. Hence, the first value in |auto_corr| has an
-// inverted lag equal to 0 that corresponds to a lag equal to the maximum
-// pitch period.
-void AutoCorrelationCalculator::ComputeOnPitchBuffer(
- rtc::ArrayView<const float, kBufSize12kHz> pitch_buf,
- rtc::ArrayView<float, kNumInvertedLags12kHz> auto_corr) {
- RTC_DCHECK_LT(auto_corr.size(), kMaxPitch12kHz);
- RTC_DCHECK_GT(pitch_buf.size(), kMaxPitch12kHz);
- constexpr size_t kFftFrameSize = 1 << kAutoCorrelationFftOrder;
- constexpr size_t kConvolutionLength = kBufSize12kHz - kMaxPitch12kHz;
- static_assert(kConvolutionLength == kFrameSize20ms12kHz,
- "Mismatch between pitch buffer size, frame size and maximum "
- "pitch period.");
- static_assert(kFftFrameSize > kNumInvertedLags12kHz + kConvolutionLength,
- "The FFT length is not sufficiently big to avoid cyclic "
- "convolution errors.");
- auto tmp = tmp_->GetView();
-
- // Compute the FFT for the reversed reference frame - i.e.,
- // pitch_buf[-kConvolutionLength:].
- std::reverse_copy(pitch_buf.end() - kConvolutionLength, pitch_buf.end(),
- tmp.begin());
- std::fill(tmp.begin() + kConvolutionLength, tmp.end(), 0.f);
- fft_.ForwardTransform(*tmp_, H_.get(), /*ordered=*/false);
-
- // Compute the FFT for the sliding frames chunk. The sliding frames are
- // defined as pitch_buf[i:i+kConvolutionLength] where i in
- // [0, kNumInvertedLags12kHz). The chunk includes all of them, hence it is
- // defined as pitch_buf[:kNumInvertedLags12kHz+kConvolutionLength].
- std::copy(pitch_buf.begin(),
- pitch_buf.begin() + kConvolutionLength + kNumInvertedLags12kHz,
- tmp.begin());
- std::fill(tmp.begin() + kNumInvertedLags12kHz + kConvolutionLength, tmp.end(),
- 0.f);
- fft_.ForwardTransform(*tmp_, X_.get(), /*ordered=*/false);
-
- // Convolve in the frequency domain.
- constexpr float kScalingFactor = 1.f / static_cast<float>(kFftFrameSize);
- std::fill(tmp.begin(), tmp.end(), 0.f);
- fft_.FrequencyDomainConvolve(*X_, *H_, tmp_.get(), kScalingFactor);
- fft_.BackwardTransform(*tmp_, tmp_.get(), /*ordered=*/false);
-
- // Extract the auto-correlation coefficients.
- std::copy(tmp.begin() + kConvolutionLength - 1,
- tmp.begin() + kConvolutionLength + kNumInvertedLags12kHz - 1,
- auto_corr.begin());
-}
-
-} // namespace rnn_vad
-} // namespace webrtc
diff --git a/modules/audio_processing/agc2/rnn_vad/auto_correlation.h b/modules/audio_processing/agc2/rnn_vad/auto_correlation.h
deleted file mode 100644
index de7f453..0000000
--- a/modules/audio_processing/agc2/rnn_vad/auto_correlation.h
+++ /dev/null
@@ -1,49 +0,0 @@
-/*
- * Copyright (c) 2019 The WebRTC project authors. All Rights Reserved.
- *
- * Use of this source code is governed by a BSD-style license
- * that can be found in the LICENSE file in the root of the source
- * tree. An additional intellectual property rights grant can be found
- * in the file PATENTS. All contributing project authors may
- * be found in the AUTHORS file in the root of the source tree.
- */
-
-#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_AUTO_CORRELATION_H_
-#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_AUTO_CORRELATION_H_
-
-#include <memory>
-
-#include "api/array_view.h"
-#include "modules/audio_processing/agc2/rnn_vad/common.h"
-#include "modules/audio_processing/utility/pffft_wrapper.h"
-
-namespace webrtc {
-namespace rnn_vad {
-
-// Class to compute the auto correlation on the pitch buffer for a target pitch
-// interval.
-class AutoCorrelationCalculator {
- public:
- AutoCorrelationCalculator();
- AutoCorrelationCalculator(const AutoCorrelationCalculator&) = delete;
- AutoCorrelationCalculator& operator=(const AutoCorrelationCalculator&) =
- delete;
- ~AutoCorrelationCalculator();
-
- // Computes the auto-correlation coefficients for a target pitch interval.
- // |auto_corr| indexes are inverted lags.
- void ComputeOnPitchBuffer(
- rtc::ArrayView<const float, kBufSize12kHz> pitch_buf,
- rtc::ArrayView<float, kNumInvertedLags12kHz> auto_corr);
-
- private:
- Pffft fft_;
- std::unique_ptr<Pffft::FloatBuffer> tmp_;
- std::unique_ptr<Pffft::FloatBuffer> X_;
- std::unique_ptr<Pffft::FloatBuffer> H_;
-};
-
-} // namespace rnn_vad
-} // namespace webrtc
-
-#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_AUTO_CORRELATION_H_
diff --git a/modules/audio_processing/agc2/rnn_vad/auto_correlation_unittest.cc b/modules/audio_processing/agc2/rnn_vad/auto_correlation_unittest.cc
deleted file mode 100644
index a5e456a..0000000
--- a/modules/audio_processing/agc2/rnn_vad/auto_correlation_unittest.cc
+++ /dev/null
@@ -1,62 +0,0 @@
-/*
- * Copyright (c) 2019 The WebRTC project authors. All Rights Reserved.
- *
- * Use of this source code is governed by a BSD-style license
- * that can be found in the LICENSE file in the root of the source
- * tree. An additional intellectual property rights grant can be found
- * in the file PATENTS. All contributing project authors may
- * be found in the AUTHORS file in the root of the source tree.
- */
-
-#include "modules/audio_processing/agc2/rnn_vad/auto_correlation.h"
-
-#include "modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h"
-#include "modules/audio_processing/agc2/rnn_vad/test_utils.h"
-#include "test/gtest.h"
-
-namespace webrtc {
-namespace rnn_vad {
-namespace test {
-
-TEST(RnnVadTest, PitchBufferAutoCorrelationWithinTolerance) {
- PitchTestData test_data;
- std::array<float, kBufSize12kHz> pitch_buf_decimated;
- Decimate2x(test_data.GetPitchBufView(), pitch_buf_decimated);
- std::array<float, kNumPitchBufAutoCorrCoeffs> computed_output;
- {
- // TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
- // FloatingPointExceptionObserver fpe_observer;
- AutoCorrelationCalculator auto_corr_calculator;
- auto_corr_calculator.ComputeOnPitchBuffer(pitch_buf_decimated,
- computed_output);
- }
- auto auto_corr_view = test_data.GetPitchBufAutoCorrCoeffsView();
- ExpectNearAbsolute({auto_corr_view.data(), auto_corr_view.size()},
- computed_output, 3e-3f);
-}
-
-// Check that the auto correlation function computes the right thing for a
-// simple use case.
-TEST(RnnVadTest, CheckAutoCorrelationOnConstantPitchBuffer) {
- // Create constant signal with no pitch.
- std::array<float, kBufSize12kHz> pitch_buf_decimated;
- std::fill(pitch_buf_decimated.begin(), pitch_buf_decimated.end(), 1.f);
- std::array<float, kNumPitchBufAutoCorrCoeffs> computed_output;
- {
- // TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
- // FloatingPointExceptionObserver fpe_observer;
- AutoCorrelationCalculator auto_corr_calculator;
- auto_corr_calculator.ComputeOnPitchBuffer(pitch_buf_decimated,
- computed_output);
- }
- // The expected output is constantly the length of the fixed 'x'
- // array in ComputePitchAutoCorrelation.
- std::array<float, kNumPitchBufAutoCorrCoeffs> expected_output;
- std::fill(expected_output.begin(), expected_output.end(),
- kBufSize12kHz - kMaxPitch12kHz);
- ExpectNearAbsolute(expected_output, computed_output, 4e-5f);
-}
-
-} // namespace test
-} // namespace rnn_vad
-} // namespace webrtc
diff --git a/modules/audio_processing/agc2/rnn_vad/common.h b/modules/audio_processing/agc2/rnn_vad/common.h
index 2f16cd4..b98438d 100644
--- a/modules/audio_processing/agc2/rnn_vad/common.h
+++ b/modules/audio_processing/agc2/rnn_vad/common.h
@@ -20,21 +20,18 @@
constexpr size_t kFrameSize10ms24kHz = kSampleRate24kHz / 100;
constexpr size_t kFrameSize20ms24kHz = kFrameSize10ms24kHz * 2;
-// Pitch buffer.
+// Pitch analysis params.
constexpr size_t kMinPitch24kHz = kSampleRate24kHz / 800; // 0.00125 s.
constexpr size_t kMaxPitch24kHz = kSampleRate24kHz / 62.5; // 0.016 s.
constexpr size_t kBufSize24kHz = kMaxPitch24kHz + kFrameSize20ms24kHz;
static_assert((kBufSize24kHz & 1) == 0, "The buffer size must be even.");
-// 24 kHz analysis.
// Define a higher minimum pitch period for the initial search. This is used to
// avoid searching for very short periods, for which a refinement step is
// responsible.
constexpr size_t kInitialMinPitch24kHz = 3 * kMinPitch24kHz;
static_assert(kMinPitch24kHz < kInitialMinPitch24kHz, "");
static_assert(kInitialMinPitch24kHz < kMaxPitch24kHz, "");
-static_assert(kMaxPitch24kHz > kInitialMinPitch24kHz, "");
-constexpr size_t kNumInvertedLags24kHz = kMaxPitch24kHz - kInitialMinPitch24kHz;
// 12 kHz analysis.
constexpr size_t kSampleRate12kHz = 12000;
@@ -43,10 +40,6 @@
constexpr size_t kBufSize12kHz = kBufSize24kHz / 2;
constexpr size_t kInitialMinPitch12kHz = kInitialMinPitch24kHz / 2;
constexpr size_t kMaxPitch12kHz = kMaxPitch24kHz / 2;
-static_assert(kMaxPitch12kHz > kInitialMinPitch12kHz, "");
-// The inverted lags for the pitch interval [|kInitialMinPitch12kHz|,
-// |kMaxPitch12kHz|] are in the range [0, |kNumInvertedLags12kHz|].
-constexpr size_t kNumInvertedLags12kHz = kMaxPitch12kHz - kInitialMinPitch12kHz;
// 48 kHz constants.
constexpr size_t kMinPitch48kHz = kMinPitch24kHz * 2;
diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search.cc b/modules/audio_processing/agc2/rnn_vad/pitch_search.cc
index 1b3b459..aa0b751 100644
--- a/modules/audio_processing/agc2/rnn_vad/pitch_search.cc
+++ b/modules/audio_processing/agc2/rnn_vad/pitch_search.cc
@@ -19,7 +19,8 @@
namespace rnn_vad {
PitchEstimator::PitchEstimator()
- : pitch_buf_decimated_(kBufSize12kHz),
+ : fft_(RealFourier::Create(kAutoCorrelationFftOrder)),
+ pitch_buf_decimated_(kBufSize12kHz),
pitch_buf_decimated_view_(pitch_buf_decimated_.data(), kBufSize12kHz),
auto_corr_(kNumInvertedLags12kHz),
auto_corr_view_(auto_corr_.data(), kNumInvertedLags12kHz) {
@@ -33,16 +34,20 @@
rtc::ArrayView<const float, kBufSize24kHz> pitch_buf) {
// Perform the initial pitch search at 12 kHz.
Decimate2x(pitch_buf, pitch_buf_decimated_view_);
- auto_corr_calculator_.ComputeOnPitchBuffer(pitch_buf_decimated_view_,
- auto_corr_view_);
+ // Compute auto-correlation terms.
+ ComputePitchAutoCorrelation(pitch_buf_decimated_view_, kMaxPitch12kHz,
+ auto_corr_view_, fft_.get());
+
+ // Search for pitch at 12 kHz.
std::array<size_t, 2> pitch_candidates_inv_lags = FindBestPitchPeriods(
auto_corr_view_, pitch_buf_decimated_view_, kMaxPitch12kHz);
+
// Refine the pitch period estimation.
// The refinement is done using the pitch buffer that contains 24 kHz samples.
// Therefore, adapt the inverted lags in |pitch_candidates_inv_lags| from 12
// to 24 kHz.
- pitch_candidates_inv_lags[0] *= 2;
- pitch_candidates_inv_lags[1] *= 2;
+ for (size_t i = 0; i < pitch_candidates_inv_lags.size(); ++i)
+ pitch_candidates_inv_lags[i] *= 2;
size_t pitch_inv_lag_48kHz =
RefinePitchPeriod48kHz(pitch_buf, pitch_candidates_inv_lags);
// Look for stronger harmonics to find the final pitch period and its gain.
diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search.h b/modules/audio_processing/agc2/rnn_vad/pitch_search.h
index 74133d0..5914535 100644
--- a/modules/audio_processing/agc2/rnn_vad/pitch_search.h
+++ b/modules/audio_processing/agc2/rnn_vad/pitch_search.h
@@ -15,7 +15,7 @@
#include <vector>
#include "api/array_view.h"
-#include "modules/audio_processing/agc2/rnn_vad/auto_correlation.h"
+#include "common_audio/real_fourier.h"
#include "modules/audio_processing/agc2/rnn_vad/common.h"
#include "modules/audio_processing/agc2/rnn_vad/pitch_info.h"
#include "modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h"
@@ -36,7 +36,7 @@
private:
PitchInfo last_pitch_48kHz_;
- AutoCorrelationCalculator auto_corr_calculator_;
+ std::unique_ptr<RealFourier> fft_;
std::vector<float> pitch_buf_decimated_;
rtc::ArrayView<float, kBufSize12kHz> pitch_buf_decimated_view_;
std::vector<float> auto_corr_;
diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc
index 0561c37..7c17dfb 100644
--- a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc
+++ b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc
@@ -13,6 +13,7 @@
#include <stdlib.h>
#include <algorithm>
#include <cmath>
+#include <complex>
#include <cstddef>
#include <numeric>
@@ -212,6 +213,64 @@
}
}
+void ComputePitchAutoCorrelation(
+ rtc::ArrayView<const float, kBufSize12kHz> pitch_buf,
+ size_t max_pitch_period,
+ rtc::ArrayView<float, kNumInvertedLags12kHz> auto_corr,
+ webrtc::RealFourier* fft) {
+ RTC_DCHECK_GT(max_pitch_period, auto_corr.size());
+ RTC_DCHECK_LT(max_pitch_period, pitch_buf.size());
+ RTC_DCHECK(fft);
+
+ constexpr size_t time_domain_fft_length = 1 << kAutoCorrelationFftOrder;
+ constexpr size_t freq_domain_fft_length = time_domain_fft_length / 2 + 1;
+
+ RTC_DCHECK_EQ(RealFourier::FftLength(fft->order()), time_domain_fft_length);
+ RTC_DCHECK_EQ(RealFourier::ComplexLength(fft->order()),
+ freq_domain_fft_length);
+
+ // Cross-correlation of y_i=pitch_buf[i:i+convolution_length] and
+ // x=pitch_buf[-convolution_length:] is equivalent to convolution of
+ // y_i and reversed(x). New notation: h=reversed(x), x=y.
+ std::array<float, time_domain_fft_length> h{};
+ std::array<float, time_domain_fft_length> x{};
+
+ const size_t convolution_length = kBufSize12kHz - max_pitch_period;
+ // Check that the FFT-length is big enough to avoid cyclic
+ // convolution errors.
+ RTC_DCHECK_GT(time_domain_fft_length,
+ kNumInvertedLags12kHz + convolution_length);
+
+ // h[0:convolution_length] is reversed pitch_buf[-convolution_length:].
+ std::reverse_copy(pitch_buf.end() - convolution_length, pitch_buf.end(),
+ h.begin());
+
+ // x is pitch_buf[:kNumInvertedLags12kHz + convolution_length].
+ std::copy(pitch_buf.begin(),
+ pitch_buf.begin() + kNumInvertedLags12kHz + convolution_length,
+ x.begin());
+
+ // Shift to frequency domain.
+ std::array<std::complex<float>, freq_domain_fft_length> X{};
+ std::array<std::complex<float>, freq_domain_fft_length> H{};
+ fft->Forward(&x[0], &X[0]);
+ fft->Forward(&h[0], &H[0]);
+
+ // Convolve in frequency domain.
+ for (size_t i = 0; i < X.size(); ++i) {
+ X[i] *= H[i];
+ }
+
+ // Shift back to time domain.
+ std::array<float, time_domain_fft_length> x_conv_h;
+ fft->Inverse(&X[0], &x_conv_h[0]);
+
+ // Collect the result.
+ std::copy(x_conv_h.begin() + convolution_length - 1,
+ x_conv_h.begin() + convolution_length + kNumInvertedLags12kHz - 1,
+ auto_corr.begin());
+}
+
std::array<size_t, 2> FindBestPitchPeriods(
rtc::ArrayView<const float> auto_corr,
rtc::ArrayView<const float> pitch_buf,
diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h
index 6ccd165..aabf713 100644
--- a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h
+++ b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h
@@ -15,12 +15,25 @@
#include <array>
#include "api/array_view.h"
+#include "common_audio/real_fourier.h"
#include "modules/audio_processing/agc2/rnn_vad/common.h"
#include "modules/audio_processing/agc2/rnn_vad/pitch_info.h"
namespace webrtc {
namespace rnn_vad {
+// The inverted lags for the pitch interval [|kInitialMinPitch12kHz|,
+// |kMaxPitch12kHz|] are in the range [0, |kNumInvertedLags|].
+static_assert(kMaxPitch12kHz > kInitialMinPitch12kHz, "");
+static_assert(kMaxPitch24kHz > kInitialMinPitch24kHz, "");
+constexpr size_t kNumInvertedLags12kHz = kMaxPitch12kHz - kInitialMinPitch12kHz;
+constexpr size_t kNumInvertedLags24kHz = kMaxPitch24kHz - kInitialMinPitch24kHz;
+constexpr int kAutoCorrelationFftOrder = 9; // Length-512 FFT.
+
+static_assert(1 << kAutoCorrelationFftOrder >
+ kNumInvertedLags12kHz + kBufSize12kHz - kMaxPitch12kHz,
+ "");
+
// Performs 2x decimation without any anti-aliasing filter.
void Decimate2x(rtc::ArrayView<const float, kBufSize24kHz> src,
rtc::ArrayView<float, kBufSize12kHz> dst);
@@ -48,6 +61,25 @@
rtc::ArrayView<const float, kBufSize24kHz> pitch_buf,
rtc::ArrayView<float, kMaxPitch24kHz + 1> yy_values);
+// Computes the auto-correlation coefficients for a given pitch interval.
+// |auto_corr| indexes are inverted lags.
+//
+// The auto-correlations coefficients are computed as follows:
+// |.........|...........| <- pitch buffer
+// [ x (fixed) ]
+// [ y_0 ]
+// [ y_{m-1} ]
+// x and y are sub-array of equal length; x is never moved, whereas y slides.
+// The cross-correlation between y_0 and x corresponds to the auto-correlation
+// for the maximum pitch period. Hence, the first value in |auto_corr| has an
+// inverted lag equal to 0 that corresponds to a lag equal to the maximum pitch
+// period.
+void ComputePitchAutoCorrelation(
+ rtc::ArrayView<const float, kBufSize12kHz> pitch_buf,
+ size_t max_pitch_period,
+ rtc::ArrayView<float, kNumInvertedLags12kHz> auto_corr,
+ webrtc::RealFourier* fft);
+
// Given the auto-correlation coefficients stored according to
// ComputePitchAutoCorrelation() (i.e., using inverted lags), returns the best
// and the second best pitch periods.
diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc
index bd2ea24..8ff6ac1 100644
--- a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc
+++ b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc
@@ -9,6 +9,7 @@
*/
#include "modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h"
+#include "common_audio/real_fourier.h"
#include <array>
#include <tuple>
@@ -29,6 +30,34 @@
};
constexpr std::array<float, 2> kTestPitchGains = {0.35f, 0.75f};
+constexpr size_t kNumPitchBufSquareEnergies = 385;
+constexpr size_t kNumPitchBufAutoCorrCoeffs = 147;
+constexpr size_t kTestDataSize =
+ kBufSize24kHz + kNumPitchBufSquareEnergies + kNumPitchBufAutoCorrCoeffs;
+
+class TestData {
+ public:
+ TestData() {
+ auto test_data_reader = CreatePitchSearchTestDataReader();
+ test_data_reader->ReadChunk(test_data_);
+ }
+ rtc::ArrayView<const float, kBufSize24kHz> GetPitchBufView() {
+ return {test_data_.data(), kBufSize24kHz};
+ }
+ rtc::ArrayView<const float, kNumPitchBufSquareEnergies>
+ GetPitchBufSquareEnergiesView() {
+ return {test_data_.data() + kBufSize24kHz, kNumPitchBufSquareEnergies};
+ }
+ rtc::ArrayView<const float, kNumPitchBufAutoCorrCoeffs>
+ GetPitchBufAutoCorrCoeffsView() {
+ return {test_data_.data() + kBufSize24kHz + kNumPitchBufSquareEnergies,
+ kNumPitchBufAutoCorrCoeffs};
+ }
+
+ private:
+ std::array<float, kTestDataSize> test_data_;
+};
+
} // namespace
class ComputePitchGainThresholdTest
@@ -78,7 +107,7 @@
std::make_tuple(78, 2, 156, 0.72750503f, 153, 0.85069299f, 0.618379f)));
TEST(RnnVadTest, ComputeSlidingFrameSquareEnergiesBitExactness) {
- PitchTestData test_data;
+ TestData test_data;
std::array<float, kNumPitchBufSquareEnergies> computed_output;
{
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
@@ -91,8 +120,51 @@
computed_output, 3e-2f);
}
+TEST(RnnVadTest, ComputePitchAutoCorrelationBitExactness) {
+ TestData test_data;
+ std::array<float, kBufSize12kHz> pitch_buf_decimated;
+ Decimate2x(test_data.GetPitchBufView(), pitch_buf_decimated);
+ std::array<float, kNumPitchBufAutoCorrCoeffs> computed_output;
+ {
+ // TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
+ // FloatingPointExceptionObserver fpe_observer;
+ std::unique_ptr<RealFourier> fft =
+ RealFourier::Create(kAutoCorrelationFftOrder);
+ ComputePitchAutoCorrelation(pitch_buf_decimated, kMaxPitch12kHz,
+ computed_output, fft.get());
+ }
+ auto auto_corr_view = test_data.GetPitchBufAutoCorrCoeffsView();
+ ExpectNearAbsolute({auto_corr_view.data(), auto_corr_view.size()},
+ computed_output, 3e-3f);
+}
+
+// Check that the auto correlation function computes the right thing for a
+// simple use case.
+TEST(RnnVadTest, ComputePitchAutoCorrelationConstantBuffer) {
+ // Create constant signal with no pitch.
+ std::array<float, kBufSize12kHz> pitch_buf_decimated;
+ std::fill(pitch_buf_decimated.begin(), pitch_buf_decimated.end(), 1.f);
+
+ std::array<float, kNumPitchBufAutoCorrCoeffs> computed_output;
+ {
+ // TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
+ // FloatingPointExceptionObserver fpe_observer;
+ std::unique_ptr<RealFourier> fft =
+ RealFourier::Create(kAutoCorrelationFftOrder);
+ ComputePitchAutoCorrelation(pitch_buf_decimated, kMaxPitch12kHz,
+ computed_output, fft.get());
+ }
+
+ // The expected output is constantly the length of the fixed 'x'
+ // array in ComputePitchAutoCorrelation.
+ std::array<float, kNumPitchBufAutoCorrCoeffs> expected_output;
+ std::fill(expected_output.begin(), expected_output.end(),
+ kBufSize12kHz - kMaxPitch12kHz);
+ ExpectNearAbsolute(expected_output, computed_output, 4e-5f);
+}
+
TEST(RnnVadTest, FindBestPitchPeriodsBitExactness) {
- PitchTestData test_data;
+ TestData test_data;
std::array<float, kBufSize12kHz> pitch_buf_decimated;
Decimate2x(test_data.GetPitchBufView(), pitch_buf_decimated);
std::array<size_t, 2> pitch_candidates_inv_lags;
@@ -109,7 +181,7 @@
}
TEST(RnnVadTest, RefinePitchPeriod48kHzBitExactness) {
- PitchTestData test_data;
+ TestData test_data;
std::array<float, kBufSize12kHz> pitch_buf_decimated;
Decimate2x(test_data.GetPitchBufView(), pitch_buf_decimated);
size_t pitch_inv_lag;
@@ -135,7 +207,7 @@
const float prev_pitch_gain = std::get<2>(params);
const int expected_pitch_period = std::get<3>(params);
const float expected_pitch_gain = std::get<4>(params);
- PitchTestData test_data;
+ TestData test_data;
{
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
// FloatingPointExceptionObserver fpe_observer;
diff --git a/modules/audio_processing/agc2/rnn_vad/test_utils.cc b/modules/audio_processing/agc2/rnn_vad/test_utils.cc
index 4dae8cd..8decbd0 100644
--- a/modules/audio_processing/agc2/rnn_vad/test_utils.cc
+++ b/modules/audio_processing/agc2/rnn_vad/test_utils.cc
@@ -111,28 +111,6 @@
return {std::move(ptr), ptr->data_length()};
}
-PitchTestData::PitchTestData() {
- auto test_data_reader = CreatePitchSearchTestDataReader();
- test_data_reader->ReadChunk(test_data_);
-}
-
-PitchTestData::~PitchTestData() = default;
-
-rtc::ArrayView<const float, kBufSize24kHz> PitchTestData::GetPitchBufView() {
- return {test_data_.data(), kBufSize24kHz};
-}
-
-rtc::ArrayView<const float, kNumPitchBufSquareEnergies>
-PitchTestData::GetPitchBufSquareEnergiesView() {
- return {test_data_.data() + kBufSize24kHz, kNumPitchBufSquareEnergies};
-}
-
-rtc::ArrayView<const float, kNumPitchBufAutoCorrCoeffs>
-PitchTestData::GetPitchBufAutoCorrCoeffsView() {
- return {test_data_.data() + kBufSize24kHz + kNumPitchBufSquareEnergies,
- kNumPitchBufAutoCorrCoeffs};
-}
-
} // namespace test
} // namespace rnn_vad
} // namespace webrtc
diff --git a/modules/audio_processing/agc2/rnn_vad/test_utils.h b/modules/audio_processing/agc2/rnn_vad/test_utils.h
index f9d7376..15be85a 100644
--- a/modules/audio_processing/agc2/rnn_vad/test_utils.h
+++ b/modules/audio_processing/agc2/rnn_vad/test_utils.h
@@ -12,7 +12,6 @@
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_TEST_UTILS_H_
#include <algorithm>
-#include <array>
#include <fstream>
#include <limits>
#include <memory>
@@ -21,7 +20,6 @@
#include <vector>
#include "api/array_view.h"
-#include "modules/audio_processing/agc2/rnn_vad/common.h"
#include "rtc_base/checks.h"
namespace webrtc {
@@ -120,27 +118,6 @@
std::pair<std::unique_ptr<BinaryFileReader<float>>, const size_t>
CreateVadProbsReader();
-constexpr size_t kNumPitchBufAutoCorrCoeffs = 147;
-constexpr size_t kNumPitchBufSquareEnergies = 385;
-constexpr size_t kPitchTestDataSize =
- kBufSize24kHz + kNumPitchBufSquareEnergies + kNumPitchBufAutoCorrCoeffs;
-
-// Class to retrieve a test pitch buffer content and the expected output for the
-// analysis steps.
-class PitchTestData {
- public:
- PitchTestData();
- ~PitchTestData();
- rtc::ArrayView<const float, kBufSize24kHz> GetPitchBufView();
- rtc::ArrayView<const float, kNumPitchBufSquareEnergies>
- GetPitchBufSquareEnergiesView();
- rtc::ArrayView<const float, kNumPitchBufAutoCorrCoeffs>
- GetPitchBufAutoCorrCoeffsView();
-
- private:
- std::array<float, kPitchTestDataSize> test_data_;
-};
-
} // namespace test
} // namespace rnn_vad
} // namespace webrtc