RNN VAD: pitch search optimizations (part 4)

Add inverted lags index to simplify the loop in
`FindBestPitchPeriod48kHz()`. Instead of looping over 294 items,
only loop over the relevant ones (up to 10) by keeping track of
the relevant indexes.

The benchmark has shown a slight improvement (about +6x).

Benchmarked as follows:
```
out/release/modules_unittests \
  --gtest_filter=*RnnVadTest.DISABLED_RnnVadPerformance* \
  --gtest_also_run_disabled_tests --logs
```

Results:

      | baseline             | this CL
------+----------------------+------------------------
run 1 | 22.8319 +/- 1.46554  | 22.1951 +/- 0.747611
      | 389.367x             | 400.539x
------+----------------------+------------------------
run 2 | 22.4286 +/- 0.726449 | 22.2718 +/- 0.963738
      | 396.369x             | 399.16x
------+----------------------+------------------------
run 2 | 22.5688 +/- 0.831341 | 22.4166 +/- 0.953362
      | 393.906x             | 396.581x

This CL also moved `PitchPseudoInterpolationInvLagAutoCorr()`
into `FindBestPitchPeriod48kHz()`.

Bug: webrtc:10480
Change-Id: Id4e6d755045c3198a80fa94a0a7463577d909b7e
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/191764
Commit-Queue: Alessio Bazzica <alessiob@webrtc.org>
Reviewed-by: Karl Wiberg <kwiberg@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#32590}
diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc
index d7ba65f..262c386 100644
--- a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc
+++ b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc
@@ -79,24 +79,6 @@
   return 2 * lag + offset;
 }
 
-// Refines a pitch period |inverted_lag| encoded as inverted lag with
-// pseudo-interpolation. The output sample rate is twice as that of
-// |inverted_lag|.
-int PitchPseudoInterpolationInvLagAutoCorr(
-    int inverted_lag,
-    rtc::ArrayView<const float, kInitialNumLags24kHz> auto_correlation) {
-  int offset = 0;
-  // Cannot apply pseudo-interpolation at the boundaries.
-  if (inverted_lag > 0 && inverted_lag < kInitialNumLags24kHz - 1) {
-    offset = GetPitchPseudoInterpolationOffset(
-        auto_correlation[inverted_lag + 1], auto_correlation[inverted_lag],
-        auto_correlation[inverted_lag - 1]);
-  }
-  // TODO(bugs.webrtc.org/9076): When retraining, check if |offset| below should
-  // be subtracted since |inverted_lag| is an inverted lag but offset is a lag.
-  return 2 * inverted_lag + offset;
-}
-
 // Integer multipliers used in ComputeExtendedPitchPeriod48kHz() when
 // looking for sub-harmonics.
 // The values have been chosen to serve the following algorithm. Given the
@@ -129,35 +111,75 @@
   int max;
 };
 
+// Number of analyzed pitches to the left(right) of a pitch candidate.
+constexpr int kPitchNeighborhoodRadius = 2;
+
 // Creates a pitch period interval centered in `inverted_lag` with hard-coded
 // radius. Clipping is applied so that the interval is always valid for a 24 kHz
 // pitch buffer.
 Range CreateInvertedLagRange(int inverted_lag) {
-  constexpr int kRadius = 2;
-  return {std::max(inverted_lag - kRadius, 0),
-          std::min(inverted_lag + kRadius, kInitialNumLags24kHz - 1)};
+  return {std::max(inverted_lag - kPitchNeighborhoodRadius, 0),
+          std::min(inverted_lag + kPitchNeighborhoodRadius,
+                   kInitialNumLags24kHz - 1)};
 }
 
+constexpr int kNumPitchCandidates = 2;  // Best and second best.
+// Maximum number of analyzed pitch periods.
+constexpr int kMaxPitchPeriods24kHz =
+    kNumPitchCandidates * (2 * kPitchNeighborhoodRadius + 1);
+
+// Collection of inverted lags.
+class InvertedLagsIndex {
+ public:
+  InvertedLagsIndex() : num_entries_(0) {}
+  // Adds an inverted lag to the index. Cannot add more than
+  // `kMaxPitchPeriods24kHz` values.
+  void Append(int inverted_lag) {
+    RTC_DCHECK_LT(num_entries_, kMaxPitchPeriods24kHz);
+    inverted_lags_[num_entries_++] = inverted_lag;
+  }
+  const int* data() const { return inverted_lags_.data(); }
+  int size() const { return num_entries_; }
+
+ private:
+  std::array<int, kMaxPitchPeriods24kHz> inverted_lags_;
+  int num_entries_;
+};
+
 // Computes the auto correlation coefficients for the inverted lags in the
-// closed interval `inverted_lags`.
+// closed interval `inverted_lags`. Updates `inverted_lags_index` by appending
+// the inverted lags for the computed auto correlation values.
 void ComputeAutoCorrelation(
     Range inverted_lags,
     rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
-    rtc::ArrayView<float, kInitialNumLags24kHz> auto_correlation) {
+    rtc::ArrayView<float, kInitialNumLags24kHz> auto_correlation,
+    InvertedLagsIndex& inverted_lags_index) {
   // Check valid range.
   RTC_DCHECK_LE(inverted_lags.min, inverted_lags.max);
+  // Trick to avoid zero initialization of `auto_correlation`.
+  // Needed by the pseudo-interpolation.
+  if (inverted_lags.min > 0) {
+    auto_correlation[inverted_lags.min - 1] = 0.f;
+  }
+  if (inverted_lags.max < kInitialNumLags24kHz - 1) {
+    auto_correlation[inverted_lags.max + 1] = 0.f;
+  }
   // Check valid `inverted_lag` indexes.
   RTC_DCHECK_GE(inverted_lags.min, 0);
-  RTC_DCHECK_LT(inverted_lags.max, auto_correlation.size());
+  RTC_DCHECK_LT(inverted_lags.max, kInitialNumLags24kHz);
   for (int inverted_lag = inverted_lags.min; inverted_lag <= inverted_lags.max;
        ++inverted_lag) {
     auto_correlation[inverted_lag] =
         ComputeAutoCorrelation(inverted_lag, pitch_buffer);
+    inverted_lags_index.Append(inverted_lag);
   }
 }
 
-int ComputePitchPeriod24kHz(
+// Searches the strongest pitch period at 24 kHz and returns its inverted lag at
+// 48 kHz.
+int ComputePitchPeriod48kHz(
     rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
+    rtc::ArrayView<const int> inverted_lags,
     rtc::ArrayView<const float, kInitialNumLags24kHz> auto_correlation,
     rtc::ArrayView<const float, kRefineNumLags24kHz> y_energy) {
   static_assert(kMaxPitch24kHz > kInitialNumLags24kHz, "");
@@ -165,8 +187,7 @@
   int best_inverted_lag = 0;     // Pitch period.
   float best_numerator = -1.f;   // Pitch strength numerator.
   float best_denominator = 0.f;  // Pitch strength denominator.
-  for (int inverted_lag = 0; inverted_lag < kInitialNumLags24kHz;
-       ++inverted_lag) {
+  for (int inverted_lag : inverted_lags) {
     // A pitch candidate must have positive correlation.
     if (auto_correlation[inverted_lag] > 0.f) {
       // Auto-correlation energy normalized by frame energy.
@@ -181,7 +202,19 @@
       }
     }
   }
-  return best_inverted_lag;
+  // Pseudo-interpolation to transform `best_inverted_lag` (24 kHz pitch) to a
+  // 48 kHz pitch period.
+  if (best_inverted_lag == 0 || best_inverted_lag >= kInitialNumLags24kHz - 1) {
+    // Cannot apply pseudo-interpolation at the boundaries.
+    return best_inverted_lag * 2;
+  }
+  int offset = GetPitchPseudoInterpolationOffset(
+      auto_correlation[best_inverted_lag + 1],
+      auto_correlation[best_inverted_lag],
+      auto_correlation[best_inverted_lag - 1]);
+  // TODO(bugs.webrtc.org/9076): When retraining, check if |offset| below should
+  // be subtracted since |inverted_lag| is an inverted lag but offset is a lag.
+  return 2 * best_inverted_lag + offset;
 }
 
 // Returns an alternative pitch period for `pitch_period` given a `multiplier`
@@ -332,10 +365,10 @@
     rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
     rtc::ArrayView<const float, kRefineNumLags24kHz> y_energy,
     CandidatePitchPeriods pitch_candidates) {
-  // Compute the auto-correlation terms only for neighbors of the given pitch
-  // candidates (similar to what is done in ComputePitchAutoCorrelation(), but
-  // for a few lag values).
-  std::array<float, kInitialNumLags24kHz> auto_correlation{};
+  // Compute the auto-correlation terms only for neighbors of the two pitch
+  // candidates (best and second best).
+  std::array<float, kInitialNumLags24kHz> auto_correlation;
+  InvertedLagsIndex inverted_lags_index;
   // Create two inverted lag ranges so that `r1` precedes `r2`.
   const bool swap_candidates =
       pitch_candidates.best > pitch_candidates.second_best;
@@ -351,18 +384,17 @@
   RTC_DCHECK_LE(r1.max, r2.max);
   if (r1.max + 1 >= r2.min) {
     // Overlapping or adjacent ranges.
-    ComputeAutoCorrelation({r1.min, r2.max}, pitch_buffer, auto_correlation);
+    ComputeAutoCorrelation({r1.min, r2.max}, pitch_buffer, auto_correlation,
+                           inverted_lags_index);
   } else {
     // Disjoint ranges.
-    ComputeAutoCorrelation(r1, pitch_buffer, auto_correlation);
-    ComputeAutoCorrelation(r2, pitch_buffer, auto_correlation);
+    ComputeAutoCorrelation(r1, pitch_buffer, auto_correlation,
+                           inverted_lags_index);
+    ComputeAutoCorrelation(r2, pitch_buffer, auto_correlation,
+                           inverted_lags_index);
   }
-  // Find best pitch at 24 kHz.
-  const int pitch_candidate_24kHz =
-      ComputePitchPeriod24kHz(pitch_buffer, auto_correlation, y_energy);
-  // Pseudo-interpolation.
-  return PitchPseudoInterpolationInvLagAutoCorr(pitch_candidate_24kHz,
-                                                auto_correlation);
+  return ComputePitchPeriod48kHz(pitch_buffer, inverted_lags_index,
+                                 auto_correlation, y_energy);
 }
 
 PitchInfo ComputeExtendedPitchPeriod48kHz(
diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc
index fc715c6..152d569 100644
--- a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc
+++ b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc
@@ -128,9 +128,9 @@
 TEST_P(ExtendedPitchPeriodSearchParametrizaion,
        PeriodBitExactnessGainWithinTolerance) {
   PitchTestData test_data;
-  std::vector<float> y_energy(kMaxPitch24kHz + 1);
-  rtc::ArrayView<float, kMaxPitch24kHz + 1> y_energy_view(y_energy.data(),
-                                                          kMaxPitch24kHz + 1);
+  std::vector<float> y_energy(kRefineNumLags24kHz);
+  rtc::ArrayView<float, kRefineNumLags24kHz> y_energy_view(y_energy.data(),
+                                                           kRefineNumLags24kHz);
   ComputeSlidingFrameSquareEnergies24kHz(test_data.GetPitchBufView(),
                                          y_energy_view);
   // TODO(bugs.webrtc.org/8948): Add when the issue is fixed.