AEC3: Add support in the echo subtractor for handling multiple channels

This CL adds support in the echo subtractor for handling multiple
capture and render channels.

The changes have passed bitexactness tests for substantial set
of mono recordings.

Bug: webrtc:10913
Change-Id: Ib448c9edf172ebc31e8c28db7b2f2a389a53adb9
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/155168
Commit-Queue: Per Åhgren <peah@webrtc.org>
Reviewed-by: Gustaf Ullberg <gustaf@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#29389}
diff --git a/modules/audio_processing/aec3/adaptive_fir_filter.cc b/modules/audio_processing/aec3/adaptive_fir_filter.cc
index 00fa884..6a0f531 100644
--- a/modules/audio_processing/aec3/adaptive_fir_filter.cc
+++ b/modules/audio_processing/aec3/adaptive_fir_filter.cc
@@ -19,6 +19,8 @@
 #if defined(WEBRTC_ARCH_X86_FAMILY)
 #include <emmintrin.h>
 #endif
+#include <math.h>
+
 #include <algorithm>
 #include <functional>
 
@@ -30,207 +32,255 @@
 namespace aec3 {
 
 // Computes and stores the frequency response of the filter.
-void UpdateFrequencyResponse(
-    rtc::ArrayView<const FftData> H,
+void ComputeFrequencyResponse(
+    size_t num_partitions,
+    const std::vector<std::vector<FftData>>& H,
     std::vector<std::array<float, kFftLengthBy2Plus1>>* H2) {
-  RTC_DCHECK_EQ(H.size(), H2->size());
-  for (size_t k = 0; k < H.size(); ++k) {
-    std::transform(H[k].re.begin(), H[k].re.end(), H[k].im.begin(),
-                   (*H2)[k].begin(),
-                   [](float a, float b) { return a * a + b * b; });
+  for (auto& H2_ch : *H2) {
+    H2_ch.fill(0.f);
+  }
+
+  const size_t num_render_channels = H[0].size();
+  RTC_DCHECK_EQ(H.size(), H2->capacity());
+  for (size_t p = 0; p < num_partitions; ++p) {
+    RTC_DCHECK_EQ(kFftLengthBy2Plus1, (*H2)[p].size());
+    for (size_t ch = 0; ch < num_render_channels; ++ch) {
+      for (size_t j = 0; j < kFftLengthBy2Plus1; ++j) {
+        float tmp =
+            H[p][ch].re[j] * H[p][ch].re[j] + H[p][ch].im[j] * H[p][ch].im[j];
+        (*H2)[p][j] = std::max((*H2)[p][j], tmp);
+      }
+    }
   }
 }
 
 #if defined(WEBRTC_HAS_NEON)
 // Computes and stores the frequency response of the filter.
-void UpdateFrequencyResponse_NEON(
-    rtc::ArrayView<const FftData> H,
+void ComputeFrequencyResponse_Neon(
+    size_t num_partitions,
+    const std::vector<std::vector<FftData>>& H,
     std::vector<std::array<float, kFftLengthBy2Plus1>>* H2) {
-  RTC_DCHECK_EQ(H.size(), H2->size());
-  for (size_t k = 0; k < H.size(); ++k) {
-    for (size_t j = 0; j < kFftLengthBy2; j += 4) {
-      const float32x4_t re = vld1q_f32(&H[k].re[j]);
-      const float32x4_t im = vld1q_f32(&H[k].im[j]);
-      float32x4_t H2_k_j = vmulq_f32(re, re);
-      H2_k_j = vmlaq_f32(H2_k_j, im, im);
-      vst1q_f32(&(*H2)[k][j], H2_k_j);
+  for (auto& H2_ch : *H2) {
+    H2_ch.fill(0.f);
+  }
+
+  const size_t num_render_channels = H[0].size();
+  RTC_DCHECK_EQ(H.size(), H2->capacity());
+  for (size_t p = 0; p < num_partitions; ++p) {
+    RTC_DCHECK_EQ(kFftLengthBy2Plus1, (*H2)[p].size());
+    for (size_t ch = 0; ch < num_render_channels; ++ch) {
+      for (size_t j = 0; j < kFftLengthBy2; j += 4) {
+        const float32x4_t re = vld1q_f32(&H[p][ch].re[j]);
+        const float32x4_t im = vld1q_f32(&H[p][ch].im[j]);
+        float32x4_t H2_new = vmulq_f32(re, re);
+        H2_new = vmlaq_f32(H2_new, im, im);
+        float32x4_t H2_p_j = vld1q_f32(&(*H2)[p][j]);
+        H2_p_j = vmaxq_f32(H2_p_j, H2_new);
+        vst1q_f32(&(*H2)[p][j], H2_p_j);
+      }
+      float H2_new = H[p][ch].re[kFftLengthBy2] * H[p][ch].re[kFftLengthBy2] +
+                     H[p][ch].im[kFftLengthBy2] * H[p][ch].im[kFftLengthBy2];
+      (*H2)[p][kFftLengthBy2] = std::max((*H2)[p][kFftLengthBy2], H2_new);
     }
-    (*H2)[k][kFftLengthBy2] = H[k].re[kFftLengthBy2] * H[k].re[kFftLengthBy2] +
-                              H[k].im[kFftLengthBy2] * H[k].im[kFftLengthBy2];
   }
 }
 #endif
 
 #if defined(WEBRTC_ARCH_X86_FAMILY)
 // Computes and stores the frequency response of the filter.
-void UpdateFrequencyResponse_SSE2(
-    rtc::ArrayView<const FftData> H,
+void ComputeFrequencyResponse_Sse2(
+    size_t num_partitions,
+    const std::vector<std::vector<FftData>>& H,
     std::vector<std::array<float, kFftLengthBy2Plus1>>* H2) {
-  RTC_DCHECK_EQ(H.size(), H2->size());
-  for (size_t k = 0; k < H.size(); ++k) {
-    for (size_t j = 0; j < kFftLengthBy2; j += 4) {
-      const __m128 re = _mm_loadu_ps(&H[k].re[j]);
-      const __m128 re2 = _mm_mul_ps(re, re);
-      const __m128 im = _mm_loadu_ps(&H[k].im[j]);
-      const __m128 im2 = _mm_mul_ps(im, im);
-      const __m128 H2_k_j = _mm_add_ps(re2, im2);
-      _mm_storeu_ps(&(*H2)[k][j], H2_k_j);
+  for (auto& H2_ch : *H2) {
+    H2_ch.fill(0.f);
+  }
+
+  const size_t num_render_channels = H[0].size();
+  RTC_DCHECK_EQ(H.size(), H2->capacity());
+  // constexpr __mmmask8 kMaxMask = static_cast<__mmmask8>(256u);
+  for (size_t p = 0; p < num_partitions; ++p) {
+    RTC_DCHECK_EQ(kFftLengthBy2Plus1, (*H2)[p].size());
+    for (size_t ch = 0; ch < num_render_channels; ++ch) {
+      for (size_t j = 0; j < kFftLengthBy2; j += 4) {
+        const __m128 re = _mm_loadu_ps(&H[p][ch].re[j]);
+        const __m128 re2 = _mm_mul_ps(re, re);
+        const __m128 im = _mm_loadu_ps(&H[p][ch].im[j]);
+        const __m128 im2 = _mm_mul_ps(im, im);
+        const __m128 H2_new = _mm_add_ps(re2, im2);
+        __m128 H2_k_j = _mm_loadu_ps(&(*H2)[p][j]);
+        H2_k_j = _mm_max_ps(H2_k_j, H2_new);
+        _mm_storeu_ps(&(*H2)[p][j], H2_k_j);
+      }
+      float H2_new = H[p][ch].re[kFftLengthBy2] * H[p][ch].re[kFftLengthBy2] +
+                     H[p][ch].im[kFftLengthBy2] * H[p][ch].im[kFftLengthBy2];
+      (*H2)[p][kFftLengthBy2] = std::max((*H2)[p][kFftLengthBy2], H2_new);
     }
-    (*H2)[k][kFftLengthBy2] = H[k].re[kFftLengthBy2] * H[k].re[kFftLengthBy2] +
-                              H[k].im[kFftLengthBy2] * H[k].im[kFftLengthBy2];
   }
 }
 #endif
 
-
 // Adapts the filter partitions as H(t+1)=H(t)+G(t)*conj(X(t)).
 void AdaptPartitions(const RenderBuffer& render_buffer,
                      const FftData& G,
-                     rtc::ArrayView<FftData> H) {
+                     size_t num_partitions,
+                     std::vector<std::vector<FftData>>* H) {
   rtc::ArrayView<const std::vector<FftData>> render_buffer_data =
       render_buffer.GetFftBuffer();
   size_t index = render_buffer.Position();
-  for (auto& H_j : H) {
-    const FftData& X = render_buffer_data[index][/*channel=*/0];
-    for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) {
-      H_j.re[k] += X.re[k] * G.re[k] + X.im[k] * G.im[k];
-      H_j.im[k] += X.re[k] * G.im[k] - X.im[k] * G.re[k];
+  const size_t num_render_channels = render_buffer_data[index].size();
+  for (size_t p = 0; p < num_partitions; ++p) {
+    for (size_t ch = 0; ch < num_render_channels; ++ch) {
+      const FftData& X_p_ch = render_buffer_data[index][ch];
+      FftData& H_p_ch = (*H)[p][ch];
+      for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) {
+        H_p_ch.re[k] += X_p_ch.re[k] * G.re[k] + X_p_ch.im[k] * G.im[k];
+        H_p_ch.im[k] += X_p_ch.re[k] * G.im[k] - X_p_ch.im[k] * G.re[k];
+      }
     }
-
     index = index < (render_buffer_data.size() - 1) ? index + 1 : 0;
   }
 }
 
 #if defined(WEBRTC_HAS_NEON)
-// Adapts the filter partitions. (NEON variant)
-void AdaptPartitions_NEON(const RenderBuffer& render_buffer,
+// Adapts the filter partitions. (Neon variant)
+void AdaptPartitions_Neon(const RenderBuffer& render_buffer,
                           const FftData& G,
-                          rtc::ArrayView<FftData> H) {
+                          size_t num_partitions,
+                          std::vector<std::vector<FftData>>* H) {
   rtc::ArrayView<const std::vector<FftData>> render_buffer_data =
       render_buffer.GetFftBuffer();
-  const int lim1 =
-      std::min(render_buffer_data.size() - render_buffer.Position(), H.size());
-  const int lim2 = H.size();
-  constexpr int kNumFourBinBands = kFftLengthBy2 / 4;
-  FftData* H_j = &H[0];
-  const std::vector<FftData>* X_channels =
-      &render_buffer_data[render_buffer.Position()];
-  int limit = lim1;
-  int j = 0;
-  do {
-    for (; j < limit; ++j, ++H_j, ++X_channels) {
-      const FftData& X = (*X_channels)[/*channel=*/0];
-      for (int k = 0, n = 0; n < kNumFourBinBands; ++n, k += 4) {
-        const float32x4_t G_re = vld1q_f32(&G.re[k]);
-        const float32x4_t G_im = vld1q_f32(&G.im[k]);
-        const float32x4_t X_re = vld1q_f32(&X.re[k]);
-        const float32x4_t X_im = vld1q_f32(&X.im[k]);
-        const float32x4_t H_re = vld1q_f32(&H_j->re[k]);
-        const float32x4_t H_im = vld1q_f32(&H_j->im[k]);
-        const float32x4_t a = vmulq_f32(X_re, G_re);
-        const float32x4_t e = vmlaq_f32(a, X_im, G_im);
-        const float32x4_t c = vmulq_f32(X_re, G_im);
-        const float32x4_t f = vmlsq_f32(c, X_im, G_re);
-        const float32x4_t g = vaddq_f32(H_re, e);
-        const float32x4_t h = vaddq_f32(H_im, f);
+  const size_t num_render_channels = render_buffer_data[0].size();
+  const size_t lim1 = std::min(
+      render_buffer_data.size() - render_buffer.Position(), num_partitions);
+  const size_t lim2 = num_partitions;
+  constexpr size_t kNumFourBinBands = kFftLengthBy2 / 4;
 
-        vst1q_f32(&H_j->re[k], g);
-        vst1q_f32(&H_j->im[k], h);
+  size_t X_partition = render_buffer.Position();
+  size_t limit = lim1;
+  size_t p = 0;
+  do {
+    for (; p < limit; ++p, ++X_partition) {
+      for (size_t ch = 0; ch < num_render_channels; ++ch) {
+        FftData& H_p_ch = (*H)[p][ch];
+        const FftData& X = render_buffer_data[X_partition][ch];
+        for (size_t k = 0, n = 0; n < kNumFourBinBands; ++n, k += 4) {
+          const float32x4_t G_re = vld1q_f32(&G.re[k]);
+          const float32x4_t G_im = vld1q_f32(&G.im[k]);
+          const float32x4_t X_re = vld1q_f32(&X.re[k]);
+          const float32x4_t X_im = vld1q_f32(&X.im[k]);
+          const float32x4_t H_re = vld1q_f32(&H_p_ch.re[k]);
+          const float32x4_t H_im = vld1q_f32(&H_p_ch.im[k]);
+          const float32x4_t a = vmulq_f32(X_re, G_re);
+          const float32x4_t e = vmlaq_f32(a, X_im, G_im);
+          const float32x4_t c = vmulq_f32(X_re, G_im);
+          const float32x4_t f = vmlsq_f32(c, X_im, G_re);
+          const float32x4_t g = vaddq_f32(H_re, e);
+          const float32x4_t h = vaddq_f32(H_im, f);
+          vst1q_f32(&H_p_ch.re[k], g);
+          vst1q_f32(&H_p_ch.im[k], h);
+        }
       }
     }
 
-    X_channels = &render_buffer_data[0];
+    X_partition = 0;
     limit = lim2;
-  } while (j < lim2);
+  } while (p < lim2);
 
-  H_j = &H[0];
-  X_channels = &render_buffer_data[render_buffer.Position()];
+  X_partition = render_buffer.Position();
   limit = lim1;
-  j = 0;
+  p = 0;
   do {
-    for (; j < limit; ++j, ++H_j, ++X_channels) {
-      const FftData& X = (*X_channels)[/*channel=*/0];
-      H_j->re[kFftLengthBy2] += X.re[kFftLengthBy2] * G.re[kFftLengthBy2] +
-                                X.im[kFftLengthBy2] * G.im[kFftLengthBy2];
-      H_j->im[kFftLengthBy2] += X.re[kFftLengthBy2] * G.im[kFftLengthBy2] -
-                                X.im[kFftLengthBy2] * G.re[kFftLengthBy2];
-    }
+    for (; p < limit; ++p, ++X_partition) {
+      for (size_t ch = 0; ch < num_render_channels; ++ch) {
+        FftData& H_p_ch = (*H)[p][ch];
+        const FftData& X = render_buffer_data[X_partition][ch];
 
-    X_channels = &render_buffer_data[0];
+        H_p_ch.re[kFftLengthBy2] += X.re[kFftLengthBy2] * G.re[kFftLengthBy2] +
+                                    X.im[kFftLengthBy2] * G.im[kFftLengthBy2];
+        H_p_ch.im[kFftLengthBy2] += X.re[kFftLengthBy2] * G.im[kFftLengthBy2] -
+                                    X.im[kFftLengthBy2] * G.re[kFftLengthBy2];
+      }
+    }
+    X_partition = 0;
     limit = lim2;
-  } while (j < lim2);
+  } while (p < lim2);
 }
 #endif
 
 #if defined(WEBRTC_ARCH_X86_FAMILY)
 // Adapts the filter partitions. (SSE2 variant)
-void AdaptPartitions_SSE2(const RenderBuffer& render_buffer,
+void AdaptPartitions_Sse2(const RenderBuffer& render_buffer,
                           const FftData& G,
-                          rtc::ArrayView<FftData> H) {
+                          size_t num_partitions,
+                          std::vector<std::vector<FftData>>* H) {
   rtc::ArrayView<const std::vector<FftData>> render_buffer_data =
       render_buffer.GetFftBuffer();
-  const int lim1 =
-      std::min(render_buffer_data.size() - render_buffer.Position(), H.size());
-  const int lim2 = H.size();
-  constexpr int kNumFourBinBands = kFftLengthBy2 / 4;
-  FftData* H_j;
-  const std::vector<FftData>* X_channels;
-  int limit;
-  int j;
-  for (int k = 0, n = 0; n < kNumFourBinBands; ++n, k += 4) {
-    const __m128 G_re = _mm_loadu_ps(&G.re[k]);
-    const __m128 G_im = _mm_loadu_ps(&G.im[k]);
+  const size_t num_render_channels = render_buffer_data[0].size();
+  const size_t lim1 = std::min(
+      render_buffer_data.size() - render_buffer.Position(), num_partitions);
+  const size_t lim2 = num_partitions;
+  constexpr size_t kNumFourBinBands = kFftLengthBy2 / 4;
 
-    H_j = &H[0];
-    X_channels = &render_buffer_data[render_buffer.Position()];
-    limit = lim1;
-    j = 0;
-    do {
-      for (; j < limit; ++j, ++H_j, ++X_channels) {
-        const FftData& X = (*X_channels)[/*channel=*/0];
-        const __m128 X_re = _mm_loadu_ps(&X.re[k]);
-        const __m128 X_im = _mm_loadu_ps(&X.im[k]);
-        const __m128 H_re = _mm_loadu_ps(&H_j->re[k]);
-        const __m128 H_im = _mm_loadu_ps(&H_j->im[k]);
-        const __m128 a = _mm_mul_ps(X_re, G_re);
-        const __m128 b = _mm_mul_ps(X_im, G_im);
-        const __m128 c = _mm_mul_ps(X_re, G_im);
-        const __m128 d = _mm_mul_ps(X_im, G_re);
-        const __m128 e = _mm_add_ps(a, b);
-        const __m128 f = _mm_sub_ps(c, d);
-        const __m128 g = _mm_add_ps(H_re, e);
-        const __m128 h = _mm_add_ps(H_im, f);
-        _mm_storeu_ps(&H_j->re[k], g);
-        _mm_storeu_ps(&H_j->im[k], h);
-      }
-
-      X_channels = &render_buffer_data[0];
-      limit = lim2;
-    } while (j < lim2);
-  }
-
-  H_j = &H[0];
-  X_channels = &render_buffer_data[render_buffer.Position()];
-  limit = lim1;
-  j = 0;
+  size_t X_partition = render_buffer.Position();
+  size_t limit = lim1;
+  size_t p = 0;
   do {
-    for (; j < limit; ++j, ++H_j, ++X_channels) {
-      const FftData& X = (*X_channels)[/*channel=*/0];
-      H_j->re[kFftLengthBy2] += X.re[kFftLengthBy2] * G.re[kFftLengthBy2] +
-                                X.im[kFftLengthBy2] * G.im[kFftLengthBy2];
-      H_j->im[kFftLengthBy2] += X.re[kFftLengthBy2] * G.im[kFftLengthBy2] -
-                                X.im[kFftLengthBy2] * G.re[kFftLengthBy2];
+    for (; p < limit; ++p, ++X_partition) {
+      for (size_t ch = 0; ch < num_render_channels; ++ch) {
+        FftData& H_p_ch = (*H)[p][ch];
+        const FftData& X = render_buffer_data[X_partition][ch];
+
+        for (size_t k = 0, n = 0; n < kNumFourBinBands; ++n, k += 4) {
+          const __m128 G_re = _mm_loadu_ps(&G.re[k]);
+          const __m128 G_im = _mm_loadu_ps(&G.im[k]);
+          const __m128 X_re = _mm_loadu_ps(&X.re[k]);
+          const __m128 X_im = _mm_loadu_ps(&X.im[k]);
+          const __m128 H_re = _mm_loadu_ps(&H_p_ch.re[k]);
+          const __m128 H_im = _mm_loadu_ps(&H_p_ch.im[k]);
+          const __m128 a = _mm_mul_ps(X_re, G_re);
+          const __m128 b = _mm_mul_ps(X_im, G_im);
+          const __m128 c = _mm_mul_ps(X_re, G_im);
+          const __m128 d = _mm_mul_ps(X_im, G_re);
+          const __m128 e = _mm_add_ps(a, b);
+          const __m128 f = _mm_sub_ps(c, d);
+          const __m128 g = _mm_add_ps(H_re, e);
+          const __m128 h = _mm_add_ps(H_im, f);
+          _mm_storeu_ps(&H_p_ch.re[k], g);
+          _mm_storeu_ps(&H_p_ch.im[k], h);
+        }
+      }
+    }
+    X_partition = 0;
+    limit = lim2;
+  } while (p < lim2);
+
+  X_partition = render_buffer.Position();
+  limit = lim1;
+  p = 0;
+  do {
+    for (; p < limit; ++p, ++X_partition) {
+      for (size_t ch = 0; ch < num_render_channels; ++ch) {
+        FftData& H_p_ch = (*H)[p][ch];
+        const FftData& X = render_buffer_data[X_partition][ch];
+
+        H_p_ch.re[kFftLengthBy2] += X.re[kFftLengthBy2] * G.re[kFftLengthBy2] +
+                                    X.im[kFftLengthBy2] * G.im[kFftLengthBy2];
+        H_p_ch.im[kFftLengthBy2] += X.re[kFftLengthBy2] * G.im[kFftLengthBy2] -
+                                    X.im[kFftLengthBy2] * G.re[kFftLengthBy2];
+      }
     }
 
-    X_channels = &render_buffer_data[0];
+    X_partition = 0;
     limit = lim2;
-  } while (j < lim2);
+  } while (p < lim2);
 }
 #endif
 
 // Produces the filter output.
 void ApplyFilter(const RenderBuffer& render_buffer,
-                 rtc::ArrayView<const FftData> H,
+                 size_t num_partitions,
+                 const std::vector<std::vector<FftData>>& H,
                  FftData* S) {
   S->re.fill(0.f);
   S->im.fill(0.f);
@@ -238,184 +288,219 @@
   rtc::ArrayView<const std::vector<FftData>> render_buffer_data =
       render_buffer.GetFftBuffer();
   size_t index = render_buffer.Position();
-  for (auto& H_j : H) {
-    const FftData& X = render_buffer_data[index][0];
-    for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) {
-      S->re[k] += X.re[k] * H_j.re[k] - X.im[k] * H_j.im[k];
-      S->im[k] += X.re[k] * H_j.im[k] + X.im[k] * H_j.re[k];
+  const size_t num_render_channels = render_buffer_data[index].size();
+  for (size_t p = 0; p < num_partitions; ++p) {
+    RTC_DCHECK_EQ(num_render_channels, H[p].size());
+    for (size_t ch = 0; ch < num_render_channels; ++ch) {
+      const FftData& X_p_ch = render_buffer_data[index][ch];
+      const FftData& H_p_ch = H[p][ch];
+      for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) {
+        S->re[k] += X_p_ch.re[k] * H_p_ch.re[k] - X_p_ch.im[k] * H_p_ch.im[k];
+        S->im[k] += X_p_ch.re[k] * H_p_ch.im[k] + X_p_ch.im[k] * H_p_ch.re[k];
+      }
     }
     index = index < (render_buffer_data.size() - 1) ? index + 1 : 0;
   }
 }
 
 #if defined(WEBRTC_HAS_NEON)
-// Produces the filter output (NEON variant).
-void ApplyFilter_NEON(const RenderBuffer& render_buffer,
-                      rtc::ArrayView<const FftData> H,
+// Produces the filter output (Neon variant).
+void ApplyFilter_Neon(const RenderBuffer& render_buffer,
+                      size_t num_partitions,
+                      const std::vector<std::vector<FftData>>& H,
                       FftData* S) {
+  // const RenderBuffer& render_buffer,
+  //                     rtc::ArrayView<const FftData> H,
+  //                     FftData* S) {
   RTC_DCHECK_GE(H.size(), H.size() - 1);
   S->Clear();
 
   rtc::ArrayView<const std::vector<FftData>> render_buffer_data =
       render_buffer.GetFftBuffer();
-  const int lim1 =
-      std::min(render_buffer_data.size() - render_buffer.Position(), H.size());
-  const int lim2 = H.size();
-  constexpr int kNumFourBinBands = kFftLengthBy2 / 4;
-  const FftData* H_j = &H[0];
-  const std::vector<FftData>* X_channels =
-      &render_buffer_data[render_buffer.Position()];
+  const size_t num_render_channels = render_buffer_data[0].size();
+  const size_t lim1 = std::min(
+      render_buffer_data.size() - render_buffer.Position(), num_partitions);
+  const size_t lim2 = num_partitions;
+  constexpr size_t kNumFourBinBands = kFftLengthBy2 / 4;
 
-  int j = 0;
-  int limit = lim1;
+  size_t X_partition = render_buffer.Position();
+  size_t p = 0;
+  size_t limit = lim1;
   do {
-    for (; j < limit; ++j, ++H_j, ++X_channels) {
-      const FftData& X = (*X_channels)[/*channel=*/0];
-      for (int k = 0, n = 0; n < kNumFourBinBands; ++n, k += 4) {
-        const float32x4_t X_re = vld1q_f32(&X.re[k]);
-        const float32x4_t X_im = vld1q_f32(&X.im[k]);
-        const float32x4_t H_re = vld1q_f32(&H_j->re[k]);
-        const float32x4_t H_im = vld1q_f32(&H_j->im[k]);
-        const float32x4_t S_re = vld1q_f32(&S->re[k]);
-        const float32x4_t S_im = vld1q_f32(&S->im[k]);
-        const float32x4_t a = vmulq_f32(X_re, H_re);
-        const float32x4_t e = vmlsq_f32(a, X_im, H_im);
-        const float32x4_t c = vmulq_f32(X_re, H_im);
-        const float32x4_t f = vmlaq_f32(c, X_im, H_re);
-        const float32x4_t g = vaddq_f32(S_re, e);
-        const float32x4_t h = vaddq_f32(S_im, f);
-        vst1q_f32(&S->re[k], g);
-        vst1q_f32(&S->im[k], h);
+    for (; p < limit; ++p, ++X_partition) {
+      for (size_t ch = 0; ch < num_render_channels; ++ch) {
+        const FftData& H_p_ch = H[p][ch];
+        const FftData& X = render_buffer_data[X_partition][ch];
+        for (size_t k = 0, n = 0; n < kNumFourBinBands; ++n, k += 4) {
+          const float32x4_t X_re = vld1q_f32(&X.re[k]);
+          const float32x4_t X_im = vld1q_f32(&X.im[k]);
+          const float32x4_t H_re = vld1q_f32(&H_p_ch.re[k]);
+          const float32x4_t H_im = vld1q_f32(&H_p_ch.im[k]);
+          const float32x4_t S_re = vld1q_f32(&S->re[k]);
+          const float32x4_t S_im = vld1q_f32(&S->im[k]);
+          const float32x4_t a = vmulq_f32(X_re, H_re);
+          const float32x4_t e = vmlsq_f32(a, X_im, H_im);
+          const float32x4_t c = vmulq_f32(X_re, H_im);
+          const float32x4_t f = vmlaq_f32(c, X_im, H_re);
+          const float32x4_t g = vaddq_f32(S_re, e);
+          const float32x4_t h = vaddq_f32(S_im, f);
+          vst1q_f32(&S->re[k], g);
+          vst1q_f32(&S->im[k], h);
+        }
       }
     }
     limit = lim2;
-    X_channels = &render_buffer_data[0];
-  } while (j < lim2);
+    X_partition = 0;
+  } while (p < lim2);
 
-  H_j = &H[0];
-  X_channels = &render_buffer_data[render_buffer.Position()];
-  j = 0;
+  X_partition = render_buffer.Position();
+  p = 0;
   limit = lim1;
   do {
-    for (; j < limit; ++j, ++H_j, ++X_channels) {
-      const FftData& X = (*X_channels)[/*channel=*/0];
-      S->re[kFftLengthBy2] += X.re[kFftLengthBy2] * H_j->re[kFftLengthBy2] -
-                              X.im[kFftLengthBy2] * H_j->im[kFftLengthBy2];
-      S->im[kFftLengthBy2] += X.re[kFftLengthBy2] * H_j->im[kFftLengthBy2] +
-                              X.im[kFftLengthBy2] * H_j->re[kFftLengthBy2];
+    for (; p < limit; ++p, ++X_partition) {
+      for (size_t ch = 0; ch < num_render_channels; ++ch) {
+        const FftData& H_p_ch = H[p][ch];
+        const FftData& X = render_buffer_data[X_partition][ch];
+        S->re[kFftLengthBy2] += X.re[kFftLengthBy2] * H_p_ch.re[kFftLengthBy2] -
+                                X.im[kFftLengthBy2] * H_p_ch.im[kFftLengthBy2];
+        S->im[kFftLengthBy2] += X.re[kFftLengthBy2] * H_p_ch.im[kFftLengthBy2] +
+                                X.im[kFftLengthBy2] * H_p_ch.re[kFftLengthBy2];
+      }
     }
     limit = lim2;
-    X_channels = &render_buffer_data[0];
-  } while (j < lim2);
+    X_partition = 0;
+  } while (p < lim2);
 }
 #endif
 
 #if defined(WEBRTC_ARCH_X86_FAMILY)
 // Produces the filter output (SSE2 variant).
-void ApplyFilter_SSE2(const RenderBuffer& render_buffer,
-                      rtc::ArrayView<const FftData> H,
+void ApplyFilter_Sse2(const RenderBuffer& render_buffer,
+                      size_t num_partitions,
+                      const std::vector<std::vector<FftData>>& H,
                       FftData* S) {
+  // const RenderBuffer& render_buffer,
+  //                     rtc::ArrayView<const FftData> H,
+  //                     FftData* S) {
   RTC_DCHECK_GE(H.size(), H.size() - 1);
   S->re.fill(0.f);
   S->im.fill(0.f);
 
   rtc::ArrayView<const std::vector<FftData>> render_buffer_data =
       render_buffer.GetFftBuffer();
-  const int lim1 =
-      std::min(render_buffer_data.size() - render_buffer.Position(), H.size());
-  const int lim2 = H.size();
-  constexpr int kNumFourBinBands = kFftLengthBy2 / 4;
-  const FftData* H_j = &H[0];
-  const std::vector<FftData>* X_channels =
-      &render_buffer_data[render_buffer.Position()];
+  const size_t num_render_channels = render_buffer_data[0].size();
+  const size_t lim1 = std::min(
+      render_buffer_data.size() - render_buffer.Position(), num_partitions);
+  const size_t lim2 = num_partitions;
+  constexpr size_t kNumFourBinBands = kFftLengthBy2 / 4;
 
-  int j = 0;
-  int limit = lim1;
+  size_t X_partition = render_buffer.Position();
+  size_t p = 0;
+  size_t limit = lim1;
   do {
-    for (; j < limit; ++j, ++H_j, ++X_channels) {
-      const FftData& X = (*X_channels)[/*channel=*/0];
-      for (int k = 0, n = 0; n < kNumFourBinBands; ++n, k += 4) {
-        const __m128 X_re = _mm_loadu_ps(&X.re[k]);
-        const __m128 X_im = _mm_loadu_ps(&X.im[k]);
-        const __m128 H_re = _mm_loadu_ps(&H_j->re[k]);
-        const __m128 H_im = _mm_loadu_ps(&H_j->im[k]);
-        const __m128 S_re = _mm_loadu_ps(&S->re[k]);
-        const __m128 S_im = _mm_loadu_ps(&S->im[k]);
-        const __m128 a = _mm_mul_ps(X_re, H_re);
-        const __m128 b = _mm_mul_ps(X_im, H_im);
-        const __m128 c = _mm_mul_ps(X_re, H_im);
-        const __m128 d = _mm_mul_ps(X_im, H_re);
-        const __m128 e = _mm_sub_ps(a, b);
-        const __m128 f = _mm_add_ps(c, d);
-        const __m128 g = _mm_add_ps(S_re, e);
-        const __m128 h = _mm_add_ps(S_im, f);
-        _mm_storeu_ps(&S->re[k], g);
-        _mm_storeu_ps(&S->im[k], h);
+    for (; p < limit; ++p, ++X_partition) {
+      for (size_t ch = 0; ch < num_render_channels; ++ch) {
+        const FftData& H_p_ch = H[p][ch];
+        const FftData& X = render_buffer_data[X_partition][ch];
+        for (size_t k = 0, n = 0; n < kNumFourBinBands; ++n, k += 4) {
+          const __m128 X_re = _mm_loadu_ps(&X.re[k]);
+          const __m128 X_im = _mm_loadu_ps(&X.im[k]);
+          const __m128 H_re = _mm_loadu_ps(&H_p_ch.re[k]);
+          const __m128 H_im = _mm_loadu_ps(&H_p_ch.im[k]);
+          const __m128 S_re = _mm_loadu_ps(&S->re[k]);
+          const __m128 S_im = _mm_loadu_ps(&S->im[k]);
+          const __m128 a = _mm_mul_ps(X_re, H_re);
+          const __m128 b = _mm_mul_ps(X_im, H_im);
+          const __m128 c = _mm_mul_ps(X_re, H_im);
+          const __m128 d = _mm_mul_ps(X_im, H_re);
+          const __m128 e = _mm_sub_ps(a, b);
+          const __m128 f = _mm_add_ps(c, d);
+          const __m128 g = _mm_add_ps(S_re, e);
+          const __m128 h = _mm_add_ps(S_im, f);
+          _mm_storeu_ps(&S->re[k], g);
+          _mm_storeu_ps(&S->im[k], h);
+        }
       }
     }
     limit = lim2;
-    X_channels = &render_buffer_data[0];
-  } while (j < lim2);
+    X_partition = 0;
+  } while (p < lim2);
 
-  H_j = &H[0];
-  X_channels = &render_buffer_data[render_buffer.Position()];
-  j = 0;
+  X_partition = render_buffer.Position();
+  p = 0;
   limit = lim1;
   do {
-    for (; j < limit; ++j, ++H_j, ++X_channels) {
-      const FftData& X = (*X_channels)[/*channel=*/0];
-      S->re[kFftLengthBy2] += X.re[kFftLengthBy2] * H_j->re[kFftLengthBy2] -
-                              X.im[kFftLengthBy2] * H_j->im[kFftLengthBy2];
-      S->im[kFftLengthBy2] += X.re[kFftLengthBy2] * H_j->im[kFftLengthBy2] +
-                              X.im[kFftLengthBy2] * H_j->re[kFftLengthBy2];
+    for (; p < limit; ++p, ++X_partition) {
+      for (size_t ch = 0; ch < num_render_channels; ++ch) {
+        const FftData& H_p_ch = H[p][ch];
+        const FftData& X = render_buffer_data[X_partition][ch];
+        S->re[kFftLengthBy2] += X.re[kFftLengthBy2] * H_p_ch.re[kFftLengthBy2] -
+                                X.im[kFftLengthBy2] * H_p_ch.im[kFftLengthBy2];
+        S->im[kFftLengthBy2] += X.re[kFftLengthBy2] * H_p_ch.im[kFftLengthBy2] +
+                                X.im[kFftLengthBy2] * H_p_ch.re[kFftLengthBy2];
+      }
     }
     limit = lim2;
-    X_channels = &render_buffer_data[0];
-  } while (j < lim2);
+    X_partition = 0;
+  } while (p < lim2);
 }
 #endif
 
 }  // namespace aec3
 
+namespace {
+
+// Ensures that the newly added filter partitions after a size increase are set
+// to zero.
+void ZeroFilter(size_t old_size,
+                size_t new_size,
+                std::vector<std::vector<FftData>>* H) {
+  RTC_DCHECK_GE(H->size(), old_size);
+  RTC_DCHECK_GE(H->size(), new_size);
+
+  for (size_t p = old_size; p < new_size; ++p) {
+    RTC_DCHECK_EQ((*H)[p].size(), (*H)[0].size());
+    for (size_t ch = 0; ch < (*H)[0].size(); ++ch) {
+      (*H)[p][ch].Clear();
+    }
+  }
+}
+
+}  // namespace
+
 AdaptiveFirFilter::AdaptiveFirFilter(size_t max_size_partitions,
                                      size_t initial_size_partitions,
                                      size_t size_change_duration_blocks,
                                      size_t num_render_channels,
-                                     size_t num_capture_channels,
                                      Aec3Optimization optimization,
                                      ApmDataDumper* data_dumper)
     : data_dumper_(data_dumper),
       fft_(),
       optimization_(optimization),
+      num_render_channels_(num_render_channels),
       max_size_partitions_(max_size_partitions),
       size_change_duration_blocks_(
           static_cast<int>(size_change_duration_blocks)),
       current_size_partitions_(initial_size_partitions),
       target_size_partitions_(initial_size_partitions),
       old_target_size_partitions_(initial_size_partitions),
-      H_(max_size_partitions_) {
+      H_(max_size_partitions_, std::vector<FftData>(num_render_channels_)) {
   RTC_DCHECK(data_dumper_);
   RTC_DCHECK_GE(max_size_partitions, initial_size_partitions);
 
   RTC_DCHECK_LT(0, size_change_duration_blocks_);
   one_by_size_change_duration_blocks_ = 1.f / size_change_duration_blocks_;
 
-  for (auto& H_j : H_) {
-    H_j.Clear();
-  }
+  ZeroFilter(0, max_size_partitions_, &H_);
+
   SetSizePartitions(current_size_partitions_, true);
 }
 
 AdaptiveFirFilter::~AdaptiveFirFilter() = default;
 
 void AdaptiveFirFilter::HandleEchoPathChange() {
-  size_t current_size_partitions = H_.size();
-  H_.resize(max_size_partitions_);
-
-  for (size_t k = current_size_partitions; k < max_size_partitions_; ++k) {
-    H_[k].Clear();
-  }
-  H_.resize(current_size_partitions);
+  // TODO(peah): Check the value and purpose of the code below.
+  ZeroFilter(current_size_partitions_, max_size_partitions_, &H_);
 }
 
 void AdaptiveFirFilter::SetSizePartitions(size_t size, bool immediate_effect) {
@@ -424,24 +509,22 @@
 
   target_size_partitions_ = std::min(max_size_partitions_, size);
   if (immediate_effect) {
+    size_t old_size_partitions_ = current_size_partitions_;
     current_size_partitions_ = old_target_size_partitions_ =
         target_size_partitions_;
-    ResetFilterBuffersToCurrentSize();
+    ZeroFilter(old_size_partitions_, current_size_partitions_, &H_);
+
+    partition_to_constrain_ =
+        std::min(partition_to_constrain_, current_size_partitions_ - 1);
     size_change_counter_ = 0;
   } else {
     size_change_counter_ = size_change_duration_blocks_;
   }
 }
 
-void AdaptiveFirFilter::ResetFilterBuffersToCurrentSize() {
-  H_.resize(current_size_partitions_);
-  RTC_DCHECK_LT(0, current_size_partitions_);
-  partition_to_constrain_ =
-      std::min(partition_to_constrain_, current_size_partitions_ - 1);
-}
-
 void AdaptiveFirFilter::UpdateSize() {
   RTC_DCHECK_GE(size_change_duration_blocks_, size_change_counter_);
+  size_t old_size_partitions_ = current_size_partitions_;
   if (size_change_counter_ > 0) {
     --size_change_counter_;
 
@@ -455,11 +538,13 @@
     current_size_partitions_ = average(old_target_size_partitions_,
                                        target_size_partitions_, change_factor);
 
-    ResetFilterBuffersToCurrentSize();
+    partition_to_constrain_ =
+        std::min(partition_to_constrain_, current_size_partitions_ - 1);
   } else {
     current_size_partitions_ = old_target_size_partitions_ =
         target_size_partitions_;
   }
+  ZeroFilter(old_size_partitions_, current_size_partitions_, &H_);
   RTC_DCHECK_LE(0, size_change_counter_);
 }
 
@@ -469,16 +554,16 @@
   switch (optimization_) {
 #if defined(WEBRTC_ARCH_X86_FAMILY)
     case Aec3Optimization::kSse2:
-      aec3::ApplyFilter_SSE2(render_buffer, H_, S);
+      aec3::ApplyFilter_Sse2(render_buffer, current_size_partitions_, H_, S);
       break;
 #endif
 #if defined(WEBRTC_HAS_NEON)
     case Aec3Optimization::kNeon:
-      aec3::ApplyFilter_NEON(render_buffer, H_, S);
+      aec3::ApplyFilter_Neon(render_buffer, current_size_partitions_, H_, S);
       break;
 #endif
     default:
-      aec3::ApplyFilter(render_buffer, H_, S);
+      aec3::ApplyFilter(render_buffer, current_size_partitions_, H_, S);
   }
 }
 
@@ -503,28 +588,23 @@
 
 void AdaptiveFirFilter::ComputeFrequencyResponse(
     std::vector<std::array<float, kFftLengthBy2Plus1>>* H2) const {
-  RTC_DCHECK_EQ(max_size_partitions_, H2->capacity());
+  RTC_DCHECK_GE(max_size_partitions_, H2->capacity());
 
-  if (H2->size() > H_.size()) {
-    for (size_t k = H_.size(); k < H2->size(); ++k) {
-      (*H2)[k].fill(0.f);
-    }
-  }
-  H2->resize(H_.size());
+  H2->resize(current_size_partitions_);
 
   switch (optimization_) {
 #if defined(WEBRTC_ARCH_X86_FAMILY)
     case Aec3Optimization::kSse2:
-      aec3::UpdateFrequencyResponse_SSE2(H_, H2);
+      aec3::ComputeFrequencyResponse_Sse2(current_size_partitions_, H_, H2);
       break;
 #endif
 #if defined(WEBRTC_HAS_NEON)
     case Aec3Optimization::kNeon:
-      aec3::UpdateFrequencyResponse_NEON(H_, H2);
+      aec3::ComputeFrequencyResponse_Neon(current_size_partitions_, H_, H2);
       break;
 #endif
     default:
-      aec3::UpdateFrequencyResponse(H_, H2);
+      aec3::ComputeFrequencyResponse(current_size_partitions_, H_, H2);
   }
 }
 
@@ -537,16 +617,18 @@
   switch (optimization_) {
 #if defined(WEBRTC_ARCH_X86_FAMILY)
     case Aec3Optimization::kSse2:
-      aec3::AdaptPartitions_SSE2(render_buffer, G, H_);
+      aec3::AdaptPartitions_Sse2(render_buffer, G, current_size_partitions_,
+                                 &H_);
       break;
 #endif
 #if defined(WEBRTC_HAS_NEON)
     case Aec3Optimization::kNeon:
-      aec3::AdaptPartitions_NEON(render_buffer, G, H_);
+      aec3::AdaptPartitions_Neon(render_buffer, G, current_size_partitions_,
+                                 &H_);
       break;
 #endif
     default:
-      aec3::AdaptPartitions(render_buffer, G, H_);
+      aec3::AdaptPartitions(render_buffer, G, current_size_partitions_, &H_);
   }
 }
 
@@ -557,62 +639,91 @@
     std::vector<float>* impulse_response) {
   RTC_DCHECK_EQ(GetTimeDomainLength(max_size_partitions_),
                 impulse_response->capacity());
-
   impulse_response->resize(GetTimeDomainLength(current_size_partitions_));
   std::array<float, kFftLength> h;
-  fft_.Ifft(H_[partition_to_constrain_], &h);
+  impulse_response->resize(GetTimeDomainLength(current_size_partitions_));
+  std::fill(
+      impulse_response->begin() + partition_to_constrain_ * kFftLengthBy2,
+      impulse_response->begin() + (partition_to_constrain_ + 1) * kFftLengthBy2,
+      0.f);
 
-  static constexpr float kScale = 1.0f / kFftLengthBy2;
-  std::for_each(h.begin(), h.begin() + kFftLengthBy2,
-                [](float& a) { a *= kScale; });
-  std::fill(h.begin() + kFftLengthBy2, h.end(), 0.f);
+  for (size_t ch = 0; ch < num_render_channels_; ++ch) {
+    fft_.Ifft(H_[partition_to_constrain_][ch], &h);
 
-  std::copy(
-      h.begin(), h.begin() + kFftLengthBy2,
-      impulse_response->begin() + partition_to_constrain_ * kFftLengthBy2);
+    static constexpr float kScale = 1.0f / kFftLengthBy2;
+    std::for_each(h.begin(), h.begin() + kFftLengthBy2,
+                  [](float& a) { a *= kScale; });
+    std::fill(h.begin() + kFftLengthBy2, h.end(), 0.f);
 
-  fft_.Fft(&h, &H_[partition_to_constrain_]);
+    if (ch == 0) {
+      std::copy(
+          h.begin(), h.begin() + kFftLengthBy2,
+          impulse_response->begin() + partition_to_constrain_ * kFftLengthBy2);
+    } else {
+      for (size_t k = 0, j = partition_to_constrain_ * kFftLengthBy2;
+           k < kFftLengthBy2; ++k, ++j) {
+        if (fabsf((*impulse_response)[j]) < fabsf(h[k])) {
+          (*impulse_response)[j] = h[k];
+        }
+      }
+    }
 
-  partition_to_constrain_ = partition_to_constrain_ < (H_.size() - 1)
-                                ? partition_to_constrain_ + 1
-                                : 0;
+    fft_.Fft(&h, &H_[partition_to_constrain_][ch]);
+  }
+
+  partition_to_constrain_ =
+      partition_to_constrain_ < (current_size_partitions_ - 1)
+          ? partition_to_constrain_ + 1
+          : 0;
 }
 
 // Constrains the a partiton of the frequency domain filter to be limited in
 // time via setting the relevant time-domain coefficients to zero.
 void AdaptiveFirFilter::Constrain() {
   std::array<float, kFftLength> h;
-  fft_.Ifft(H_[partition_to_constrain_], &h);
+  for (size_t ch = 0; ch < num_render_channels_; ++ch) {
+    fft_.Ifft(H_[partition_to_constrain_][ch], &h);
 
-  static constexpr float kScale = 1.0f / kFftLengthBy2;
-  std::for_each(h.begin(), h.begin() + kFftLengthBy2,
-                [](float& a) { a *= kScale; });
-  std::fill(h.begin() + kFftLengthBy2, h.end(), 0.f);
+    static constexpr float kScale = 1.0f / kFftLengthBy2;
+    std::for_each(h.begin(), h.begin() + kFftLengthBy2,
+                  [](float& a) { a *= kScale; });
+    std::fill(h.begin() + kFftLengthBy2, h.end(), 0.f);
 
-  fft_.Fft(&h, &H_[partition_to_constrain_]);
+    fft_.Fft(&h, &H_[partition_to_constrain_][ch]);
+  }
 
-  partition_to_constrain_ = partition_to_constrain_ < (H_.size() - 1)
-                                ? partition_to_constrain_ + 1
-                                : 0;
+  partition_to_constrain_ =
+      partition_to_constrain_ < (current_size_partitions_ - 1)
+          ? partition_to_constrain_ + 1
+          : 0;
 }
 
 void AdaptiveFirFilter::ScaleFilter(float factor) {
-  for (auto& H : H_) {
-    for (auto& re : H.re) {
-      re *= factor;
-    }
-    for (auto& im : H.im) {
-      im *= factor;
+  for (auto& H_p : H_) {
+    for (auto& H_p_ch : H_p) {
+      for (auto& re : H_p_ch.re) {
+        re *= factor;
+      }
+      for (auto& im : H_p_ch.im) {
+        im *= factor;
+      }
     }
   }
 }
 
 // Set the filter coefficients.
-void AdaptiveFirFilter::SetFilter(const std::vector<FftData>& H) {
-  const size_t num_partitions = std::min(H_.size(), H.size());
-  for (size_t k = 0; k < num_partitions; ++k) {
-    std::copy(H[k].re.begin(), H[k].re.end(), H_[k].re.begin());
-    std::copy(H[k].im.begin(), H[k].im.end(), H_[k].im.begin());
+void AdaptiveFirFilter::SetFilter(size_t num_partitions,
+                                  const std::vector<std::vector<FftData>>& H) {
+  const size_t min_num_partitions =
+      std::min(current_size_partitions_, num_partitions);
+  for (size_t p = 0; p < min_num_partitions; ++p) {
+    RTC_DCHECK_EQ(H_[p].size(), H[p].size());
+    RTC_DCHECK_EQ(num_render_channels_, H_[p].size());
+
+    for (size_t ch = 0; ch < num_render_channels_; ++ch) {
+      std::copy(H[p][ch].re.begin(), H[p][ch].re.end(), H_[p][ch].re.begin());
+      std::copy(H[p][ch].im.begin(), H[p][ch].im.end(), H_[p][ch].im.begin());
+    }
   }
 }
 
diff --git a/modules/audio_processing/aec3/adaptive_fir_filter.h b/modules/audio_processing/aec3/adaptive_fir_filter.h
index aec83aa..2f64853 100644
--- a/modules/audio_processing/aec3/adaptive_fir_filter.h
+++ b/modules/audio_processing/aec3/adaptive_fir_filter.h
@@ -27,47 +27,56 @@
 namespace webrtc {
 namespace aec3 {
 // Computes and stores the frequency response of the filter.
-void UpdateFrequencyResponse(
-    rtc::ArrayView<const FftData> H,
+void ComputeFrequencyResponse(
+    size_t num_partitions,
+    const std::vector<std::vector<FftData>>& H,
     std::vector<std::array<float, kFftLengthBy2Plus1>>* H2);
 #if defined(WEBRTC_HAS_NEON)
-void UpdateFrequencyResponse_NEON(
-    rtc::ArrayView<const FftData> H,
+void ComputeFrequencyResponse_Neon(
+    size_t num_partitions,
+    const std::vector<std::vector<FftData>>& H,
     std::vector<std::array<float, kFftLengthBy2Plus1>>* H2);
 #endif
 #if defined(WEBRTC_ARCH_X86_FAMILY)
-void UpdateFrequencyResponse_SSE2(
-    rtc::ArrayView<const FftData> H,
+void ComputeFrequencyResponse_Sse2(
+    size_t num_partitions,
+    const std::vector<std::vector<FftData>>& H,
     std::vector<std::array<float, kFftLengthBy2Plus1>>* H2);
 #endif
 
 // Adapts the filter partitions.
 void AdaptPartitions(const RenderBuffer& render_buffer,
                      const FftData& G,
-                     rtc::ArrayView<FftData> H);
+                     size_t num_partitions,
+                     std::vector<std::vector<FftData>>* H);
 #if defined(WEBRTC_HAS_NEON)
-void AdaptPartitions_NEON(const RenderBuffer& render_buffer,
+void AdaptPartitions_Neon(const RenderBuffer& render_buffer,
                           const FftData& G,
-                          rtc::ArrayView<FftData> H);
+                          size_t num_partitions,
+                          std::vector<std::vector<FftData>>* H);
 #endif
 #if defined(WEBRTC_ARCH_X86_FAMILY)
-void AdaptPartitions_SSE2(const RenderBuffer& render_buffer,
+void AdaptPartitions_Sse2(const RenderBuffer& render_buffer,
                           const FftData& G,
-                          rtc::ArrayView<FftData> H);
+                          size_t num_partitions,
+                          std::vector<std::vector<FftData>>* H);
 #endif
 
 // Produces the filter output.
 void ApplyFilter(const RenderBuffer& render_buffer,
-                 rtc::ArrayView<const FftData> H,
+                 size_t num_partitions,
+                 const std::vector<std::vector<FftData>>& H,
                  FftData* S);
 #if defined(WEBRTC_HAS_NEON)
-void ApplyFilter_NEON(const RenderBuffer& render_buffer,
-                      rtc::ArrayView<const FftData> H,
+void ApplyFilter_Neon(const RenderBuffer& render_buffer,
+                      size_t num_partitions,
+                      const std::vector<std::vector<FftData>>& H,
                       FftData* S);
 #endif
 #if defined(WEBRTC_ARCH_X86_FAMILY)
-void ApplyFilter_SSE2(const RenderBuffer& render_buffer,
-                      rtc::ArrayView<const FftData> H,
+void ApplyFilter_Sse2(const RenderBuffer& render_buffer,
+                      size_t num_partitions,
+                      const std::vector<std::vector<FftData>>& H,
                       FftData* S);
 #endif
 
@@ -80,7 +89,6 @@
                     size_t initial_size_partitions,
                     size_t size_change_duration_blocks,
                     size_t num_render_channels,
-                    size_t num_capture_channels,
                     Aec3Optimization optimization,
                     ApmDataDumper* data_dumper);
 
@@ -106,7 +114,7 @@
   void HandleEchoPathChange();
 
   // Returns the filter size.
-  size_t SizePartitions() const { return H_.size(); }
+  size_t SizePartitions() const { return current_size_partitions_; }
 
   // Sets the filter size.
   void SetSizePartitions(size_t size, bool immediate_effect);
@@ -119,23 +127,21 @@
   size_t max_filter_size_partitions() const { return max_size_partitions_; }
 
   void DumpFilter(const char* name_frequency_domain) {
-    size_t current_size = H_.size();
-    H_.resize(max_size_partitions_);
-    for (auto& H : H_) {
-      data_dumper_->DumpRaw(name_frequency_domain, H.re);
-      data_dumper_->DumpRaw(name_frequency_domain, H.im);
+    for (size_t p = 0; p < max_size_partitions_; ++p) {
+      data_dumper_->DumpRaw(name_frequency_domain, H_[p][0].re);
+      data_dumper_->DumpRaw(name_frequency_domain, H_[p][0].im);
     }
-    H_.resize(current_size);
   }
 
   // Scale the filter impulse response and spectrum by a factor.
   void ScaleFilter(float factor);
 
   // Set the filter coefficients.
-  void SetFilter(const std::vector<FftData>& H);
+  void SetFilter(size_t num_partitions,
+                 const std::vector<std::vector<FftData>>& H);
 
   // Gets the filter coefficients.
-  const std::vector<FftData>& GetFilter() const { return H_; }
+  const std::vector<std::vector<FftData>>& GetFilter() const { return H_; }
 
  private:
   // Adapts the filter and updates the filter size.
@@ -147,15 +153,13 @@
   // values in the supplied impulse response.
   void ConstrainAndUpdateImpulseResponse(std::vector<float>* impulse_response);
 
-  // Resets the filter buffers to use the current size.
-  void ResetFilterBuffersToCurrentSize();
-
   // Gradually Updates the current filter size towards the target size.
   void UpdateSize();
 
   ApmDataDumper* const data_dumper_;
   const Aec3Fft fft_;
   const Aec3Optimization optimization_;
+  const size_t num_render_channels_;
   const size_t max_size_partitions_;
   const int size_change_duration_blocks_;
   float one_by_size_change_duration_blocks_;
@@ -163,7 +167,7 @@
   size_t target_size_partitions_;
   size_t old_target_size_partitions_;
   int size_change_counter_ = 0;
-  std::vector<FftData> H_;
+  std::vector<std::vector<FftData>> H_;
   size_t partition_to_constrain_ = 0;
 };
 
diff --git a/modules/audio_processing/aec3/adaptive_fir_filter_unittest.cc b/modules/audio_processing/aec3/adaptive_fir_filter_unittest.cc
index 36e31eb..6f1635f 100644
--- a/modules/audio_processing/aec3/adaptive_fir_filter_unittest.cc
+++ b/modules/audio_processing/aec3/adaptive_fir_filter_unittest.cc
@@ -42,9 +42,10 @@
 namespace aec3 {
 namespace {
 
-std::string ProduceDebugText(size_t delay) {
+std::string ProduceDebugText(size_t num_render_channels, size_t delay) {
   rtc::StringBuilder ss;
-  ss << ", Delay: " << delay;
+  ss << "delay: " << delay << ", ";
+  ss << "num_render_channels:" << num_render_channels;
   return ss.Release();
 }
 
@@ -54,163 +55,184 @@
 // Verifies that the optimized methods for filter adaptation are similar to
 // their reference counterparts.
 TEST(AdaptiveFirFilter, FilterAdaptationNeonOptimizations) {
-  constexpr size_t kNumRenderChannels = 1;
-  constexpr int kSampleRateHz = 48000;
-  constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz);
+  for (size_t num_partitions : {2, 5, 12, 30, 50}) {
+    for (size_t num_render_channels : {1, 2, 4, 8}) {
+      constexpr int kSampleRateHz = 48000;
+      constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz);
 
-  std::unique_ptr<RenderDelayBuffer> render_delay_buffer(
-      RenderDelayBuffer::Create(EchoCanceller3Config(), kSampleRateHz,
-                                kNumRenderChannels));
-  Random random_generator(42U);
-  std::vector<std::vector<std::vector<float>>> x(
-      kNumBands, std::vector<std::vector<float>>(
-                     kNumRenderChannels, std::vector<float>(kBlockSize, 0.f)));
-  FftData S_C;
-  FftData S_NEON;
-  FftData G;
-  Aec3Fft fft;
-  std::vector<FftData> H_C(10);
-  std::vector<FftData> H_NEON(10);
-  for (auto& H_j : H_C) {
-    H_j.Clear();
-  }
-  for (auto& H_j : H_NEON) {
-    H_j.Clear();
-  }
+      std::unique_ptr<RenderDelayBuffer> render_delay_buffer(
+          RenderDelayBuffer::Create(EchoCanceller3Config(), kSampleRateHz,
+                                    num_render_channels));
+      Random random_generator(42U);
+      std::vector<std::vector<std::vector<float>>> x(
+          kNumBands,
+          std::vector<std::vector<float>>(num_render_channels,
+                                          std::vector<float>(kBlockSize, 0.f)));
+      FftData S_C;
+      FftData S_Neon;
+      FftData G;
+      Aec3Fft fft;
+      std::vector<std::vector<FftData>> H_C(
+          num_partitions, std::vector<FftData>(num_render_channels));
+      std::vector<std::vector<FftData>> H_Neon(
+          num_partitions, std::vector<FftData>(num_render_channels));
+      for (size_t p = 0; p < num_partitions; ++p) {
+        for (size_t ch = 0; ch < num_render_channels; ++ch) {
+          H_C[p][ch].Clear();
+          H_Neon[p][ch].Clear();
+        }
+      }
 
-  for (size_t k = 0; k < 30; ++k) {
-    for (size_t band = 0; band < x.size(); ++band) {
-      for (size_t channel = 0; channel < x[band].size(); ++channel) {
-        RandomizeSampleVector(&random_generator, x[band][channel]);
+      for (size_t k = 0; k < 30; ++k) {
+        for (size_t band = 0; band < x.size(); ++band) {
+          for (size_t ch = 0; ch < x[band].size(); ++ch) {
+            RandomizeSampleVector(&random_generator, x[band][ch]);
+          }
+        }
+        render_delay_buffer->Insert(x);
+        if (k == 0) {
+          render_delay_buffer->Reset();
+        }
+        render_delay_buffer->PrepareCaptureProcessing();
+      }
+      auto* const render_buffer = render_delay_buffer->GetRenderBuffer();
+
+      for (size_t j = 0; j < G.re.size(); ++j) {
+        G.re[j] = j / 10001.f;
+      }
+      for (size_t j = 1; j < G.im.size() - 1; ++j) {
+        G.im[j] = j / 20001.f;
+      }
+      G.im[0] = 0.f;
+      G.im[G.im.size() - 1] = 0.f;
+
+      AdaptPartitions_Neon(*render_buffer, G, num_partitions, &H_Neon);
+      AdaptPartitions(*render_buffer, G, num_partitions, &H_C);
+      AdaptPartitions_Neon(*render_buffer, G, num_partitions, &H_Neon);
+      AdaptPartitions(*render_buffer, G, num_partitions, &H_C);
+
+      for (size_t p = 0; p < num_partitions; ++p) {
+        for (size_t ch = 0; ch < num_render_channels; ++ch) {
+          for (size_t j = 0; j < H_C[p][ch].re.size(); ++j) {
+            EXPECT_FLOAT_EQ(H_C[p][ch].re[j], H_Neon[p][ch].re[j]);
+            EXPECT_FLOAT_EQ(H_C[p][ch].im[j], H_Neon[p][ch].im[j]);
+          }
+        }
+      }
+
+      ApplyFilter_Neon(*render_buffer, num_partitions, H_Neon, &S_Neon);
+      ApplyFilter(*render_buffer, num_partitions, H_C, &S_C);
+      for (size_t j = 0; j < S_C.re.size(); ++j) {
+        EXPECT_NEAR(S_C.re[j], S_Neon.re[j], fabs(S_C.re[j] * 0.00001f));
+        EXPECT_NEAR(S_C.im[j], S_Neon.im[j], fabs(S_C.re[j] * 0.00001f));
       }
     }
-    render_delay_buffer->Insert(x);
-    if (k == 0) {
-      render_delay_buffer->Reset();
-    }
-    render_delay_buffer->PrepareCaptureProcessing();
-  }
-  auto* const render_buffer = render_delay_buffer->GetRenderBuffer();
-
-  for (size_t j = 0; j < G.re.size(); ++j) {
-    G.re[j] = j / 10001.f;
-  }
-  for (size_t j = 1; j < G.im.size() - 1; ++j) {
-    G.im[j] = j / 20001.f;
-  }
-  G.im[0] = 0.f;
-  G.im[G.im.size() - 1] = 0.f;
-
-  AdaptPartitions_NEON(*render_buffer, G, H_NEON);
-  AdaptPartitions(*render_buffer, G, H_C);
-  AdaptPartitions_NEON(*render_buffer, G, H_NEON);
-  AdaptPartitions(*render_buffer, G, H_C);
-
-  for (size_t l = 0; l < H_C.size(); ++l) {
-    for (size_t j = 0; j < H_C[l].im.size(); ++j) {
-      EXPECT_NEAR(H_C[l].re[j], H_NEON[l].re[j], fabs(H_C[l].re[j] * 0.00001f));
-      EXPECT_NEAR(H_C[l].im[j], H_NEON[l].im[j], fabs(H_C[l].im[j] * 0.00001f));
-    }
-  }
-
-  ApplyFilter_NEON(*render_buffer, H_NEON, &S_NEON);
-  ApplyFilter(*render_buffer, H_C, &S_C);
-  for (size_t j = 0; j < S_C.re.size(); ++j) {
-    EXPECT_NEAR(S_C.re[j], S_NEON.re[j], fabs(S_C.re[j] * 0.00001f));
-    EXPECT_NEAR(S_C.im[j], S_NEON.im[j], fabs(S_C.re[j] * 0.00001f));
   }
 }
 
 // Verifies that the optimized method for frequency response computation is
 // bitexact to the reference counterpart.
-TEST(AdaptiveFirFilter, UpdateFrequencyResponseNeonOptimization) {
-  const size_t kNumPartitions = 12;
-  std::vector<FftData> H(kNumPartitions);
-  std::vector<std::array<float, kFftLengthBy2Plus1>> H2(kNumPartitions);
-  std::vector<std::array<float, kFftLengthBy2Plus1>> H2_NEON(kNumPartitions);
+TEST(AdaptiveFirFilter, ComputeFrequencyResponseNeonOptimization) {
+  for (size_t num_partitions : {2, 5, 12, 30, 50}) {
+    for (size_t num_render_channels : {1, 2, 4, 8}) {
+      std::vector<std::vector<FftData>> H(
+          num_partitions, std::vector<FftData>(num_render_channels));
+      std::vector<std::array<float, kFftLengthBy2Plus1>> H2(num_partitions);
+      std::vector<std::array<float, kFftLengthBy2Plus1>> H2_Neon(
+          num_partitions);
 
-  for (size_t j = 0; j < H.size(); ++j) {
-    for (size_t k = 0; k < H[j].re.size(); ++k) {
-      H[j].re[k] = k + j / 3.f;
-      H[j].im[k] = j + k / 7.f;
-    }
-  }
+      for (size_t p = 0; p < num_partitions; ++p) {
+        for (size_t ch = 0; ch < num_render_channels; ++ch) {
+          for (size_t k = 0; k < H[p][ch].re.size(); ++k) {
+            H[p][ch].re[k] = k + p / 3.f + ch;
+            H[p][ch].im[k] = p + k / 7.f - ch;
+          }
+        }
+      }
 
-  UpdateFrequencyResponse(H, &H2);
-  UpdateFrequencyResponse_NEON(H, &H2_NEON);
+      ComputeFrequencyResponse(num_partitions, H, &H2);
+      ComputeFrequencyResponse_Neon(num_partitions, H, &H2_Neon);
 
-  for (size_t j = 0; j < H2.size(); ++j) {
-    for (size_t k = 0; k < H[j].re.size(); ++k) {
-      EXPECT_FLOAT_EQ(H2[j][k], H2_NEON[j][k]);
+      for (size_t p = 0; p < num_partitions; ++p) {
+        for (size_t k = 0; k < H2[p].size(); ++k) {
+          EXPECT_FLOAT_EQ(H2[p][k], H2_Neon[p][k]);
+        }
+      }
     }
   }
 }
-
 #endif
 
 #if defined(WEBRTC_ARCH_X86_FAMILY)
 // Verifies that the optimized methods for filter adaptation are bitexact to
 // their reference counterparts.
 TEST(AdaptiveFirFilter, FilterAdaptationSse2Optimizations) {
-  constexpr size_t kNumRenderChannels = 1;
   constexpr int kSampleRateHz = 48000;
   constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz);
 
   bool use_sse2 = (WebRtc_GetCPUInfo(kSSE2) != 0);
   if (use_sse2) {
-    std::unique_ptr<RenderDelayBuffer> render_delay_buffer(
-        RenderDelayBuffer::Create(EchoCanceller3Config(), kSampleRateHz,
-                                  kNumRenderChannels));
-    Random random_generator(42U);
-    std::vector<std::vector<std::vector<float>>> x(
-        kNumBands,
-        std::vector<std::vector<float>>(kNumRenderChannels,
-                                        std::vector<float>(kBlockSize, 0.f)));
-    FftData S_C;
-    FftData S_SSE2;
-    FftData G;
-    Aec3Fft fft;
-    std::vector<FftData> H_C(10);
-    std::vector<FftData> H_SSE2(10);
-    for (auto& H_j : H_C) {
-      H_j.Clear();
-    }
-    for (auto& H_j : H_SSE2) {
-      H_j.Clear();
-    }
-
-    for (size_t k = 0; k < 500; ++k) {
-      for (size_t band = 0; band < x.size(); ++band) {
-        for (size_t channel = 0; channel < x[band].size(); ++channel) {
-          RandomizeSampleVector(&random_generator, x[band][channel]);
+    for (size_t num_partitions : {2, 5, 12, 30, 50}) {
+      for (size_t num_render_channels : {1, 2, 4, 8}) {
+        std::unique_ptr<RenderDelayBuffer> render_delay_buffer(
+            RenderDelayBuffer::Create(EchoCanceller3Config(), kSampleRateHz,
+                                      num_render_channels));
+        Random random_generator(42U);
+        std::vector<std::vector<std::vector<float>>> x(
+            kNumBands,
+            std::vector<std::vector<float>>(
+                num_render_channels, std::vector<float>(kBlockSize, 0.f)));
+        FftData S_C;
+        FftData S_Sse2;
+        FftData G;
+        Aec3Fft fft;
+        std::vector<std::vector<FftData>> H_C(
+            num_partitions, std::vector<FftData>(num_render_channels));
+        std::vector<std::vector<FftData>> H_Sse2(
+            num_partitions, std::vector<FftData>(num_render_channels));
+        for (size_t p = 0; p < num_partitions; ++p) {
+          for (size_t ch = 0; ch < num_render_channels; ++ch) {
+            H_C[p][ch].Clear();
+            H_Sse2[p][ch].Clear();
+          }
         }
-      }
-      render_delay_buffer->Insert(x);
-      if (k == 0) {
-        render_delay_buffer->Reset();
-      }
-      render_delay_buffer->PrepareCaptureProcessing();
-      auto* const render_buffer = render_delay_buffer->GetRenderBuffer();
 
-      ApplyFilter_SSE2(*render_buffer, H_SSE2, &S_SSE2);
-      ApplyFilter(*render_buffer, H_C, &S_C);
-      for (size_t j = 0; j < S_C.re.size(); ++j) {
-        EXPECT_FLOAT_EQ(S_C.re[j], S_SSE2.re[j]);
-        EXPECT_FLOAT_EQ(S_C.im[j], S_SSE2.im[j]);
-      }
+        for (size_t k = 0; k < 500; ++k) {
+          for (size_t band = 0; band < x.size(); ++band) {
+            for (size_t ch = 0; ch < x[band].size(); ++ch) {
+              RandomizeSampleVector(&random_generator, x[band][ch]);
+            }
+          }
+          render_delay_buffer->Insert(x);
+          if (k == 0) {
+            render_delay_buffer->Reset();
+          }
+          render_delay_buffer->PrepareCaptureProcessing();
+          auto* const render_buffer = render_delay_buffer->GetRenderBuffer();
 
-      std::for_each(G.re.begin(), G.re.end(),
-                    [&](float& a) { a = random_generator.Rand<float>(); });
-      std::for_each(G.im.begin(), G.im.end(),
-                    [&](float& a) { a = random_generator.Rand<float>(); });
+          ApplyFilter_Sse2(*render_buffer, num_partitions, H_Sse2, &S_Sse2);
+          ApplyFilter(*render_buffer, num_partitions, H_C, &S_C);
+          for (size_t j = 0; j < S_C.re.size(); ++j) {
+            EXPECT_FLOAT_EQ(S_C.re[j], S_Sse2.re[j]);
+            EXPECT_FLOAT_EQ(S_C.im[j], S_Sse2.im[j]);
+          }
 
-      AdaptPartitions_SSE2(*render_buffer, G, H_SSE2);
-      AdaptPartitions(*render_buffer, G, H_C);
+          std::for_each(G.re.begin(), G.re.end(),
+                        [&](float& a) { a = random_generator.Rand<float>(); });
+          std::for_each(G.im.begin(), G.im.end(),
+                        [&](float& a) { a = random_generator.Rand<float>(); });
 
-      for (size_t k = 0; k < H_C.size(); ++k) {
-        for (size_t j = 0; j < H_C[k].re.size(); ++j) {
-          EXPECT_FLOAT_EQ(H_C[k].re[j], H_SSE2[k].re[j]);
-          EXPECT_FLOAT_EQ(H_C[k].im[j], H_SSE2[k].im[j]);
+          AdaptPartitions_Sse2(*render_buffer, G, num_partitions, &H_Sse2);
+          AdaptPartitions(*render_buffer, G, num_partitions, &H_C);
+
+          for (size_t p = 0; p < num_partitions; ++p) {
+            for (size_t ch = 0; ch < num_render_channels; ++ch) {
+              for (size_t j = 0; j < H_C[p][ch].re.size(); ++j) {
+                EXPECT_FLOAT_EQ(H_C[p][ch].re[j], H_Sse2[p][ch].re[j]);
+                EXPECT_FLOAT_EQ(H_C[p][ch].im[j], H_Sse2[p][ch].im[j]);
+              }
+            }
+          }
         }
       }
     }
@@ -219,27 +241,34 @@
 
 // Verifies that the optimized method for frequency response computation is
 // bitexact to the reference counterpart.
-TEST(AdaptiveFirFilter, UpdateFrequencyResponseSse2Optimization) {
+TEST(AdaptiveFirFilter, ComputeFrequencyResponseSse2Optimization) {
   bool use_sse2 = (WebRtc_GetCPUInfo(kSSE2) != 0);
   if (use_sse2) {
-    const size_t kNumPartitions = 12;
-    std::vector<FftData> H(kNumPartitions);
-    std::vector<std::array<float, kFftLengthBy2Plus1>> H2(kNumPartitions);
-    std::vector<std::array<float, kFftLengthBy2Plus1>> H2_SSE2(kNumPartitions);
+    for (size_t num_partitions : {2, 5, 12, 30, 50}) {
+      for (size_t num_render_channels : {1, 2, 4, 8}) {
+        std::vector<std::vector<FftData>> H(
+            num_partitions, std::vector<FftData>(num_render_channels));
+        std::vector<std::array<float, kFftLengthBy2Plus1>> H2(num_partitions);
+        std::vector<std::array<float, kFftLengthBy2Plus1>> H2_Sse2(
+            num_partitions);
 
-    for (size_t j = 0; j < H.size(); ++j) {
-      for (size_t k = 0; k < H[j].re.size(); ++k) {
-        H[j].re[k] = k + j / 3.f;
-        H[j].im[k] = j + k / 7.f;
-      }
-    }
+        for (size_t p = 0; p < num_partitions; ++p) {
+          for (size_t ch = 0; ch < num_render_channels; ++ch) {
+            for (size_t k = 0; k < H[p][ch].re.size(); ++k) {
+              H[p][ch].re[k] = k + p / 3.f + ch;
+              H[p][ch].im[k] = p + k / 7.f - ch;
+            }
+          }
+        }
 
-    UpdateFrequencyResponse(H, &H2);
-    UpdateFrequencyResponse_SSE2(H, &H2_SSE2);
+        ComputeFrequencyResponse(num_partitions, H, &H2);
+        ComputeFrequencyResponse_Sse2(num_partitions, H, &H2_Sse2);
 
-    for (size_t j = 0; j < H2.size(); ++j) {
-      for (size_t k = 0; k < H[j].re.size(); ++k) {
-        EXPECT_FLOAT_EQ(H2[j][k], H2_SSE2[j][k]);
+        for (size_t p = 0; p < num_partitions; ++p) {
+          for (size_t k = 0; k < H2[p].size(); ++k) {
+            EXPECT_FLOAT_EQ(H2[p][k], H2_Sse2[p][k]);
+          }
+        }
       }
     }
   }
@@ -250,14 +279,14 @@
 #if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID)
 // Verifies that the check for non-null data dumper works.
 TEST(AdaptiveFirFilter, NullDataDumper) {
-  EXPECT_DEATH(
-      AdaptiveFirFilter(9, 9, 250, 1, 1, DetectOptimization(), nullptr), "");
+  EXPECT_DEATH(AdaptiveFirFilter(9, 9, 250, 1, DetectOptimization(), nullptr),
+               "");
 }
 
 // Verifies that the check for non-null filter output works.
 TEST(AdaptiveFirFilter, NullFilterOutput) {
   ApmDataDumper data_dumper(42);
-  AdaptiveFirFilter filter(9, 9, 250, 1, 1, DetectOptimization(), &data_dumper);
+  AdaptiveFirFilter filter(9, 9, 250, 1, DetectOptimization(), &data_dumper);
   std::unique_ptr<RenderDelayBuffer> render_delay_buffer(
       RenderDelayBuffer::Create(EchoCanceller3Config(), 48000, 1));
   EXPECT_DEATH(filter.Filter(*render_delay_buffer->GetRenderBuffer(), nullptr),
@@ -271,7 +300,7 @@
 TEST(AdaptiveFirFilter, FilterStatisticsAccess) {
   ApmDataDumper data_dumper(42);
   Aec3Optimization optimization = DetectOptimization();
-  AdaptiveFirFilter filter(9, 9, 250, 1, 1, optimization, &data_dumper);
+  AdaptiveFirFilter filter(9, 9, 250, 1, optimization, &data_dumper);
   std::vector<std::array<float, kFftLengthBy2Plus1>> H2(
       filter.max_filter_size_partitions(),
       std::array<float, kFftLengthBy2Plus1>());
@@ -288,7 +317,7 @@
 TEST(AdaptiveFirFilter, FilterSize) {
   ApmDataDumper data_dumper(42);
   for (size_t filter_size = 1; filter_size < 5; ++filter_size) {
-    AdaptiveFirFilter filter(filter_size, filter_size, 250, 1, 1,
+    AdaptiveFirFilter filter(filter_size, filter_size, 250, 1,
                              DetectOptimization(), &data_dumper);
     EXPECT_EQ(filter_size, filter.SizePartitions());
   }
@@ -297,115 +326,146 @@
 // Verifies that the filter is being able to properly filter a signal and to
 // adapt its coefficients.
 TEST(AdaptiveFirFilter, FilterAndAdapt) {
-  constexpr size_t kNumRenderChannels = 1;
-  constexpr size_t kNumCaptureChannels = 1;
   constexpr int kSampleRateHz = 48000;
   constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz);
+  constexpr size_t kNumBlocksToProcessPerRenderChannel = 1000;
+  constexpr size_t kNumCaptureChannels = 1;
 
-  constexpr size_t kNumBlocksToProcess = 1000;
-  ApmDataDumper data_dumper(42);
-  EchoCanceller3Config config;
-  AdaptiveFirFilter filter(config.filter.main.length_blocks,
-                           config.filter.main.length_blocks,
-                           config.filter.config_change_duration_blocks, 1, 1,
-                           DetectOptimization(), &data_dumper);
-  std::vector<std::array<float, kFftLengthBy2Plus1>> H2(
-      filter.max_filter_size_partitions(),
-      std::array<float, kFftLengthBy2Plus1>());
-  std::vector<float> h(GetTimeDomainLength(filter.max_filter_size_partitions()),
-                       0.f);
-  Aec3Fft fft;
-  config.delay.default_delay = 1;
-  std::unique_ptr<RenderDelayBuffer> render_delay_buffer(
-      RenderDelayBuffer::Create(config, kSampleRateHz, kNumRenderChannels));
-  ShadowFilterUpdateGain gain(config.filter.shadow,
-                              config.filter.config_change_duration_blocks);
-  Random random_generator(42U);
-  std::vector<std::vector<std::vector<float>>> x(
-      kNumBands, std::vector<std::vector<float>>(
-                     kNumRenderChannels, std::vector<float>(kBlockSize, 0.f)));
-  std::vector<float> n(kBlockSize, 0.f);
-  std::vector<float> y(kBlockSize, 0.f);
-  AecState aec_state(EchoCanceller3Config{}, kNumCaptureChannels);
-  RenderSignalAnalyzer render_signal_analyzer(config);
-  absl::optional<DelayEstimate> delay_estimate;
-  std::vector<float> e(kBlockSize, 0.f);
-  std::array<float, kFftLength> s_scratch;
-  std::vector<SubtractorOutput> output(kNumCaptureChannels);
-  FftData S;
-  FftData G;
-  FftData E;
-  std::array<float, kFftLengthBy2Plus1> Y2;
-  std::array<float, kFftLengthBy2Plus1> E2_main;
-  std::array<float, kFftLengthBy2Plus1> E2_shadow;
-  // [B,A] = butter(2,100/8000,'high')
-  constexpr CascadedBiQuadFilter::BiQuadCoefficients
-      kHighPassFilterCoefficients = {{0.97261f, -1.94523f, 0.97261f},
-                                     {-1.94448f, 0.94598f}};
-  Y2.fill(0.f);
-  E2_main.fill(0.f);
-  E2_shadow.fill(0.f);
-  for (auto& subtractor_output : output) {
-    subtractor_output.Reset();
-  }
+  for (size_t num_render_channels : {1, 2, 3, 6, 8}) {
+    ApmDataDumper data_dumper(42);
+    EchoCanceller3Config config;
 
-  constexpr float kScale = 1.0f / kFftLengthBy2;
-
-  for (size_t delay_samples : {0, 64, 150, 200, 301}) {
-    DelayBuffer<float> delay_buffer(delay_samples);
-    CascadedBiQuadFilter x_hp_filter(kHighPassFilterCoefficients, 1);
-    CascadedBiQuadFilter y_hp_filter(kHighPassFilterCoefficients, 1);
-
-    SCOPED_TRACE(ProduceDebugText(delay_samples));
-    for (size_t j = 0; j < kNumBlocksToProcess; ++j) {
-      RandomizeSampleVector(&random_generator, x[0][0]);
-      delay_buffer.Delay(x[0][0], y);
-
-      RandomizeSampleVector(&random_generator, n);
-      static constexpr float kNoiseScaling = 1.f / 100.f;
-      std::transform(y.begin(), y.end(), n.begin(), y.begin(),
-                     [](float a, float b) { return a + b * kNoiseScaling; });
-
-      x_hp_filter.Process(x[0][0]);
-      y_hp_filter.Process(y);
-
-      render_delay_buffer->Insert(x);
-      if (j == 0) {
-        render_delay_buffer->Reset();
-      }
-      render_delay_buffer->PrepareCaptureProcessing();
-      auto* const render_buffer = render_delay_buffer->GetRenderBuffer();
-
-      render_signal_analyzer.Update(*render_buffer,
-                                    aec_state.FilterDelayBlocks());
-
-      filter.Filter(*render_buffer, &S);
-      fft.Ifft(S, &s_scratch);
-      std::transform(y.begin(), y.end(), s_scratch.begin() + kFftLengthBy2,
-                     e.begin(),
-                     [&](float a, float b) { return a - b * kScale; });
-      std::for_each(e.begin(), e.end(),
-                    [](float& a) { a = rtc::SafeClamp(a, -32768.f, 32767.f); });
-      fft.ZeroPaddedFft(e, Aec3Fft::Window::kRectangular, &E);
-      for (size_t k = 0; k < kBlockSize; ++k) {
-        output[0].s_main[k] = kScale * s_scratch[k + kFftLengthBy2];
-      }
-
-      std::array<float, kFftLengthBy2Plus1> render_power;
-      render_buffer->SpectralSum(filter.SizePartitions(), &render_power);
-      gain.Compute(render_power, render_signal_analyzer, E,
-                   filter.SizePartitions(), false, &G);
-      filter.Adapt(*render_buffer, G, &h);
-      aec_state.HandleEchoPathChange(EchoPathVariability(
-          false, EchoPathVariability::DelayAdjustment::kNone, false));
-
-      filter.ComputeFrequencyResponse(&H2);
-      aec_state.Update(delay_estimate, H2, h, *render_buffer, E2_main, Y2,
-                       output);
+    if (num_render_channels == 33) {
+      config.filter.main = {13, 0.00005f, 0.0005f, 0.0001f, 2.f, 20075344.f};
+      config.filter.shadow = {13, 0.1f, 20075344.f};
+      config.filter.main_initial = {12, 0.005f, 0.5f, 0.001f, 2.f, 20075344.f};
+      config.filter.shadow_initial = {12, 0.7f, 20075344.f};
     }
-    // Verify that the filter is able to perform well.
-    EXPECT_LT(1000 * std::inner_product(e.begin(), e.end(), e.begin(), 0.f),
-              std::inner_product(y.begin(), y.end(), y.begin(), 0.f));
+
+    AdaptiveFirFilter filter(
+        config.filter.main.length_blocks, config.filter.main.length_blocks,
+        config.filter.config_change_duration_blocks, num_render_channels,
+        DetectOptimization(), &data_dumper);
+    std::vector<std::array<float, kFftLengthBy2Plus1>> H2(
+        filter.max_filter_size_partitions(),
+        std::array<float, kFftLengthBy2Plus1>());
+    std::vector<float> h(
+        GetTimeDomainLength(filter.max_filter_size_partitions()), 0.f);
+    Aec3Fft fft;
+    config.delay.default_delay = 1;
+    std::unique_ptr<RenderDelayBuffer> render_delay_buffer(
+        RenderDelayBuffer::Create(config, kSampleRateHz, num_render_channels));
+    ShadowFilterUpdateGain gain(config.filter.shadow,
+                                config.filter.config_change_duration_blocks);
+    Random random_generator(42U);
+    std::vector<std::vector<std::vector<float>>> x(
+        kNumBands,
+        std::vector<std::vector<float>>(num_render_channels,
+                                        std::vector<float>(kBlockSize, 0.f)));
+    std::vector<float> n(kBlockSize, 0.f);
+    std::vector<float> y(kBlockSize, 0.f);
+    AecState aec_state(EchoCanceller3Config{}, kNumCaptureChannels);
+    RenderSignalAnalyzer render_signal_analyzer(config);
+    absl::optional<DelayEstimate> delay_estimate;
+    std::vector<float> e(kBlockSize, 0.f);
+    std::array<float, kFftLength> s_scratch;
+    std::vector<SubtractorOutput> output(kNumCaptureChannels);
+    FftData S;
+    FftData G;
+    FftData E;
+    std::array<float, kFftLengthBy2Plus1> Y2;
+    std::array<float, kFftLengthBy2Plus1> E2_main;
+    std::array<float, kFftLengthBy2Plus1> E2_shadow;
+    // [B,A] = butter(2,100/8000,'high')
+    constexpr CascadedBiQuadFilter::BiQuadCoefficients
+        kHighPassFilterCoefficients = {{0.97261f, -1.94523f, 0.97261f},
+                                       {-1.94448f, 0.94598f}};
+    Y2.fill(0.f);
+    E2_main.fill(0.f);
+    E2_shadow.fill(0.f);
+    for (auto& subtractor_output : output) {
+      subtractor_output.Reset();
+    }
+
+    constexpr float kScale = 1.0f / kFftLengthBy2;
+
+    for (size_t delay_samples : {0, 64, 150, 200, 301}) {
+      std::vector<DelayBuffer<float>> delay_buffer(
+          num_render_channels, DelayBuffer<float>(delay_samples));
+      std::vector<std::unique_ptr<CascadedBiQuadFilter>> x_hp_filter(
+          num_render_channels);
+      for (size_t ch = 0; ch < num_render_channels; ++ch) {
+        x_hp_filter[ch] = std::make_unique<CascadedBiQuadFilter>(
+            kHighPassFilterCoefficients, 1);
+      }
+      CascadedBiQuadFilter y_hp_filter(kHighPassFilterCoefficients, 1);
+
+      SCOPED_TRACE(ProduceDebugText(num_render_channels, delay_samples));
+      const size_t num_blocks_to_process =
+          kNumBlocksToProcessPerRenderChannel * num_render_channels;
+      for (size_t j = 0; j < num_blocks_to_process; ++j) {
+        std::fill(y.begin(), y.end(), 0.f);
+        for (size_t ch = 0; ch < num_render_channels; ++ch) {
+          RandomizeSampleVector(&random_generator, x[0][ch]);
+          std::array<float, kBlockSize> y_channel;
+          delay_buffer[ch].Delay(x[0][ch], y_channel);
+          for (size_t k = 0; k < y.size(); ++k) {
+            y[k] += y_channel[k] / num_render_channels;
+          }
+        }
+
+        RandomizeSampleVector(&random_generator, n);
+        const float noise_scaling = 1.f / 100.f / num_render_channels;
+        for (size_t k = 0; k < y.size(); ++k) {
+          y[k] += n[k] * noise_scaling;
+        }
+
+        for (size_t ch = 0; ch < num_render_channels; ++ch) {
+          x_hp_filter[ch]->Process(x[0][ch]);
+        }
+        y_hp_filter.Process(y);
+
+        render_delay_buffer->Insert(x);
+        if (j == 0) {
+          render_delay_buffer->Reset();
+        }
+        render_delay_buffer->PrepareCaptureProcessing();
+        auto* const render_buffer = render_delay_buffer->GetRenderBuffer();
+
+        render_signal_analyzer.Update(*render_buffer,
+                                      aec_state.FilterDelayBlocks());
+
+        filter.Filter(*render_buffer, &S);
+        fft.Ifft(S, &s_scratch);
+        std::transform(y.begin(), y.end(), s_scratch.begin() + kFftLengthBy2,
+                       e.begin(),
+                       [&](float a, float b) { return a - b * kScale; });
+        std::for_each(e.begin(), e.end(), [](float& a) {
+          a = rtc::SafeClamp(a, -32768.f, 32767.f);
+        });
+        fft.ZeroPaddedFft(e, Aec3Fft::Window::kRectangular, &E);
+        for (auto& o : output) {
+          for (size_t k = 0; k < kBlockSize; ++k) {
+            o.s_main[k] = kScale * s_scratch[k + kFftLengthBy2];
+          }
+        }
+
+        std::array<float, kFftLengthBy2Plus1> render_power;
+        render_buffer->SpectralSum(filter.SizePartitions(), &render_power);
+        gain.Compute(render_power, render_signal_analyzer, E,
+                     filter.SizePartitions(), false, &G);
+        filter.Adapt(*render_buffer, G, &h);
+        aec_state.HandleEchoPathChange(EchoPathVariability(
+            false, EchoPathVariability::DelayAdjustment::kNone, false));
+
+        filter.ComputeFrequencyResponse(&H2);
+        aec_state.Update(delay_estimate, H2, h, *render_buffer, E2_main, Y2,
+                         output);
+      }
+      // Verify that the filter is able to perform well.
+      EXPECT_LT(1000 * std::inner_product(e.begin(), e.end(), e.begin(), 0.f),
+                std::inner_product(y.begin(), y.end(), y.begin(), 0.f));
+    }
   }
 }
 }  // namespace aec3
diff --git a/modules/audio_processing/aec3/comfort_noise_generator_unittest.cc b/modules/audio_processing/aec3/comfort_noise_generator_unittest.cc
index 94aa039..7abbb79 100644
--- a/modules/audio_processing/aec3/comfort_noise_generator_unittest.cc
+++ b/modules/audio_processing/aec3/comfort_noise_generator_unittest.cc
@@ -13,6 +13,7 @@
 #include <algorithm>
 #include <numeric>
 
+#include "modules/audio_processing/aec3/aec_state.h"
 #include "rtc_base/random.h"
 #include "rtc_base/system/arch.h"
 #include "system_wrappers/include/cpu_features_wrapper.h"
diff --git a/modules/audio_processing/aec3/echo_remover.cc b/modules/audio_processing/aec3/echo_remover.cc
index 2df9cfd..c33b39c 100644
--- a/modules/audio_processing/aec3/echo_remover.cc
+++ b/modules/audio_processing/aec3/echo_remover.cc
@@ -386,9 +386,9 @@
 
   // Update the AEC state information.
   // TODO(bugs.webrtc.org/10913): Take all subtractors into account.
-  aec_state_.Update(external_delay, subtractor_.FilterFrequencyResponse(),
-                    subtractor_.FilterImpulseResponse(), *render_buffer, E2[0],
-                    Y2[0], subtractor_output);
+  aec_state_.Update(external_delay, subtractor_.FilterFrequencyResponse()[0],
+                    subtractor_.FilterImpulseResponse()[0], *render_buffer,
+                    E2[0], Y2[0], subtractor_output);
 
   // Choose the linear output.
   const auto& Y_fft = aec_state_.UseLinearFilterOutput() ? E : Y;
diff --git a/modules/audio_processing/aec3/main_filter_update_gain.cc b/modules/audio_processing/aec3/main_filter_update_gain.cc
index c2cfd2c..43f37b0 100644
--- a/modules/audio_processing/aec3/main_filter_update_gain.cc
+++ b/modules/audio_processing/aec3/main_filter_update_gain.cc
@@ -80,7 +80,7 @@
   const auto& E2_main = subtractor_output.E2_main;
   const auto& E2_shadow = subtractor_output.E2_shadow;
   FftData* G = gain_fft;
-  auto X2 = render_power;
+  const auto& X2 = render_power;
 
   ++call_counter_;
 
@@ -100,43 +100,40 @@
     std::array<float, kFftLengthBy2Plus1> mu;
     // mu = H_error / (0.5* H_error* X2 + n * E2).
     for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) {
-      mu[k] = X2[k] > current_config_.noise_gate
-                  ? H_error_[k] / (0.5f * H_error_[k] * X2[k] +
-                                   size_partitions * E2_main[k])
-                  : 0.f;
+      if (X2[k] >= current_config_.noise_gate) {
+        mu[k] = H_error_[k] /
+                (0.5f * H_error_[k] * X2[k] + size_partitions * E2_main[k]);
+      } else {
+        mu[k] = 0.f;
+      }
     }
 
     // Avoid updating the filter close to narrow bands in the render signals.
     render_signal_analyzer.MaskRegionsAroundNarrowBands(&mu);
 
     // H_error = H_error - 0.5 * mu * X2 * H_error.
-    for (size_t k = 0; k < H_error_.size(); ++k) {
+    for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) {
       H_error_[k] -= 0.5f * mu[k] * X2[k] * H_error_[k];
     }
 
     // G = mu * E.
-    std::transform(mu.begin(), mu.end(), E_main.re.begin(), G->re.begin(),
-                   std::multiplies<float>());
-    std::transform(mu.begin(), mu.end(), E_main.im.begin(), G->im.begin(),
-                   std::multiplies<float>());
+    for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) {
+      G->re[k] = mu[k] * E_main.re[k];
+      G->im[k] = mu[k] * E_main.im[k];
+    }
   }
 
   // H_error = H_error + factor * erl.
-  std::array<float, kFftLengthBy2Plus1> H_error_increase;
-  std::transform(E2_shadow.begin(), E2_shadow.end(), E2_main.begin(),
-                 H_error_increase.begin(), [&](float a, float b) {
-                   return a >= b ? current_config_.leakage_converged
-                                 : current_config_.leakage_diverged;
-                 });
-  std::transform(erl.begin(), erl.end(), H_error_increase.begin(),
-                 H_error_increase.begin(), std::multiplies<float>());
-  std::transform(H_error_.begin(), H_error_.end(), H_error_increase.begin(),
-                 H_error_.begin(), [&](float a, float b) {
-                   float error = a + b;
-                   error = std::max(error, current_config_.error_floor);
-                   error = std::min(error, current_config_.error_ceil);
-                   return error;
-                 });
+  for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) {
+    if (E2_shadow[k] >= E2_main[k]) {
+      H_error_[k] += current_config_.leakage_converged * erl[k];
+    } else {
+      H_error_[k] += current_config_.leakage_diverged * erl[k];
+    }
+
+    H_error_[k] = std::max(H_error_[k], current_config_.error_floor);
+    H_error_[k] = std::min(H_error_[k], current_config_.error_ceil);
+  }
 
   data_dumper_->DumpRaw("aec3_main_gain_H_error", H_error_);
 }
diff --git a/modules/audio_processing/aec3/main_filter_update_gain_unittest.cc b/modules/audio_processing/aec3/main_filter_update_gain_unittest.cc
index 20714ce..1a9e792 100644
--- a/modules/audio_processing/aec3/main_filter_update_gain_unittest.cc
+++ b/modules/audio_processing/aec3/main_filter_update_gain_unittest.cc
@@ -54,11 +54,11 @@
   AdaptiveFirFilter main_filter(config.filter.main.length_blocks,
                                 config.filter.main.length_blocks,
                                 config.filter.config_change_duration_blocks, 1,
-                                1, optimization, &data_dumper);
+                                optimization, &data_dumper);
   AdaptiveFirFilter shadow_filter(config.filter.shadow.length_blocks,
                                   config.filter.shadow.length_blocks,
                                   config.filter.config_change_duration_blocks,
-                                  1, 1, optimization, &data_dumper);
+                                  1, optimization, &data_dumper);
   std::vector<std::array<float, kFftLengthBy2Plus1>> H2(
       main_filter.max_filter_size_partitions(),
       std::array<float, kFftLengthBy2Plus1>());
diff --git a/modules/audio_processing/aec3/shadow_filter_update_gain.cc b/modules/audio_processing/aec3/shadow_filter_update_gain.cc
index e27437a..51ead2e 100644
--- a/modules/audio_processing/aec3/shadow_filter_update_gain.cc
+++ b/modules/audio_processing/aec3/shadow_filter_update_gain.cc
@@ -28,8 +28,6 @@
 }
 
 void ShadowFilterUpdateGain::HandleEchoPathChange() {
-  // TODO(peah): Check whether this counter should instead be initialized to a
-  // large value.
   poor_signal_excitation_counter_ = 0;
   call_counter_ = 0;
 }
@@ -60,19 +58,23 @@
 
   // Compute mu.
   std::array<float, kFftLengthBy2Plus1> mu;
-  auto X2 = render_power;
-  std::transform(X2.begin(), X2.end(), mu.begin(), [&](float a) {
-    return a > current_config_.noise_gate ? current_config_.rate / a : 0.f;
-  });
+  const auto& X2 = render_power;
+  for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) {
+    if (X2[k] > current_config_.noise_gate) {
+      mu[k] = current_config_.rate / X2[k];
+    } else {
+      mu[k] = 0.f;
+    }
+  }
 
   // Avoid updating the filter close to narrow bands in the render signals.
   render_signal_analyzer.MaskRegionsAroundNarrowBands(&mu);
 
   // G = mu * E * X2.
-  std::transform(mu.begin(), mu.end(), E_shadow.re.begin(), G->re.begin(),
-                 std::multiplies<float>());
-  std::transform(mu.begin(), mu.end(), E_shadow.im.begin(), G->im.begin(),
-                 std::multiplies<float>());
+  for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) {
+    G->re[k] = mu[k] * E_shadow.re[k];
+    G->im[k] = mu[k] * E_shadow.im[k];
+  }
 }
 
 void ShadowFilterUpdateGain::UpdateCurrentConfig() {
diff --git a/modules/audio_processing/aec3/shadow_filter_update_gain_unittest.cc b/modules/audio_processing/aec3/shadow_filter_update_gain_unittest.cc
index 605f570..a73a539 100644
--- a/modules/audio_processing/aec3/shadow_filter_update_gain_unittest.cc
+++ b/modules/audio_processing/aec3/shadow_filter_update_gain_unittest.cc
@@ -44,11 +44,11 @@
   AdaptiveFirFilter main_filter(config.filter.main.length_blocks,
                                 config.filter.main.length_blocks,
                                 config.filter.config_change_duration_blocks, 1,
-                                1, DetectOptimization(), &data_dumper);
+                                DetectOptimization(), &data_dumper);
   AdaptiveFirFilter shadow_filter(config.filter.shadow.length_blocks,
                                   config.filter.shadow.length_blocks,
                                   config.filter.config_change_duration_blocks,
-                                  1, 1, DetectOptimization(), &data_dumper);
+                                  1, DetectOptimization(), &data_dumper);
   Aec3Fft fft;
 
   constexpr int kSampleRateHz = 48000;
diff --git a/modules/audio_processing/aec3/subtractor.cc b/modules/audio_processing/aec3/subtractor.cc
index 0c52ed6..5e99565 100644
--- a/modules/audio_processing/aec3/subtractor.cc
+++ b/modules/audio_processing/aec3/subtractor.cc
@@ -89,13 +89,13 @@
         config_.filter.main.length_blocks,
         config_.filter.main_initial.length_blocks,
         config.filter.config_change_duration_blocks, num_render_channels,
-        num_capture_channels, optimization, data_dumper_);
+        optimization, data_dumper_);
 
     shadow_filter_[ch] = std::make_unique<AdaptiveFirFilter>(
         config_.filter.shadow.length_blocks,
         config_.filter.shadow_initial.length_blocks,
         config.filter.config_change_duration_blocks, num_render_channels,
-        num_capture_channels, optimization, data_dumper_);
+        optimization, data_dumper_);
     G_main_[ch] = std::make_unique<MainFilterUpdateGain>(
         config_.filter.main_initial,
         config_.filter.config_change_duration_blocks);
@@ -162,14 +162,12 @@
   RTC_DCHECK_EQ(num_capture_channels_, capture.size());
 
   // Compute the render powers.
+  const bool same_filter_sizes =
+      main_filter_[0]->SizePartitions() == shadow_filter_[0]->SizePartitions();
   std::array<float, kFftLengthBy2Plus1> X2_main;
   std::array<float, kFftLengthBy2Plus1> X2_shadow_data;
-  std::array<float, kFftLengthBy2Plus1>& X2_shadow =
-      main_filter_[0]->SizePartitions() == shadow_filter_[0]->SizePartitions()
-          ? X2_main
-          : X2_shadow_data;
-  if (main_filter_[0]->SizePartitions() ==
-      shadow_filter_[0]->SizePartitions()) {
+  auto& X2_shadow = same_filter_sizes ? X2_main : X2_shadow_data;
+  if (same_filter_sizes) {
     render_buffer.SpectralSum(main_filter_[0]->SizePartitions(), &X2_main);
   } else if (main_filter_[0]->SizePartitions() >
              shadow_filter_[0]->SizePartitions()) {
@@ -256,7 +254,8 @@
                              aec_state.SaturatedCapture(), &G);
     } else {
       poor_shadow_filter_counter_[ch] = 0;
-      shadow_filter_[ch]->SetFilter(main_filter_[ch]->GetFilter());
+      shadow_filter_[ch]->SetFilter(main_filter_[ch]->SizePartitions(),
+                                    main_filter_[ch]->GetFilter());
       G_shadow_[ch]->Compute(X2_shadow, render_signal_analyzer, E_main,
                              shadow_filter_[ch]->SizePartitions(),
                              aec_state.SaturatedCapture(), &G);
diff --git a/modules/audio_processing/aec3/subtractor.h b/modules/audio_processing/aec3/subtractor.h
index c5fb765..01d2eef 100644
--- a/modules/audio_processing/aec3/subtractor.h
+++ b/modules/audio_processing/aec3/subtractor.h
@@ -59,26 +59,24 @@
   void ExitInitialState();
 
   // Returns the block-wise frequency responses for the main adaptive filters.
-  // TODO(bugs.webrtc.org/10913): Return the frequency responses for all capture
-  // channels.
-  const std::vector<std::array<float, kFftLengthBy2Plus1>>&
+  const std::vector<std::vector<std::array<float, kFftLengthBy2Plus1>>>&
   FilterFrequencyResponse() const {
-    return main_frequency_response_[0];
+    return main_frequency_response_;
   }
 
   // Returns the estimates of the impulse responses for the main adaptive
   // filters.
-  // TODO(bugs.webrtc.org/10913): Return the impulse responses for all capture
-  // channels.
-  const std::vector<float>& FilterImpulseResponse() const {
-    return main_impulse_response_[0];
+  const std::vector<std::vector<float>>& FilterImpulseResponse() const {
+    return main_impulse_response_;
   }
 
   void DumpFilters() {
-    size_t current_size = main_impulse_response_[0].size();
-    main_impulse_response_[0].resize(main_impulse_response_[0].capacity());
-    data_dumper_->DumpRaw("aec3_subtractor_h_main", main_impulse_response_[0]);
-    main_impulse_response_[0].resize(current_size);
+    data_dumper_->DumpRaw(
+        "aec3_subtractor_h_main",
+        rtc::ArrayView<const float>(
+            main_impulse_response_[0].data(),
+            GetTimeDomainLength(
+                main_filter_[0]->max_filter_size_partitions())));
 
     main_filter_[0]->DumpFilter("aec3_subtractor_H_main");
     shadow_filter_[0]->DumpFilter("aec3_subtractor_H_shadow");
diff --git a/modules/audio_processing/aec3/subtractor_unittest.cc b/modules/audio_processing/aec3/subtractor_unittest.cc
index b5635f4..23e7ead 100644
--- a/modules/audio_processing/aec3/subtractor_unittest.cc
+++ b/modules/audio_processing/aec3/subtractor_unittest.cc
@@ -11,12 +11,14 @@
 #include "modules/audio_processing/aec3/subtractor.h"
 
 #include <algorithm>
+#include <memory>
 #include <numeric>
 #include <string>
 
 #include "modules/audio_processing/aec3/aec_state.h"
 #include "modules/audio_processing/aec3/render_delay_buffer.h"
 #include "modules/audio_processing/test/echo_canceller_test_tools.h"
+#include "modules/audio_processing/utility/cascaded_biquad_filter.h"
 #include "rtc_base/random.h"
 #include "rtc_base/strings/string_builder.h"
 #include "test/gtest.h"
@@ -24,51 +26,104 @@
 namespace webrtc {
 namespace {
 
-float RunSubtractorTest(int num_blocks_to_process,
-                        int delay_samples,
-                        int main_filter_length_blocks,
-                        int shadow_filter_length_blocks,
-                        bool uncorrelated_inputs,
-                        const std::vector<int>& blocks_with_echo_path_changes) {
+std::vector<float> RunSubtractorTest(
+    size_t num_render_channels,
+    size_t num_capture_channels,
+    int num_blocks_to_process,
+    int delay_samples,
+    int main_filter_length_blocks,
+    int shadow_filter_length_blocks,
+    bool uncorrelated_inputs,
+    const std::vector<int>& blocks_with_echo_path_changes) {
   ApmDataDumper data_dumper(42);
-  constexpr size_t kNumChannels = 1;
   constexpr int kSampleRateHz = 48000;
   constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz);
   EchoCanceller3Config config;
   config.filter.main.length_blocks = main_filter_length_blocks;
   config.filter.shadow.length_blocks = shadow_filter_length_blocks;
 
-  Subtractor subtractor(config, 1, 1, &data_dumper, DetectOptimization());
+  Subtractor subtractor(config, num_render_channels, num_capture_channels,
+                        &data_dumper, DetectOptimization());
   absl::optional<DelayEstimate> delay_estimate;
   std::vector<std::vector<std::vector<float>>> x(
       kNumBands, std::vector<std::vector<float>>(
-                     kNumChannels, std::vector<float>(kBlockSize, 0.f)));
-  std::vector<std::vector<float>> y(1, std::vector<float>(kBlockSize, 0.f));
+                     num_render_channels, std::vector<float>(kBlockSize, 0.f)));
+  std::vector<std::vector<float>> y(num_capture_channels,
+                                    std::vector<float>(kBlockSize, 0.f));
   std::array<float, kBlockSize> x_old;
-  std::array<SubtractorOutput, 1> output;
+  std::vector<SubtractorOutput> output(num_capture_channels);
   config.delay.default_delay = 1;
   std::unique_ptr<RenderDelayBuffer> render_delay_buffer(
-      RenderDelayBuffer::Create(config, kSampleRateHz, kNumChannels));
+      RenderDelayBuffer::Create(config, kSampleRateHz, num_render_channels));
   RenderSignalAnalyzer render_signal_analyzer(config);
   Random random_generator(42U);
   Aec3Fft fft;
   std::array<float, kFftLengthBy2Plus1> Y2;
   std::array<float, kFftLengthBy2Plus1> E2_main;
   std::array<float, kFftLengthBy2Plus1> E2_shadow;
-  AecState aec_state(config, kNumChannels);
+  AecState aec_state(config, num_capture_channels);
   x_old.fill(0.f);
   Y2.fill(0.f);
   E2_main.fill(0.f);
   E2_shadow.fill(0.f);
 
-  DelayBuffer<float> delay_buffer(delay_samples);
-  for (int k = 0; k < num_blocks_to_process; ++k) {
-    RandomizeSampleVector(&random_generator, x[0][0]);
-    if (uncorrelated_inputs) {
-      RandomizeSampleVector(&random_generator, y[0]);
-    } else {
-      delay_buffer.Delay(x[0][0], y[0]);
+  std::vector<std::vector<std::unique_ptr<DelayBuffer<float>>>> delay_buffer(
+      num_capture_channels);
+  for (size_t capture_ch = 0; capture_ch < num_capture_channels; ++capture_ch) {
+    delay_buffer[capture_ch].resize(num_render_channels);
+    for (size_t render_ch = 0; render_ch < num_render_channels; ++render_ch) {
+      delay_buffer[capture_ch][render_ch] =
+          std::make_unique<DelayBuffer<float>>(delay_samples);
     }
+  }
+
+  // [B,A] = butter(2,100/8000,'high')
+  constexpr CascadedBiQuadFilter::BiQuadCoefficients
+      kHighPassFilterCoefficients = {{0.97261f, -1.94523f, 0.97261f},
+                                     {-1.94448f, 0.94598f}};
+  std::vector<std::unique_ptr<CascadedBiQuadFilter>> x_hp_filter(
+      num_render_channels);
+  for (size_t ch = 0; ch < num_render_channels; ++ch) {
+    x_hp_filter[ch] =
+        std::make_unique<CascadedBiQuadFilter>(kHighPassFilterCoefficients, 1);
+  }
+  std::vector<std::unique_ptr<CascadedBiQuadFilter>> y_hp_filter(
+      num_capture_channels);
+  for (size_t ch = 0; ch < num_capture_channels; ++ch) {
+    y_hp_filter[ch] =
+        std::make_unique<CascadedBiQuadFilter>(kHighPassFilterCoefficients, 1);
+  }
+
+  for (int k = 0; k < num_blocks_to_process; ++k) {
+    for (size_t render_ch = 0; render_ch < num_render_channels; ++render_ch) {
+      RandomizeSampleVector(&random_generator, x[0][render_ch]);
+    }
+    if (uncorrelated_inputs) {
+      for (size_t capture_ch = 0; capture_ch < num_capture_channels;
+           ++capture_ch) {
+        RandomizeSampleVector(&random_generator, y[capture_ch]);
+      }
+    } else {
+      for (size_t capture_ch = 0; capture_ch < num_capture_channels;
+           ++capture_ch) {
+        for (size_t render_ch = 0; render_ch < num_render_channels;
+             ++render_ch) {
+          std::array<float, kBlockSize> y_channel;
+          delay_buffer[capture_ch][render_ch]->Delay(x[0][render_ch],
+                                                     y_channel);
+          for (size_t k = 0; k < y.size(); ++k) {
+            y[capture_ch][k] += y_channel[k] / num_render_channels;
+          }
+        }
+      }
+    }
+    for (size_t ch = 0; ch < num_render_channels; ++ch) {
+      x_hp_filter[ch]->Process(x[0][ch]);
+    }
+    for (size_t ch = 0; ch < num_capture_channels; ++ch) {
+      y_hp_filter[ch]->Process(y[ch]);
+    }
+
     render_delay_buffer->Insert(x);
     if (k == 0) {
       render_delay_buffer->Reset();
@@ -90,28 +145,37 @@
 
     aec_state.HandleEchoPathChange(EchoPathVariability(
         false, EchoPathVariability::DelayAdjustment::kNone, false));
-    aec_state.Update(delay_estimate, subtractor.FilterFrequencyResponse(),
-                     subtractor.FilterImpulseResponse(),
+    aec_state.Update(delay_estimate, subtractor.FilterFrequencyResponse()[0],
+                     subtractor.FilterImpulseResponse()[0],
                      *render_delay_buffer->GetRenderBuffer(), E2_main, Y2,
                      output);
   }
 
-  const float output_power =
-      std::inner_product(output[0].e_main.begin(), output[0].e_main.end(),
-                         output[0].e_main.begin(), 0.f);
-  const float y_power =
-      std::inner_product(y[0].begin(), y[0].end(), y[0].begin(), 0.f);
-  if (y_power == 0.f) {
-    ADD_FAILURE();
-    return -1.0;
+  std::vector<float> results(num_capture_channels);
+  for (size_t ch = 0; ch < num_capture_channels; ++ch) {
+    const float output_power =
+        std::inner_product(output[ch].e_main.begin(), output[ch].e_main.end(),
+                           output[ch].e_main.begin(), 0.f);
+    const float y_power =
+        std::inner_product(y[ch].begin(), y[ch].end(), y[ch].begin(), 0.f);
+    if (y_power == 0.f) {
+      ADD_FAILURE();
+      results[ch] = -1.f;
+    }
+    results[ch] = output_power / y_power;
   }
-  return output_power / y_power;
+  return results;
 }
 
-std::string ProduceDebugText(size_t delay, int filter_length_blocks) {
+std::string ProduceDebugText(size_t num_render_channels,
+                             size_t num_capture_channels,
+                             size_t delay,
+                             int filter_length_blocks) {
   rtc::StringBuilder ss;
-  ss << "Delay: " << delay << ", ";
-  ss << "Length: " << filter_length_blocks;
+  ss << "delay: " << delay << ", ";
+  ss << "filter_length_blocks:" << filter_length_blocks << ", ";
+  ss << "num_render_channels:" << num_render_channels << ", ";
+  ss << "num_capture_channels:" << num_capture_channels;
   return ss.Release();
 }
 
@@ -150,17 +214,32 @@
   std::vector<int> blocks_with_echo_path_changes;
   for (size_t filter_length_blocks : {12, 20, 30}) {
     for (size_t delay_samples : {0, 64, 150, 200, 301}) {
-      SCOPED_TRACE(ProduceDebugText(delay_samples, filter_length_blocks));
+      SCOPED_TRACE(ProduceDebugText(1, 1, delay_samples, filter_length_blocks));
+      std::vector<float> echo_to_nearend_powers = RunSubtractorTest(
+          1, 1, 2500, delay_samples, filter_length_blocks, filter_length_blocks,
+          false, blocks_with_echo_path_changes);
 
-      float echo_to_nearend_power = RunSubtractorTest(
-          400, delay_samples, filter_length_blocks, filter_length_blocks, false,
-          blocks_with_echo_path_changes);
-
-      // Use different criteria to take overmodelling into account.
-      if (filter_length_blocks == 12) {
+      for (float echo_to_nearend_power : echo_to_nearend_powers) {
         EXPECT_GT(0.1f, echo_to_nearend_power);
-      } else {
-        EXPECT_GT(1.f, echo_to_nearend_power);
+      }
+    }
+  }
+}
+
+// Verifies that the subtractor is able to converge on correlated data.
+TEST(Subtractor, ConvergenceMultiChannel) {
+  std::vector<int> blocks_with_echo_path_changes;
+  for (size_t num_render_channels : {1, 2, 4, 8}) {
+    for (size_t num_capture_channels : {1, 2, 4}) {
+      SCOPED_TRACE(
+          ProduceDebugText(num_render_channels, num_render_channels, 64, 20));
+      size_t num_blocks_to_process = 2500 * num_render_channels;
+      std::vector<float> echo_to_nearend_powers = RunSubtractorTest(
+          num_render_channels, num_capture_channels, num_blocks_to_process, 64,
+          20, 20, false, blocks_with_echo_path_changes);
+
+      for (float echo_to_nearend_power : echo_to_nearend_powers) {
+        EXPECT_GT(0.1f, echo_to_nearend_power);
       }
     }
   }
@@ -170,18 +249,22 @@
 // is longer than the shadow filter.
 TEST(Subtractor, MainFilterLongerThanShadowFilter) {
   std::vector<int> blocks_with_echo_path_changes;
-  float echo_to_nearend_power =
-      RunSubtractorTest(400, 64, 20, 15, false, blocks_with_echo_path_changes);
-  EXPECT_GT(0.5f, echo_to_nearend_power);
+  std::vector<float> echo_to_nearend_powers = RunSubtractorTest(
+      1, 1, 400, 64, 20, 15, false, blocks_with_echo_path_changes);
+  for (float echo_to_nearend_power : echo_to_nearend_powers) {
+    EXPECT_GT(0.5f, echo_to_nearend_power);
+  }
 }
 
 // Verifies that the subtractor is able to handle the case when the shadow
 // filter is longer than the main filter.
 TEST(Subtractor, ShadowFilterLongerThanMainFilter) {
   std::vector<int> blocks_with_echo_path_changes;
-  float echo_to_nearend_power =
-      RunSubtractorTest(400, 64, 15, 20, false, blocks_with_echo_path_changes);
-  EXPECT_GT(0.5f, echo_to_nearend_power);
+  std::vector<float> echo_to_nearend_powers = RunSubtractorTest(
+      1, 1, 400, 64, 15, 20, false, blocks_with_echo_path_changes);
+  for (float echo_to_nearend_power : echo_to_nearend_powers) {
+    EXPECT_GT(0.5f, echo_to_nearend_power);
+  }
 }
 
 // Verifies that the subtractor does not converge on uncorrelated signals.
@@ -189,12 +272,33 @@
   std::vector<int> blocks_with_echo_path_changes;
   for (size_t filter_length_blocks : {12, 20, 30}) {
     for (size_t delay_samples : {0, 64, 150, 200, 301}) {
-      SCOPED_TRACE(ProduceDebugText(delay_samples, filter_length_blocks));
+      SCOPED_TRACE(ProduceDebugText(1, 1, delay_samples, filter_length_blocks));
 
-      float echo_to_nearend_power = RunSubtractorTest(
-          300, delay_samples, filter_length_blocks, filter_length_blocks, true,
-          blocks_with_echo_path_changes);
-      EXPECT_NEAR(1.f, echo_to_nearend_power, 0.1);
+      std::vector<float> echo_to_nearend_powers = RunSubtractorTest(
+          1, 1, 3000, delay_samples, filter_length_blocks, filter_length_blocks,
+          true, blocks_with_echo_path_changes);
+      for (float echo_to_nearend_power : echo_to_nearend_powers) {
+        EXPECT_NEAR(1.f, echo_to_nearend_power, 0.1);
+      }
+    }
+  }
+}
+
+// Verifies that the subtractor does not converge on uncorrelated signals.
+TEST(Subtractor, NonConvergenceOnUncorrelatedSignalsMultiChannel) {
+  std::vector<int> blocks_with_echo_path_changes;
+  for (size_t num_render_channels : {1, 2, 4}) {
+    for (size_t num_capture_channels : {1, 2, 4}) {
+      SCOPED_TRACE(
+          ProduceDebugText(num_render_channels, num_render_channels, 64, 20));
+      size_t num_blocks_to_process = 5000 * num_render_channels;
+      std::vector<float> echo_to_nearend_powers = RunSubtractorTest(
+          num_render_channels, num_capture_channels, num_blocks_to_process, 64,
+          20, 20, true, blocks_with_echo_path_changes);
+      for (float echo_to_nearend_power : echo_to_nearend_powers) {
+        EXPECT_LT(.8f, echo_to_nearend_power);
+        EXPECT_NEAR(1.f, echo_to_nearend_power, 0.25f);
+      }
     }
   }
 }
diff --git a/modules/audio_processing/aec3/suppression_gain_unittest.cc b/modules/audio_processing/aec3/suppression_gain_unittest.cc
index 490c7ec..465227c 100644
--- a/modules/audio_processing/aec3/suppression_gain_unittest.cc
+++ b/modules/audio_processing/aec3/suppression_gain_unittest.cc
@@ -97,14 +97,14 @@
 
   // Ensure that the gain is no longer forced to zero.
   for (int k = 0; k <= kNumBlocksPerSecond / 5 + 1; ++k) {
-    aec_state.Update(delay_estimate, subtractor.FilterFrequencyResponse(),
-                     subtractor.FilterImpulseResponse(),
+    aec_state.Update(delay_estimate, subtractor.FilterFrequencyResponse()[0],
+                     subtractor.FilterImpulseResponse()[0],
                      *render_delay_buffer->GetRenderBuffer(), E2, Y2, output);
   }
 
   for (int k = 0; k < 100; ++k) {
-    aec_state.Update(delay_estimate, subtractor.FilterFrequencyResponse(),
-                     subtractor.FilterImpulseResponse(),
+    aec_state.Update(delay_estimate, subtractor.FilterFrequencyResponse()[0],
+                     subtractor.FilterImpulseResponse()[0],
                      *render_delay_buffer->GetRenderBuffer(), E2, Y2, output);
     suppression_gain.GetGain(E2, S2, R2, N2, analyzer, aec_state, x,
                              &high_bands_gain, &g);
@@ -120,8 +120,8 @@
   N2.fill(0.f);
 
   for (int k = 0; k < 100; ++k) {
-    aec_state.Update(delay_estimate, subtractor.FilterFrequencyResponse(),
-                     subtractor.FilterImpulseResponse(),
+    aec_state.Update(delay_estimate, subtractor.FilterFrequencyResponse()[0],
+                     subtractor.FilterImpulseResponse()[0],
                      *render_delay_buffer->GetRenderBuffer(), E2, Y2, output);
     suppression_gain.GetGain(E2, S2, R2, N2, analyzer, aec_state, x,
                              &high_bands_gain, &g);
diff --git a/modules/audio_processing/test/echo_canceller_test_tools.h b/modules/audio_processing/test/echo_canceller_test_tools.h
index bab7f27..0d70cd3 100644
--- a/modules/audio_processing/test/echo_canceller_test_tools.h
+++ b/modules/audio_processing/test/echo_canceller_test_tools.h
@@ -15,7 +15,6 @@
 #include <vector>
 
 #include "api/array_view.h"
-#include "rtc_base/constructor_magic.h"
 #include "rtc_base/random.h"
 
 namespace webrtc {
@@ -41,7 +40,6 @@
  private:
   std::vector<T> buffer_;
   size_t next_insert_index_ = 0;
-  RTC_DISALLOW_IMPLICIT_CONSTRUCTORS(DelayBuffer);
 };
 
 }  // namespace webrtc