RNN VAD: pitch search optimizations (part 4)
Add inverted lags index to simplify the loop in
`FindBestPitchPeriod48kHz()`. Instead of looping over 294 items,
only loop over the relevant ones (up to 10) by keeping track of
the relevant indexes.
The benchmark has shown a slight improvement (about +6x).
Benchmarked as follows:
```
out/release/modules_unittests \
--gtest_filter=*RnnVadTest.DISABLED_RnnVadPerformance* \
--gtest_also_run_disabled_tests --logs
```
Results:
| baseline | this CL
------+----------------------+------------------------
run 1 | 22.8319 +/- 1.46554 | 22.1951 +/- 0.747611
| 389.367x | 400.539x
------+----------------------+------------------------
run 2 | 22.4286 +/- 0.726449 | 22.2718 +/- 0.963738
| 396.369x | 399.16x
------+----------------------+------------------------
run 2 | 22.5688 +/- 0.831341 | 22.4166 +/- 0.953362
| 393.906x | 396.581x
This CL also moved `PitchPseudoInterpolationInvLagAutoCorr()`
into `FindBestPitchPeriod48kHz()`.
Bug: webrtc:10480
Change-Id: Id4e6d755045c3198a80fa94a0a7463577d909b7e
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/191764
Commit-Queue: Alessio Bazzica <alessiob@webrtc.org>
Reviewed-by: Karl Wiberg <kwiberg@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#32590}
diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc
index d7ba65f..262c386 100644
--- a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc
+++ b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc
@@ -79,24 +79,6 @@
return 2 * lag + offset;
}
-// Refines a pitch period |inverted_lag| encoded as inverted lag with
-// pseudo-interpolation. The output sample rate is twice as that of
-// |inverted_lag|.
-int PitchPseudoInterpolationInvLagAutoCorr(
- int inverted_lag,
- rtc::ArrayView<const float, kInitialNumLags24kHz> auto_correlation) {
- int offset = 0;
- // Cannot apply pseudo-interpolation at the boundaries.
- if (inverted_lag > 0 && inverted_lag < kInitialNumLags24kHz - 1) {
- offset = GetPitchPseudoInterpolationOffset(
- auto_correlation[inverted_lag + 1], auto_correlation[inverted_lag],
- auto_correlation[inverted_lag - 1]);
- }
- // TODO(bugs.webrtc.org/9076): When retraining, check if |offset| below should
- // be subtracted since |inverted_lag| is an inverted lag but offset is a lag.
- return 2 * inverted_lag + offset;
-}
-
// Integer multipliers used in ComputeExtendedPitchPeriod48kHz() when
// looking for sub-harmonics.
// The values have been chosen to serve the following algorithm. Given the
@@ -129,35 +111,75 @@
int max;
};
+// Number of analyzed pitches to the left(right) of a pitch candidate.
+constexpr int kPitchNeighborhoodRadius = 2;
+
// Creates a pitch period interval centered in `inverted_lag` with hard-coded
// radius. Clipping is applied so that the interval is always valid for a 24 kHz
// pitch buffer.
Range CreateInvertedLagRange(int inverted_lag) {
- constexpr int kRadius = 2;
- return {std::max(inverted_lag - kRadius, 0),
- std::min(inverted_lag + kRadius, kInitialNumLags24kHz - 1)};
+ return {std::max(inverted_lag - kPitchNeighborhoodRadius, 0),
+ std::min(inverted_lag + kPitchNeighborhoodRadius,
+ kInitialNumLags24kHz - 1)};
}
+constexpr int kNumPitchCandidates = 2; // Best and second best.
+// Maximum number of analyzed pitch periods.
+constexpr int kMaxPitchPeriods24kHz =
+ kNumPitchCandidates * (2 * kPitchNeighborhoodRadius + 1);
+
+// Collection of inverted lags.
+class InvertedLagsIndex {
+ public:
+ InvertedLagsIndex() : num_entries_(0) {}
+ // Adds an inverted lag to the index. Cannot add more than
+ // `kMaxPitchPeriods24kHz` values.
+ void Append(int inverted_lag) {
+ RTC_DCHECK_LT(num_entries_, kMaxPitchPeriods24kHz);
+ inverted_lags_[num_entries_++] = inverted_lag;
+ }
+ const int* data() const { return inverted_lags_.data(); }
+ int size() const { return num_entries_; }
+
+ private:
+ std::array<int, kMaxPitchPeriods24kHz> inverted_lags_;
+ int num_entries_;
+};
+
// Computes the auto correlation coefficients for the inverted lags in the
-// closed interval `inverted_lags`.
+// closed interval `inverted_lags`. Updates `inverted_lags_index` by appending
+// the inverted lags for the computed auto correlation values.
void ComputeAutoCorrelation(
Range inverted_lags,
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
- rtc::ArrayView<float, kInitialNumLags24kHz> auto_correlation) {
+ rtc::ArrayView<float, kInitialNumLags24kHz> auto_correlation,
+ InvertedLagsIndex& inverted_lags_index) {
// Check valid range.
RTC_DCHECK_LE(inverted_lags.min, inverted_lags.max);
+ // Trick to avoid zero initialization of `auto_correlation`.
+ // Needed by the pseudo-interpolation.
+ if (inverted_lags.min > 0) {
+ auto_correlation[inverted_lags.min - 1] = 0.f;
+ }
+ if (inverted_lags.max < kInitialNumLags24kHz - 1) {
+ auto_correlation[inverted_lags.max + 1] = 0.f;
+ }
// Check valid `inverted_lag` indexes.
RTC_DCHECK_GE(inverted_lags.min, 0);
- RTC_DCHECK_LT(inverted_lags.max, auto_correlation.size());
+ RTC_DCHECK_LT(inverted_lags.max, kInitialNumLags24kHz);
for (int inverted_lag = inverted_lags.min; inverted_lag <= inverted_lags.max;
++inverted_lag) {
auto_correlation[inverted_lag] =
ComputeAutoCorrelation(inverted_lag, pitch_buffer);
+ inverted_lags_index.Append(inverted_lag);
}
}
-int ComputePitchPeriod24kHz(
+// Searches the strongest pitch period at 24 kHz and returns its inverted lag at
+// 48 kHz.
+int ComputePitchPeriod48kHz(
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
+ rtc::ArrayView<const int> inverted_lags,
rtc::ArrayView<const float, kInitialNumLags24kHz> auto_correlation,
rtc::ArrayView<const float, kRefineNumLags24kHz> y_energy) {
static_assert(kMaxPitch24kHz > kInitialNumLags24kHz, "");
@@ -165,8 +187,7 @@
int best_inverted_lag = 0; // Pitch period.
float best_numerator = -1.f; // Pitch strength numerator.
float best_denominator = 0.f; // Pitch strength denominator.
- for (int inverted_lag = 0; inverted_lag < kInitialNumLags24kHz;
- ++inverted_lag) {
+ for (int inverted_lag : inverted_lags) {
// A pitch candidate must have positive correlation.
if (auto_correlation[inverted_lag] > 0.f) {
// Auto-correlation energy normalized by frame energy.
@@ -181,7 +202,19 @@
}
}
}
- return best_inverted_lag;
+ // Pseudo-interpolation to transform `best_inverted_lag` (24 kHz pitch) to a
+ // 48 kHz pitch period.
+ if (best_inverted_lag == 0 || best_inverted_lag >= kInitialNumLags24kHz - 1) {
+ // Cannot apply pseudo-interpolation at the boundaries.
+ return best_inverted_lag * 2;
+ }
+ int offset = GetPitchPseudoInterpolationOffset(
+ auto_correlation[best_inverted_lag + 1],
+ auto_correlation[best_inverted_lag],
+ auto_correlation[best_inverted_lag - 1]);
+ // TODO(bugs.webrtc.org/9076): When retraining, check if |offset| below should
+ // be subtracted since |inverted_lag| is an inverted lag but offset is a lag.
+ return 2 * best_inverted_lag + offset;
}
// Returns an alternative pitch period for `pitch_period` given a `multiplier`
@@ -332,10 +365,10 @@
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
rtc::ArrayView<const float, kRefineNumLags24kHz> y_energy,
CandidatePitchPeriods pitch_candidates) {
- // Compute the auto-correlation terms only for neighbors of the given pitch
- // candidates (similar to what is done in ComputePitchAutoCorrelation(), but
- // for a few lag values).
- std::array<float, kInitialNumLags24kHz> auto_correlation{};
+ // Compute the auto-correlation terms only for neighbors of the two pitch
+ // candidates (best and second best).
+ std::array<float, kInitialNumLags24kHz> auto_correlation;
+ InvertedLagsIndex inverted_lags_index;
// Create two inverted lag ranges so that `r1` precedes `r2`.
const bool swap_candidates =
pitch_candidates.best > pitch_candidates.second_best;
@@ -351,18 +384,17 @@
RTC_DCHECK_LE(r1.max, r2.max);
if (r1.max + 1 >= r2.min) {
// Overlapping or adjacent ranges.
- ComputeAutoCorrelation({r1.min, r2.max}, pitch_buffer, auto_correlation);
+ ComputeAutoCorrelation({r1.min, r2.max}, pitch_buffer, auto_correlation,
+ inverted_lags_index);
} else {
// Disjoint ranges.
- ComputeAutoCorrelation(r1, pitch_buffer, auto_correlation);
- ComputeAutoCorrelation(r2, pitch_buffer, auto_correlation);
+ ComputeAutoCorrelation(r1, pitch_buffer, auto_correlation,
+ inverted_lags_index);
+ ComputeAutoCorrelation(r2, pitch_buffer, auto_correlation,
+ inverted_lags_index);
}
- // Find best pitch at 24 kHz.
- const int pitch_candidate_24kHz =
- ComputePitchPeriod24kHz(pitch_buffer, auto_correlation, y_energy);
- // Pseudo-interpolation.
- return PitchPseudoInterpolationInvLagAutoCorr(pitch_candidate_24kHz,
- auto_correlation);
+ return ComputePitchPeriod48kHz(pitch_buffer, inverted_lags_index,
+ auto_correlation, y_energy);
}
PitchInfo ComputeExtendedPitchPeriod48kHz(
diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc
index fc715c6..152d569 100644
--- a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc
+++ b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc
@@ -128,9 +128,9 @@
TEST_P(ExtendedPitchPeriodSearchParametrizaion,
PeriodBitExactnessGainWithinTolerance) {
PitchTestData test_data;
- std::vector<float> y_energy(kMaxPitch24kHz + 1);
- rtc::ArrayView<float, kMaxPitch24kHz + 1> y_energy_view(y_energy.data(),
- kMaxPitch24kHz + 1);
+ std::vector<float> y_energy(kRefineNumLags24kHz);
+ rtc::ArrayView<float, kRefineNumLags24kHz> y_energy_view(y_energy.data(),
+ kRefineNumLags24kHz);
ComputeSlidingFrameSquareEnergies24kHz(test_data.GetPitchBufView(),
y_energy_view);
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.