Finalized the SSE2 optimizations for the matched filter in AEC3

The SSE2 optimizations of the filter core in the matched
filter was only half-done. This CL finalizes those.

In particular:
-It adds finalization of updating of the filter.
-It removes the manual loop unrolling in order to reduce and
simplify the code.

Note that the changes pass the bitexactness tests in an
external AEC3 test suite, and the test
MatchedFilter.TestOptimizations succeed.

BUG=webrtc:6018

Review-Url: https://codereview.webrtc.org/2813563003
Cr-Commit-Position: refs/heads/master@{#17655}
diff --git a/webrtc/modules/audio_processing/aec3/matched_filter.cc b/webrtc/modules/audio_processing/aec3/matched_filter.cc
index 5da902d..7bb5778 100644
--- a/webrtc/modules/audio_processing/aec3/matched_filter.cc
+++ b/webrtc/modules/audio_processing/aec3/matched_filter.cc
@@ -31,50 +31,56 @@
                             rtc::ArrayView<float> h,
                             bool* filters_updated,
                             float* error_sum) {
+  const int h_size = static_cast<int>(h.size());
+  const int x_size = static_cast<int>(x.size());
+  RTC_DCHECK_EQ(0, h_size % 4);
+
   // Process for all samples in the sub-block.
   for (size_t i = 0; i < kSubBlockSize; ++i) {
-    // Apply the matched filter as filter * x.  and compute x * x.
-    float x2_sum = 0.f;
-    float s = 0;
-    size_t x_index = x_start_index;
-    RTC_DCHECK_EQ(0, h.size() % 4);
+    // Apply the matched filter as filter * x, and compute x * x.
 
+    RTC_DCHECK_GT(x_size, x_start_index);
+    const float* x_p = &x[x_start_index];
+    const float* h_p = &h[0];
+
+    // Initialize values for the accumulation.
     __m128 s_128 = _mm_set1_ps(0);
     __m128 x2_sum_128 = _mm_set1_ps(0);
+    float x2_sum = 0.f;
+    float s = 0;
 
-    size_t k = 0;
-    if (h.size() > (x.size() - x_index)) {
-      const size_t limit = x.size() - x_index;
-      for (; (k + 3) < limit; k += 4, x_index += 4) {
-        const __m128 x_k = _mm_loadu_ps(&x[x_index]);
-        const __m128 h_k = _mm_loadu_ps(&h[k]);
+    // Compute loop chunk sizes until, and after, the wraparound of the circular
+    // buffer for x.
+    const int chunk1 =
+        std::min(h_size, static_cast<int>(x_size - x_start_index));
+
+    // Perform the loop in two chunks.
+    const int chunk2 = h_size - chunk1;
+    for (int limit : {chunk1, chunk2}) {
+      // Perform 128 bit vector operations.
+      const int limit_by_4 = limit >> 2;
+      for (int k = limit_by_4; k > 0; --k, h_p += 4, x_p += 4) {
+        // Load the data into 128 bit vectors.
+        const __m128 x_k = _mm_loadu_ps(x_p);
+        const __m128 h_k = _mm_loadu_ps(h_p);
         const __m128 xx = _mm_mul_ps(x_k, x_k);
+        // Compute and accumulate x * x and h * x.
         x2_sum_128 = _mm_add_ps(x2_sum_128, xx);
         const __m128 hx = _mm_mul_ps(h_k, x_k);
         s_128 = _mm_add_ps(s_128, hx);
       }
 
-      for (; k < limit; ++k, ++x_index) {
-        x2_sum += x[x_index] * x[x_index];
-        s += h[k] * x[x_index];
+      // Perform non-vector operations for any remaining items.
+      for (int k = limit - limit_by_4 * 4; k > 0; --k, ++h_p, ++x_p) {
+        const float x_k = *x_p;
+        x2_sum += x_k * x_k;
+        s += *h_p * x_k;
       }
-      x_index = 0;
+
+      x_p = &x[0];
     }
 
-    for (; k + 3 < h.size(); k += 4, x_index += 4) {
-      const __m128 x_k = _mm_loadu_ps(&x[x_index]);
-      const __m128 h_k = _mm_loadu_ps(&h[k]);
-      const __m128 xx = _mm_mul_ps(x_k, x_k);
-      x2_sum_128 = _mm_add_ps(x2_sum_128, xx);
-      const __m128 hx = _mm_mul_ps(h_k, x_k);
-      s_128 = _mm_add_ps(s_128, hx);
-    }
-
-    for (; k < h.size(); ++k, ++x_index) {
-      x2_sum += x[x_index] * x[x_index];
-      s += h[k] * x[x_index];
-    }
-
+    // Combine the accumulated vector and scalar values.
     float* v = reinterpret_cast<float*>(&x2_sum_128);
     x2_sum += v[0] + v[1] + v[2] + v[3];
     v = reinterpret_cast<float*>(&s_128);
@@ -82,23 +88,47 @@
 
     // Compute the matched filter error.
     const float e = std::min(32767.f, std::max(-32768.f, y[i] - s));
-    (*error_sum) += e * e;
+    *error_sum += e * e;
 
     // Update the matched filter estimate in an NLMS manner.
     if (x2_sum > x2_sum_threshold) {
       RTC_DCHECK_LT(0.f, x2_sum);
       const float alpha = 0.7f * e / x2_sum;
+      const __m128 alpha_128 = _mm_set1_ps(alpha);
 
       // filter = filter + 0.7 * (y - filter * x) / x * x.
-      size_t x_index = x_start_index;
-      for (size_t k = 0; k < h.size(); ++k) {
-        h[k] += alpha * x[x_index];
-        x_index = x_index < (x.size() - 1) ? x_index + 1 : 0;
+      float* h_p = &h[0];
+      x_p = &x[x_start_index];
+
+      // Perform the loop in two chunks.
+      for (int limit : {chunk1, chunk2}) {
+        // Perform 128 bit vector operations.
+        const int limit_by_4 = limit >> 2;
+        for (int k = limit_by_4; k > 0; --k, h_p += 4, x_p += 4) {
+          // Load the data into 128 bit vectors.
+          __m128 h_k = _mm_loadu_ps(h_p);
+          const __m128 x_k = _mm_loadu_ps(x_p);
+
+          // Compute h = h + alpha * x.
+          const __m128 alpha_x = _mm_mul_ps(alpha_128, x_k);
+          h_k = _mm_add_ps(h_k, alpha_x);
+
+          // Store the result.
+          _mm_storeu_ps(h_p, h_k);
+        }
+
+        // Perform non-vector operations for any remaining items.
+        for (int k = limit - limit_by_4 * 4; k > 0; --k, ++h_p, ++x_p) {
+          *h_p += alpha * *x_p;
+        }
+
+        x_p = &x[0];
       }
+
       *filters_updated = true;
     }
 
-    x_start_index = x_start_index > 0 ? x_start_index - 1 : x.size() - 1;
+    x_start_index = x_start_index > 0 ? x_start_index - 1 : x_size - 1;
   }
 }
 #endif
@@ -112,7 +142,7 @@
                        float* error_sum) {
   // Process for all samples in the sub-block.
   for (size_t i = 0; i < kSubBlockSize; ++i) {
-    // Apply the matched filter as filter * x.  and compute x * x.
+    // Apply the matched filter as filter * x, and compute x * x.
     float x2_sum = 0.f;
     float s = 0;
     size_t x_index = x_start_index;
diff --git a/webrtc/modules/audio_processing/aec3/matched_filter_unittest.cc b/webrtc/modules/audio_processing/aec3/matched_filter_unittest.cc
index 75b6b57..45965c7 100644
--- a/webrtc/modules/audio_processing/aec3/matched_filter_unittest.cc
+++ b/webrtc/modules/audio_processing/aec3/matched_filter_unittest.cc
@@ -74,7 +74,7 @@
       EXPECT_NEAR(error_sum, error_sum_SSE2, error_sum / 100000.f);
 
       for (size_t j = 0; j < h.size(); ++j) {
-        EXPECT_NEAR(h[j], h_SSE2[j], 0.001f);
+        EXPECT_NEAR(h[j], h_SSE2[j], 0.00001f);
       }
 
       x_index = (x_index + kSubBlockSize) % x.size();