Optimize MatchedFilter.

Changing to an index for-loop (instead of using std::max_element & std::distance) tracking even & odd elements separately allows the compiler to produce code with less pipeline stall.

Bug: None
Change-Id: Iaa3e820a3a3b61e2eb276f0dac9106c848db1891
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/240061
Reviewed-by: Per Ã…hgren <peah@webrtc.org>
Commit-Queue: Christian Schuldt <cschuldt@google.com>
Cr-Commit-Position: refs/heads/main@{#35729}
diff --git a/modules/audio_processing/aec3/matched_filter.cc b/modules/audio_processing/aec3/matched_filter.cc
index 794381c..faca933 100644
--- a/modules/audio_processing/aec3/matched_filter.cc
+++ b/modules/audio_processing/aec3/matched_filter.cc
@@ -308,6 +308,41 @@
   }
 }
 
+size_t MaxSquarePeakIndex(rtc::ArrayView<const float> h) {
+  if (h.size() < 2) {
+    return 0;
+  }
+  float max_element1 = h[0] * h[0];
+  float max_element2 = h[1] * h[1];
+  size_t lag_estimate1 = 0;
+  size_t lag_estimate2 = 1;
+  const size_t last_index = h.size() - 1;
+  // Keeping track of even & odd max elements separately typically allows the
+  // compiler to produce more efficient code.
+  for (size_t k = 2; k < last_index; k += 2) {
+    float element1 = h[k] * h[k];
+    float element2 = h[k + 1] * h[k + 1];
+    if (element1 > max_element1) {
+      max_element1 = element1;
+      lag_estimate1 = k;
+    }
+    if (element2 > max_element2) {
+      max_element2 = element2;
+      lag_estimate2 = k + 1;
+    }
+  }
+  if (max_element2 > max_element1) {
+    max_element1 = max_element2;
+    lag_estimate1 = lag_estimate2;
+  }
+  // In case of odd h size, we have not yet checked the last element.
+  float last_element = h[last_index] * h[last_index];
+  if (last_element > max_element1) {
+    return last_index;
+  }
+  return lag_estimate1;
+}
+
 }  // namespace aec3
 
 MatchedFilter::MatchedFilter(ApmDataDumper* data_dumper,
@@ -400,17 +435,15 @@
     }
 
     // Compute anchor for the matched filter error.
-    const float error_sum_anchor =
-        std::inner_product(y.begin(), y.end(), y.begin(), 0.f);
+    float error_sum_anchor = 0.0f;
+    for (size_t k = 0; k < y.size(); ++k) {
+      error_sum_anchor += y[k] * y[k];
+    }
 
     // Estimate the lag in the matched filter as the distance to the portion in
     // the filter that contributes the most to the matched filter output. This
     // is detected as the peak of the matched filter.
-    const size_t lag_estimate = std::distance(
-        filters_[n].begin(),
-        std::max_element(
-            filters_[n].begin(), filters_[n].end(),
-            [](float a, float b) -> bool { return a * a < b * b; }));
+    const size_t lag_estimate = aec3::MaxSquarePeakIndex(filters_[n]);
 
     // Update the lag estimates for the matched filter.
     lag_estimates_[n] = LagEstimate(
diff --git a/modules/audio_processing/aec3/matched_filter.h b/modules/audio_processing/aec3/matched_filter.h
index c6410ab..dd4a678 100644
--- a/modules/audio_processing/aec3/matched_filter.h
+++ b/modules/audio_processing/aec3/matched_filter.h
@@ -74,6 +74,9 @@
                        bool* filters_updated,
                        float* error_sum);
 
+// Find largest peak of squared values in array.
+size_t MaxSquarePeakIndex(rtc::ArrayView<const float> h);
+
 }  // namespace aec3
 
 // Produces recursively updated cross-correlation estimates for several signal
diff --git a/modules/audio_processing/aec3/matched_filter_unittest.cc b/modules/audio_processing/aec3/matched_filter_unittest.cc
index 37b51fa..8abfb69 100644
--- a/modules/audio_processing/aec3/matched_filter_unittest.cc
+++ b/modules/audio_processing/aec3/matched_filter_unittest.cc
@@ -176,6 +176,28 @@
 
 #endif
 
+// Verifies that the (optimized) function MaxSquarePeakIndex() produces output
+// equal to the corresponding std-functions.
+TEST(MatchedFilter, MaxSquarePeakIndex) {
+  Random random_generator(42U);
+  constexpr int kMaxLength = 128;
+  constexpr int kNumIterationsPerLength = 256;
+  for (int length = 1; length < kMaxLength; ++length) {
+    std::vector<float> y(length);
+    for (int i = 0; i < kNumIterationsPerLength; ++i) {
+      RandomizeSampleVector(&random_generator, y);
+
+      size_t lag_from_function = MaxSquarePeakIndex(y);
+      size_t lag_from_std = std::distance(
+          y.begin(),
+          std::max_element(y.begin(), y.end(), [](float a, float b) -> bool {
+            return a * a < b * b;
+          }));
+      EXPECT_EQ(lag_from_function, lag_from_std);
+    }
+  }
+}
+
 // Verifies that the matched filter produces proper lag estimates for
 // artificially
 // delayed signals.