RNN VAD: Pitch periods as integers and for-if-break optimization
This CL includes two changes:
1. the type for (inverted) lags and pitch periods changed from
size_t to int to reduce the chance of bugs with pitch period
manipulations
2. CheckLowerPitchPeriodsAndComputePitchGain() is optimized by
replacing an unnecessary if statement inside the loop with the
predetermined number of loops
Bug: webrtc:10480
Change-Id: I38432699254b37a2c0111279c28be8dc65b87e9b
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/139252
Commit-Queue: Alessio Bazzica <alessiob@webrtc.org>
Reviewed-by: Gustaf Ullberg <gustaf@webrtc.org>
Reviewed-by: Fredrik Hernqvist <fhernqvist@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#32521}
diff --git a/modules/audio_processing/agc2/rnn_vad/BUILD.gn b/modules/audio_processing/agc2/rnn_vad/BUILD.gn
index fcf179c..7822901 100644
--- a/modules/audio_processing/agc2/rnn_vad/BUILD.gn
+++ b/modules/audio_processing/agc2/rnn_vad/BUILD.gn
@@ -92,6 +92,7 @@
":rnn_vad_common",
"../../../../api:array_view",
"../../../../rtc_base:checks",
+ "../../../../rtc_base:safe_compare",
]
}
diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search.cc b/modules/audio_processing/agc2/rnn_vad/pitch_search.cc
index 1b3b459..df73274 100644
--- a/modules/audio_processing/agc2/rnn_vad/pitch_search.cc
+++ b/modules/audio_processing/agc2/rnn_vad/pitch_search.cc
@@ -35,20 +35,22 @@
Decimate2x(pitch_buf, pitch_buf_decimated_view_);
auto_corr_calculator_.ComputeOnPitchBuffer(pitch_buf_decimated_view_,
auto_corr_view_);
- std::array<size_t, 2> pitch_candidates_inv_lags = FindBestPitchPeriods(
- auto_corr_view_, pitch_buf_decimated_view_, kMaxPitch12kHz);
+ CandidatePitchPeriods pitch_candidates_inverted_lags =
+ FindBestPitchPeriods(auto_corr_view_, pitch_buf_decimated_view_,
+ static_cast<int>(kMaxPitch12kHz));
// Refine the pitch period estimation.
// The refinement is done using the pitch buffer that contains 24 kHz samples.
// Therefore, adapt the inverted lags in |pitch_candidates_inv_lags| from 12
// to 24 kHz.
- pitch_candidates_inv_lags[0] *= 2;
- pitch_candidates_inv_lags[1] *= 2;
- size_t pitch_inv_lag_48kHz =
- RefinePitchPeriod48kHz(pitch_buf, pitch_candidates_inv_lags);
+ pitch_candidates_inverted_lags.best *= 2;
+ pitch_candidates_inverted_lags.second_best *= 2;
+ const int pitch_inv_lag_48kHz =
+ RefinePitchPeriod48kHz(pitch_buf, pitch_candidates_inverted_lags);
// Look for stronger harmonics to find the final pitch period and its gain.
- RTC_DCHECK_LT(pitch_inv_lag_48kHz, kMaxPitch48kHz);
+ RTC_DCHECK_LT(pitch_inv_lag_48kHz, static_cast<int>(kMaxPitch48kHz));
last_pitch_48kHz_ = CheckLowerPitchPeriodsAndComputePitchGain(
- pitch_buf, kMaxPitch48kHz - pitch_inv_lag_48kHz, last_pitch_48kHz_);
+ pitch_buf, static_cast<int>(kMaxPitch48kHz) - pitch_inv_lag_48kHz,
+ last_pitch_48kHz_);
return last_pitch_48kHz_;
}
diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc
index f24a76f..922669a 100644
--- a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc
+++ b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc
@@ -19,38 +19,41 @@
#include "modules/audio_processing/agc2/rnn_vad/common.h"
#include "rtc_base/checks.h"
+#include "rtc_base/numerics/safe_compare.h"
namespace webrtc {
namespace rnn_vad {
namespace {
+constexpr int kMaxPitch24kHzInt = static_cast<int>(kMaxPitch24kHz);
+
// Converts a lag to an inverted lag (only for 24kHz).
-size_t GetInvertedLag(size_t lag) {
- RTC_DCHECK_LE(lag, kMaxPitch24kHz);
- return kMaxPitch24kHz - lag;
+int GetInvertedLag(int lag) {
+ RTC_DCHECK_LE(lag, kMaxPitch24kHzInt);
+ return kMaxPitch24kHzInt - lag;
}
float ComputeAutoCorrelationCoeff(rtc::ArrayView<const float> pitch_buf,
- size_t inv_lag,
- size_t max_pitch_period) {
- RTC_DCHECK_LT(inv_lag, pitch_buf.size());
- RTC_DCHECK_LT(max_pitch_period, pitch_buf.size());
- RTC_DCHECK_LE(inv_lag, max_pitch_period);
+ int inv_lag,
+ int max_pitch_period) {
+ RTC_DCHECK_LT(inv_lag, static_cast<int>(pitch_buf.size()));
+ RTC_DCHECK_LT(max_pitch_period, static_cast<int>(pitch_buf.size()));
+ RTC_DCHECK_LE(inv_lag, static_cast<int>(max_pitch_period));
// TODO(bugs.webrtc.org/9076): Maybe optimize using vectorization.
- return std::inner_product(pitch_buf.begin() + max_pitch_period,
- pitch_buf.end(), pitch_buf.begin() + inv_lag, 0.f);
+ return std::inner_product(
+ pitch_buf.begin() + static_cast<size_t>(max_pitch_period),
+ pitch_buf.end(), pitch_buf.begin() + static_cast<size_t>(inv_lag), 0.f);
}
-// Computes a pseudo-interpolation offset for an estimated pitch period |lag| by
-// looking at the auto-correlation coefficients in the neighborhood of |lag|.
-// (namely, |prev_auto_corr|, |lag_auto_corr| and |next_auto_corr|). The output
-// is a lag in {-1, 0, +1}.
+// Given the auto-correlation coefficients for a lag and its neighbors, computes
+// a pseudo-interpolation offset to be applied to the pitch period associated to
+// the central auto-correlation coefficient |lag_auto_corr|. The output is a lag
+// in {-1, 0, +1}.
// TODO(bugs.webrtc.org/9076): Consider removing pseudo-i since it
// is relevant only if the spectral analysis works at a sample rate that is
// twice as that of the pitch buffer (not so important instead for the estimated
// pitch period feature fed into the RNN).
-int GetPitchPseudoInterpolationOffset(size_t lag,
- float prev_auto_corr,
+int GetPitchPseudoInterpolationOffset(float prev_auto_corr,
float lag_auto_corr,
float next_auto_corr) {
const float& a = prev_auto_corr;
@@ -68,20 +71,19 @@
// Refines a pitch period |lag| encoded as lag with pseudo-interpolation. The
// output sample rate is twice as that of |lag|.
-size_t PitchPseudoInterpolationLagPitchBuf(
- size_t lag,
+int PitchPseudoInterpolationLagPitchBuf(
+ int lag,
rtc::ArrayView<const float, kBufSize24kHz> pitch_buf) {
int offset = 0;
// Cannot apply pseudo-interpolation at the boundaries.
- if (lag > 0 && lag < kMaxPitch24kHz) {
+ if (lag > 0 && lag < kMaxPitch24kHzInt) {
offset = GetPitchPseudoInterpolationOffset(
- lag,
ComputeAutoCorrelationCoeff(pitch_buf, GetInvertedLag(lag - 1),
- kMaxPitch24kHz),
+ kMaxPitch24kHzInt),
ComputeAutoCorrelationCoeff(pitch_buf, GetInvertedLag(lag),
- kMaxPitch24kHz),
+ kMaxPitch24kHzInt),
ComputeAutoCorrelationCoeff(pitch_buf, GetInvertedLag(lag + 1),
- kMaxPitch24kHz));
+ kMaxPitch24kHzInt));
}
return 2 * lag + offset;
}
@@ -89,15 +91,14 @@
// Refines a pitch period |inv_lag| encoded as inverted lag with
// pseudo-interpolation. The output sample rate is twice as that of
// |inv_lag|.
-size_t PitchPseudoInterpolationInvLagAutoCorr(
- size_t inv_lag,
+int PitchPseudoInterpolationInvLagAutoCorr(
+ int inv_lag,
rtc::ArrayView<const float> auto_corr) {
int offset = 0;
// Cannot apply pseudo-interpolation at the boundaries.
- if (inv_lag > 0 && inv_lag < auto_corr.size() - 1) {
- offset = GetPitchPseudoInterpolationOffset(inv_lag, auto_corr[inv_lag + 1],
- auto_corr[inv_lag],
- auto_corr[inv_lag - 1]);
+ if (inv_lag > 0 && inv_lag < static_cast<int>(auto_corr.size()) - 1) {
+ offset = GetPitchPseudoInterpolationOffset(
+ auto_corr[inv_lag + 1], auto_corr[inv_lag], auto_corr[inv_lag - 1]);
}
// TODO(bugs.webrtc.org/9076): When retraining, check if |offset| below should
// be subtracted since |inv_lag| is an inverted lag but offset is a lag.
@@ -198,8 +199,8 @@
void ComputeSlidingFrameSquareEnergies(
rtc::ArrayView<const float, kBufSize24kHz> pitch_buf,
rtc::ArrayView<float, kMaxPitch24kHz + 1> yy_values) {
- float yy =
- ComputeAutoCorrelationCoeff(pitch_buf, kMaxPitch24kHz, kMaxPitch24kHz);
+ float yy = ComputeAutoCorrelationCoeff(pitch_buf, kMaxPitch24kHzInt,
+ kMaxPitch24kHzInt);
yy_values[0] = yy;
for (size_t i = 1; i < yy_values.size(); ++i) {
RTC_DCHECK_LE(i, kMaxPitch24kHz + kFrameSize20ms24kHz);
@@ -213,14 +214,14 @@
}
}
-std::array<size_t, 2> FindBestPitchPeriods(
+CandidatePitchPeriods FindBestPitchPeriods(
rtc::ArrayView<const float> auto_corr,
rtc::ArrayView<const float> pitch_buf,
- size_t max_pitch_period) {
+ int max_pitch_period) {
// Stores a pitch candidate period and strength information.
struct PitchCandidate {
// Pitch period encoded as inverted lag.
- size_t period_inverted_lag = 0;
+ int period_inverted_lag = 0;
// Pitch strength encoded as a ratio.
float strength_numerator = -1.f;
float strength_denominator = 0.f;
@@ -232,9 +233,10 @@
}
};
- RTC_DCHECK_GT(max_pitch_period, auto_corr.size());
- RTC_DCHECK_LT(max_pitch_period, pitch_buf.size());
- const size_t frame_size = pitch_buf.size() - max_pitch_period;
+ RTC_DCHECK_GT(max_pitch_period, static_cast<int>(auto_corr.size()));
+ RTC_DCHECK_LT(max_pitch_period, static_cast<int>(pitch_buf.size()));
+ const int frame_size = static_cast<int>(pitch_buf.size()) - max_pitch_period;
+ RTC_DCHECK_GT(frame_size, 0);
// TODO(bugs.webrtc.org/9076): Maybe optimize using vectorization.
float yy =
std::inner_product(pitch_buf.begin(), pitch_buf.begin() + frame_size + 1,
@@ -245,7 +247,8 @@
PitchCandidate best;
PitchCandidate second_best;
second_best.period_inverted_lag = 1;
- for (size_t inv_lag = 0; inv_lag < auto_corr.size(); ++inv_lag) {
+ for (int inv_lag = 0; inv_lag < static_cast<int>(auto_corr.size());
+ ++inv_lag) {
// A pitch candidate must have positive correlation.
if (auto_corr[inv_lag] > 0) {
candidate.period_inverted_lag = inv_lag;
@@ -267,32 +270,35 @@
yy += new_coeff * new_coeff;
yy = std::max(0.f, yy);
}
- return {{best.period_inverted_lag, second_best.period_inverted_lag}};
+ return {best.period_inverted_lag, second_best.period_inverted_lag};
}
-size_t RefinePitchPeriod48kHz(
+int RefinePitchPeriod48kHz(
rtc::ArrayView<const float, kBufSize24kHz> pitch_buf,
- rtc::ArrayView<const size_t, 2> inv_lags) {
+ CandidatePitchPeriods pitch_candidates_inverted_lags) {
// Compute the auto-correlation terms only for neighbors of the given pitch
// candidates (similar to what is done in ComputePitchAutoCorrelation(), but
// for a few lag values).
- std::array<float, kNumInvertedLags24kHz> auto_corr;
- auto_corr.fill(0.f); // Zeros become ignored lags in FindBestPitchPeriods().
- auto is_neighbor = [](size_t i, size_t j) {
+ std::array<float, kNumInvertedLags24kHz> auto_correlation;
+ auto_correlation.fill(
+ 0.f); // Zeros become ignored lags in FindBestPitchPeriods().
+ auto is_neighbor = [](int i, int j) {
return ((i > j) ? (i - j) : (j - i)) <= 2;
};
- for (size_t inv_lag = 0; inv_lag < auto_corr.size(); ++inv_lag) {
- if (is_neighbor(inv_lag, inv_lags[0]) || is_neighbor(inv_lag, inv_lags[1]))
- auto_corr[inv_lag] =
- ComputeAutoCorrelationCoeff(pitch_buf, inv_lag, kMaxPitch24kHz);
+ // TODO(https://crbug.com/webrtc/10480): Optimize by removing the loop.
+ for (int inverted_lag = 0; rtc::SafeLt(inverted_lag, auto_correlation.size());
+ ++inverted_lag) {
+ if (is_neighbor(inverted_lag, pitch_candidates_inverted_lags.best) ||
+ is_neighbor(inverted_lag, pitch_candidates_inverted_lags.second_best))
+ auto_correlation[inverted_lag] = ComputeAutoCorrelationCoeff(
+ pitch_buf, inverted_lag, kMaxPitch24kHzInt);
}
// Find best pitch at 24 kHz.
- const auto pitch_candidates_inv_lags = FindBestPitchPeriods(
- {auto_corr.data(), auto_corr.size()},
- {pitch_buf.data(), pitch_buf.size()}, kMaxPitch24kHz);
- const auto inv_lag = pitch_candidates_inv_lags[0]; // Refine the best.
+ const CandidatePitchPeriods pitch_candidates_24kHz =
+ FindBestPitchPeriods(auto_correlation, pitch_buf, kMaxPitch24kHzInt);
// Pseudo-interpolation.
- return PitchPseudoInterpolationInvLagAutoCorr(inv_lag, auto_corr);
+ return PitchPseudoInterpolationInvLagAutoCorr(pitch_candidates_24kHz.best,
+ auto_correlation);
}
PitchInfo CheckLowerPitchPeriodsAndComputePitchGain(
@@ -327,15 +333,15 @@
};
// Initial pitch candidate gain.
RefinedPitchCandidate best_pitch;
- best_pitch.period_24kHz = std::min(initial_pitch_period_48kHz / 2,
- static_cast<int>(kMaxPitch24kHz - 1));
+ best_pitch.period_24kHz =
+ std::min(initial_pitch_period_48kHz / 2, kMaxPitch24kHzInt - 1);
best_pitch.xy = ComputeAutoCorrelationCoeff(
- pitch_buf, GetInvertedLag(best_pitch.period_24kHz), kMaxPitch24kHz);
+ pitch_buf, GetInvertedLag(best_pitch.period_24kHz), kMaxPitch24kHzInt);
best_pitch.yy = yy_values[best_pitch.period_24kHz];
best_pitch.gain = pitch_gain(best_pitch.xy, best_pitch.yy, xx);
// Store the initial pitch period information.
- const size_t initial_pitch_period = best_pitch.period_24kHz;
+ const int initial_pitch_period = best_pitch.period_24kHz;
const float initial_pitch_gain = best_pitch.gain;
// Given the initial pitch estimation, check lower periods (i.e., harmonics).
@@ -343,12 +349,13 @@
RTC_DCHECK_GT(k, 0);
return (2 * n * period + k) / (2 * k); // Same as round(n*period/k).
};
- for (int k = 2; k < static_cast<int>(kSubHarmonicMultipliers.size() + 2);
- ++k) {
+ // |max_k| such that alternative_period(initial_pitch_period, max_k, 1) equals
+ // kMinPitch24kHz.
+ const int max_k =
+ (2 * initial_pitch_period) / (2 * static_cast<int>(kMinPitch24kHz) - 1);
+ for (int k = 2; k <= max_k; ++k) {
int candidate_pitch_period = alternative_period(initial_pitch_period, k, 1);
- if (static_cast<size_t>(candidate_pitch_period) < kMinPitch24kHz) {
- break;
- }
+ RTC_DCHECK_GE(candidate_pitch_period, static_cast<int>(kMinPitch24kHz));
// When looking at |candidate_pitch_period|, we also look at one of its
// sub-harmonics. |kSubHarmonicMultipliers| is used to know where to look.
// |k| == 2 is a special case since |candidate_pitch_secondary_period| might
@@ -356,8 +363,7 @@
int candidate_pitch_secondary_period = alternative_period(
initial_pitch_period, k, kSubHarmonicMultipliers[k - 2]);
RTC_DCHECK_GT(candidate_pitch_secondary_period, 0);
- if (k == 2 &&
- candidate_pitch_secondary_period > static_cast<int>(kMaxPitch24kHz)) {
+ if (k == 2 && candidate_pitch_secondary_period > kMaxPitch24kHzInt) {
candidate_pitch_secondary_period = initial_pitch_period;
}
RTC_DCHECK_NE(candidate_pitch_period, candidate_pitch_secondary_period)
@@ -367,10 +373,10 @@
// |candidate_pitch_period| by also looking at its possible sub-harmonic
// |candidate_pitch_secondary_period|.
float xy_primary_period = ComputeAutoCorrelationCoeff(
- pitch_buf, GetInvertedLag(candidate_pitch_period), kMaxPitch24kHz);
+ pitch_buf, GetInvertedLag(candidate_pitch_period), kMaxPitch24kHzInt);
float xy_secondary_period = ComputeAutoCorrelationCoeff(
pitch_buf, GetInvertedLag(candidate_pitch_secondary_period),
- kMaxPitch24kHz);
+ kMaxPitch24kHzInt);
float xy = 0.5f * (xy_primary_period + xy_secondary_period);
float yy = 0.5f * (yy_values[candidate_pitch_period] +
yy_values[candidate_pitch_secondary_period]);
@@ -393,7 +399,7 @@
: best_pitch.xy / (best_pitch.yy + 1.f);
final_pitch_gain = std::min(best_pitch.gain, final_pitch_gain);
int final_pitch_period_48kHz = std::max(
- kMinPitch48kHz,
+ static_cast<int>(kMinPitch48kHz),
PitchPseudoInterpolationLagPitchBuf(best_pitch.period_24kHz, pitch_buf));
return {final_pitch_period_48kHz, final_pitch_gain};
diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h
index 2cc5ce6..cab6286 100644
--- a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h
+++ b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h
@@ -14,6 +14,7 @@
#include <stddef.h>
#include <array>
+#include <utility>
#include "api/array_view.h"
#include "modules/audio_processing/agc2/rnn_vad/common.h"
@@ -49,20 +50,26 @@
rtc::ArrayView<const float, kBufSize24kHz> pitch_buf,
rtc::ArrayView<float, kMaxPitch24kHz + 1> yy_values);
-// Given the auto-correlation coefficients stored according to
-// ComputePitchAutoCorrelation() (i.e., using inverted lags), returns the best
-// and the second best pitch periods.
-std::array<size_t, 2> FindBestPitchPeriods(
+// Top-2 pitch period candidates.
+struct CandidatePitchPeriods {
+ int best;
+ int second_best;
+};
+
+// Computes the candidate pitch periods given the auto-correlation coefficients
+// stored according to ComputePitchAutoCorrelation() (i.e., using inverted
+// lags). The return periods are inverted lags.
+CandidatePitchPeriods FindBestPitchPeriods(
rtc::ArrayView<const float> auto_corr,
rtc::ArrayView<const float> pitch_buf,
- size_t max_pitch_period);
+ int max_pitch_period);
// Refines the pitch period estimation given the pitch buffer |pitch_buf| and
-// the initial pitch period estimation |inv_lags|. Returns an inverted lag at
-// 48 kHz.
-size_t RefinePitchPeriod48kHz(
+// the initial pitch period estimation |pitch_candidates_inverted_lags|.
+// Returns an inverted lag at 48 kHz.
+int RefinePitchPeriod48kHz(
rtc::ArrayView<const float, kBufSize24kHz> pitch_buf,
- rtc::ArrayView<const size_t, 2> inv_lags);
+ CandidatePitchPeriods pitch_candidates_inverted_lags);
// Refines the pitch period estimation and compute the pitch gain. Returns the
// refined pitch estimation data at 48 kHz.
diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc
index 23ff49a..37fb15f 100644
--- a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc
+++ b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc
@@ -104,31 +104,29 @@
PitchTestData test_data;
std::array<float, kBufSize12kHz> pitch_buf_decimated;
Decimate2x(test_data.GetPitchBufView(), pitch_buf_decimated);
- std::array<size_t, 2> pitch_candidates_inv_lags;
+ CandidatePitchPeriods pitch_candidates;
{
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
// FloatingPointExceptionObserver fpe_observer;
auto auto_corr_view = test_data.GetPitchBufAutoCorrCoeffsView();
- pitch_candidates_inv_lags =
- FindBestPitchPeriods({auto_corr_view.data(), auto_corr_view.size()},
- pitch_buf_decimated, kMaxPitch12kHz);
+ pitch_candidates = FindBestPitchPeriods(auto_corr_view, pitch_buf_decimated,
+ kMaxPitch12kHz);
}
- EXPECT_EQ(pitch_candidates_inv_lags[0], static_cast<size_t>(140));
- EXPECT_EQ(pitch_candidates_inv_lags[1], static_cast<size_t>(142));
+ EXPECT_EQ(pitch_candidates.best, 140);
+ EXPECT_EQ(pitch_candidates.second_best, 142);
}
// Checks that the refined pitch period is bit-exact given test input data.
TEST(RnnVadTest, RefinePitchPeriod48kHzBitExactness) {
PitchTestData test_data;
- size_t pitch_inv_lag;
- {
- // TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
- // FloatingPointExceptionObserver fpe_observer;
- const std::array<size_t, 2> pitch_candidates_inv_lags = {280, 284};
- pitch_inv_lag = RefinePitchPeriod48kHz(test_data.GetPitchBufView(),
- pitch_candidates_inv_lags);
- }
- EXPECT_EQ(560u, pitch_inv_lag);
+ // TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
+ // FloatingPointExceptionObserver fpe_observer;
+ EXPECT_EQ(RefinePitchPeriod48kHz(test_data.GetPitchBufView(),
+ /*pitch_candidates=*/{280, 284}),
+ 560);
+ EXPECT_EQ(RefinePitchPeriod48kHz(test_data.GetPitchBufView(),
+ /*pitch_candidates=*/{260, 284}),
+ 568);
}
class CheckLowerPitchPeriodsAndComputePitchGainTest