AEC3: Analyze multi-channel SubtractorOutput in AecState

Updates SubtractorOutputAnalyzer and AecState::SaturationDetector
to multi-channel.

Bug: webrtc:10913
Change-Id: I39edafdc5d5a4db5cc853cf116d60af0f506b3bf
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/154342
Commit-Queue: Sam Zackrisson <saza@webrtc.org>
Reviewed-by: Gustaf Ullberg <gustaf@webrtc.org>
Reviewed-by: Per Ã…hgren <peah@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#29355}
diff --git a/modules/audio_processing/aec3/adaptive_fir_filter_unittest.cc b/modules/audio_processing/aec3/adaptive_fir_filter_unittest.cc
index 9318c21..36e31eb 100644
--- a/modules/audio_processing/aec3/adaptive_fir_filter_unittest.cc
+++ b/modules/audio_processing/aec3/adaptive_fir_filter_unittest.cc
@@ -298,6 +298,7 @@
 // adapt its coefficients.
 TEST(AdaptiveFirFilter, FilterAndAdapt) {
   constexpr size_t kNumRenderChannels = 1;
+  constexpr size_t kNumCaptureChannels = 1;
   constexpr int kSampleRateHz = 48000;
   constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz);
 
@@ -325,12 +326,12 @@
                      kNumRenderChannels, std::vector<float>(kBlockSize, 0.f)));
   std::vector<float> n(kBlockSize, 0.f);
   std::vector<float> y(kBlockSize, 0.f);
-  AecState aec_state(EchoCanceller3Config{});
+  AecState aec_state(EchoCanceller3Config{}, kNumCaptureChannels);
   RenderSignalAnalyzer render_signal_analyzer(config);
   absl::optional<DelayEstimate> delay_estimate;
   std::vector<float> e(kBlockSize, 0.f);
   std::array<float, kFftLength> s_scratch;
-  SubtractorOutput output;
+  std::vector<SubtractorOutput> output(kNumCaptureChannels);
   FftData S;
   FftData G;
   FftData E;
@@ -344,7 +345,9 @@
   Y2.fill(0.f);
   E2_main.fill(0.f);
   E2_shadow.fill(0.f);
-  output.Reset();
+  for (auto& subtractor_output : output) {
+    subtractor_output.Reset();
+  }
 
   constexpr float kScale = 1.0f / kFftLengthBy2;
 
@@ -385,7 +388,7 @@
                     [](float& a) { a = rtc::SafeClamp(a, -32768.f, 32767.f); });
       fft.ZeroPaddedFft(e, Aec3Fft::Window::kRectangular, &E);
       for (size_t k = 0; k < kBlockSize; ++k) {
-        output.s_main[k] = kScale * s_scratch[k + kFftLengthBy2];
+        output[0].s_main[k] = kScale * s_scratch[k + kFftLengthBy2];
       }
 
       std::array<float, kFftLengthBy2Plus1> render_power;
@@ -398,7 +401,7 @@
 
       filter.ComputeFrequencyResponse(&H2);
       aec_state.Update(delay_estimate, H2, h, *render_buffer, E2_main, Y2,
-                       output, y);
+                       output);
     }
     // Verify that the filter is able to perform well.
     EXPECT_LT(1000 * std::inner_product(e.begin(), e.end(), e.begin(), 0.f),
diff --git a/modules/audio_processing/aec3/aec_state.cc b/modules/audio_processing/aec3/aec_state.cc
index 8ff2930..97c27d5 100644
--- a/modules/audio_processing/aec3/aec_state.cc
+++ b/modules/audio_processing/aec3/aec_state.cc
@@ -55,7 +55,8 @@
   return absl::nullopt;
 }
 
-AecState::AecState(const EchoCanceller3Config& config)
+AecState::AecState(const EchoCanceller3Config& config,
+                   size_t num_capture_channels)
     : data_dumper_(
           new ApmDataDumper(rtc::AtomicOps::Increment(&instance_count_))),
       config_(config),
@@ -68,7 +69,8 @@
       filter_analyzer_(config_),
       echo_audibility_(
           config_.echo_audibility.use_stationarity_properties_at_init),
-      reverb_model_estimator_(config_) {}
+      reverb_model_estimator_(config_),
+      subtractor_output_analyzers_(num_capture_channels) {}
 
 AecState::~AecState() = default;
 
@@ -95,7 +97,9 @@
   } else if (echo_path_variability.gain_change) {
     erle_estimator_.Reset(false);
   }
-  subtractor_output_analyzer_.HandleEchoPathChange();
+  for (auto& analyzer : subtractor_output_analyzers_) {
+    analyzer.HandleEchoPathChange();
+  }
 }
 
 void AecState::Update(
@@ -106,10 +110,13 @@
     const RenderBuffer& render_buffer,
     const std::array<float, kFftLengthBy2Plus1>& E2_main,
     const std::array<float, kFftLengthBy2Plus1>& Y2,
-    const SubtractorOutput& subtractor_output,
-    rtc::ArrayView<const float> y) {
+    rtc::ArrayView<const SubtractorOutput> subtractor_output) {
+  RTC_DCHECK_EQ(subtractor_output.size(), subtractor_output_analyzers_.size());
+
   // Analyze the filter output.
-  subtractor_output_analyzer_.Update(subtractor_output);
+  for (size_t ch = 0; ch < subtractor_output.size(); ++ch) {
+    subtractor_output_analyzers_[ch].Update(subtractor_output[ch]);
+  }
 
   // Analyze the properties of the filter.
   filter_analyzer_.Update(adaptive_filter_impulse_response, render_buffer);
@@ -120,17 +127,22 @@
                         strong_not_saturated_render_blocks_);
   }
 
-  const std::vector<float>& aligned_render_block =
-      render_buffer.Block(-delay_state_.DirectPathFilterDelay())[0][0];
+  const std::vector<std::vector<float>>& aligned_render_block =
+      render_buffer.Block(-delay_state_.DirectPathFilterDelay())[0];
 
   // Update render counters.
-  const float render_energy = std::inner_product(
-      aligned_render_block.begin(), aligned_render_block.end(),
-      aligned_render_block.begin(), 0.f);
-  const bool active_render =
-      render_energy > (config_.render_levels.active_render_limit *
-                       config_.render_levels.active_render_limit) *
-                          kFftLengthBy2;
+  bool active_render = false;
+  for (size_t ch = 0; ch < aligned_render_block.size(); ++ch) {
+    const float render_energy = std::inner_product(
+        aligned_render_block[ch].begin(), aligned_render_block[ch].end(),
+        aligned_render_block[ch].begin(), 0.f);
+    if (render_energy > (config_.render_levels.active_render_limit *
+                         config_.render_levels.active_render_limit) *
+                            kFftLengthBy2) {
+      active_render = true;
+      break;
+    }
+  }
   blocks_with_active_render_ += active_render ? 1 : 0;
   strong_not_saturated_render_blocks_ +=
       active_render && !SaturatedCapture() ? 1 : 0;
@@ -153,16 +165,18 @@
     erle_estimator_.Reset(false);
   }
 
+  // TODO(bugs.webrtc.org/10913): Take all channels into account.
   const auto& X2 = render_buffer.Spectrum(delay_state_.DirectPathFilterDelay(),
                                           /*channel=*/0);
   const auto& X2_input_erle = X2_reverb;
 
   erle_estimator_.Update(render_buffer, adaptive_filter_frequency_response,
                          X2_input_erle, Y2, E2_main,
-                         subtractor_output_analyzer_.ConvergedFilter(),
+                         subtractor_output_analyzers_[0].ConvergedFilter(),
                          config_.erle.onset_detection);
 
-  erl_estimator_.Update(subtractor_output_analyzer_.ConvergedFilter(), X2, Y2);
+  erl_estimator_.Update(subtractor_output_analyzers_[0].ConvergedFilter(), X2,
+                        Y2);
 
   // Detect and flag echo saturation.
   saturation_detector_.Update(aligned_render_block, SaturatedCapture(),
@@ -175,15 +189,15 @@
   // Detect whether the transparent mode should be activated.
   transparent_state_.Update(delay_state_.DirectPathFilterDelay(),
                             filter_analyzer_.Consistent(),
-                            subtractor_output_analyzer_.ConvergedFilter(),
-                            subtractor_output_analyzer_.DivergedFilter(),
+                            subtractor_output_analyzers_[0].ConvergedFilter(),
+                            subtractor_output_analyzers_[0].DivergedFilter(),
                             active_render, SaturatedCapture());
 
   // Analyze the quality of the filter.
-  filter_quality_state_.Update(active_render, TransparentMode(),
-                               SaturatedCapture(),
-                               filter_analyzer_.Consistent(), external_delay,
-                               subtractor_output_analyzer_.ConvergedFilter());
+  filter_quality_state_.Update(
+      active_render, TransparentMode(), SaturatedCapture(),
+      filter_analyzer_.Consistent(), external_delay,
+      subtractor_output_analyzers_[0].ConvergedFilter());
 
   // Update the reverb estimate.
   const bool stationary_block =
@@ -212,9 +226,9 @@
   data_dumper_->DumpRaw("aec3_capture_saturation", SaturatedCapture());
   data_dumper_->DumpRaw("aec3_echo_saturation", SaturatedEcho());
   data_dumper_->DumpRaw("aec3_converged_filter",
-                        subtractor_output_analyzer_.ConvergedFilter());
+                        subtractor_output_analyzers_[0].ConvergedFilter());
   data_dumper_->DumpRaw("aec3_diverged_filter",
-                        subtractor_output_analyzer_.DivergedFilter());
+                        subtractor_output_analyzers_[0].DivergedFilter());
 
   data_dumper_->DumpRaw("aec3_external_delay_avaliable",
                         external_delay ? 1 : 0);
@@ -406,27 +420,36 @@
   usable_linear_estimate_ = usable_linear_estimate_ && !transparent_mode;
 }
 
-
 void AecState::SaturationDetector::Update(
-    rtc::ArrayView<const float> x,
+    rtc::ArrayView<const std::vector<float>> x,
     bool saturated_capture,
     bool usable_linear_estimate,
-    const SubtractorOutput& subtractor_output,
+    rtc::ArrayView<const SubtractorOutput> subtractor_output,
     float echo_path_gain) {
-  saturated_echo_ = saturated_capture;
+  saturated_echo_ = false;
+  if (!saturated_capture) {
+    return;
+  }
+
   if (usable_linear_estimate) {
     constexpr float kSaturationThreshold = 20000.f;
-    saturated_echo_ =
-        saturated_echo_ &&
-        (subtractor_output.s_main_max_abs > kSaturationThreshold ||
-         subtractor_output.s_shadow_max_abs > kSaturationThreshold);
+    for (size_t ch = 0; ch < subtractor_output.size(); ++ch) {
+      saturated_echo_ =
+          saturated_echo_ ||
+          (subtractor_output[ch].s_main_max_abs > kSaturationThreshold ||
+           subtractor_output[ch].s_shadow_max_abs > kSaturationThreshold);
+    }
   } else {
-    const float max_sample = fabs(*std::max_element(
-        x.begin(), x.end(), [](float a, float b) { return a * a < b * b; }));
+    float max_sample = 0.f;
+    for (auto& channel : x) {
+      for (float sample : channel) {
+        max_sample = std::max(max_sample, fabsf(sample));
+      }
+    }
 
     const float kMargin = 10.f;
     float peak_echo_amplitude = max_sample * echo_path_gain * kMargin;
-    saturated_echo_ = saturated_echo_ && peak_echo_amplitude > 32000;
+    saturated_echo_ = saturated_echo_ || peak_echo_amplitude > 32000;
   }
 }
 
diff --git a/modules/audio_processing/aec3/aec_state.h b/modules/audio_processing/aec3/aec_state.h
index 43cdb0b..1229732 100644
--- a/modules/audio_processing/aec3/aec_state.h
+++ b/modules/audio_processing/aec3/aec_state.h
@@ -40,7 +40,7 @@
 // Handles the state and the conditions for the echo removal functionality.
 class AecState {
  public:
-  explicit AecState(const EchoCanceller3Config& config);
+  AecState(const EchoCanceller3Config& config, size_t num_capture_channels);
   ~AecState();
 
   // Returns whether the echo subtractor can be used to determine the residual
@@ -129,6 +129,8 @@
   }
 
   // Updates the aec state.
+  // TODO(bugs.webrtc.org/10913): Handle multi-channel adaptive filter response.
+  // TODO(bugs.webrtc.org/10913): Compute multi-channel ERL, ERLE, and reverb.
   void Update(const absl::optional<DelayEstimate>& external_delay,
               const std::vector<std::array<float, kFftLengthBy2Plus1>>&
                   adaptive_filter_frequency_response,
@@ -136,8 +138,7 @@
               const RenderBuffer& render_buffer,
               const std::array<float, kFftLengthBy2Plus1>& E2_main,
               const std::array<float, kFftLengthBy2Plus1>& Y2,
-              const SubtractorOutput& subtractor_output,
-              rtc::ArrayView<const float> y);
+              rtc::ArrayView<const SubtractorOutput> subtractor_output);
 
   // Returns filter length in blocks.
   int FilterLengthBlocks() const {
@@ -275,10 +276,10 @@
     bool SaturatedEcho() const { return saturated_echo_; }
 
     // Updates the detection decision based on new data.
-    void Update(rtc::ArrayView<const float> x,
+    void Update(rtc::ArrayView<const std::vector<float>> x,
                 bool saturated_capture,
                 bool usable_linear_estimate,
-                const SubtractorOutput& subtractor_output,
+                rtc::ArrayView<const SubtractorOutput> subtractor_output,
                 float echo_path_gain);
 
    private:
@@ -295,7 +296,7 @@
   EchoAudibility echo_audibility_;
   ReverbModelEstimator reverb_model_estimator_;
   RenderReverbModel render_reverb_;
-  SubtractorOutputAnalyzer subtractor_output_analyzer_;
+  std::vector<SubtractorOutputAnalyzer> subtractor_output_analyzers_;
 };
 
 }  // namespace webrtc
diff --git a/modules/audio_processing/aec3/aec_state_unittest.cc b/modules/audio_processing/aec3/aec_state_unittest.cc
index 4631eac..ccf953a 100644
--- a/modules/audio_processing/aec3/aec_state_unittest.cc
+++ b/modules/audio_processing/aec3/aec_state_unittest.cc
@@ -13,37 +13,48 @@
 #include "modules/audio_processing/aec3/aec3_fft.h"
 #include "modules/audio_processing/aec3/render_delay_buffer.h"
 #include "modules/audio_processing/logging/apm_data_dumper.h"
+#include "rtc_base/strings/string_builder.h"
 #include "test/gtest.h"
 
 namespace webrtc {
+namespace {
+std::string ProduceDebugText(size_t num_render_channels,
+                             size_t num_capture_channels) {
+  rtc::StringBuilder ss;
+  ss << "Render channels: " << num_render_channels;
+  ss << ", Capture channels: " << num_capture_channels;
+  return ss.Release();
+}
 
-// Verify the general functionality of AecState
-TEST(AecState, NormalUsage) {
-  constexpr size_t kNumChannels = 1;
+void RunNormalUsageTest(size_t num_render_channels,
+                        size_t num_capture_channels) {
+  // TODO(bugs.webrtc.org/10913): Test with different content in different
+  // channels.
   constexpr int kSampleRateHz = 48000;
   constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz);
   ApmDataDumper data_dumper(42);
   EchoCanceller3Config config;
-  AecState state(config);
+  AecState state(config, num_capture_channels);
   absl::optional<DelayEstimate> delay_estimate =
       DelayEstimate(DelayEstimate::Quality::kRefined, 10);
   std::unique_ptr<RenderDelayBuffer> render_delay_buffer(
-      RenderDelayBuffer::Create(config, kSampleRateHz, kNumChannels));
+      RenderDelayBuffer::Create(config, kSampleRateHz, num_render_channels));
   std::array<float, kFftLengthBy2Plus1> E2_main = {};
   std::array<float, kFftLengthBy2Plus1> Y2 = {};
   std::vector<std::vector<std::vector<float>>> x(
       kNumBands, std::vector<std::vector<float>>(
-                     kNumChannels, std::vector<float>(kBlockSize, 0.f)));
+                     num_render_channels, std::vector<float>(kBlockSize, 0.f)));
   EchoPathVariability echo_path_variability(
       false, EchoPathVariability::DelayAdjustment::kNone, false);
-  SubtractorOutput output;
-  output.Reset();
-  std::array<float, kBlockSize> y;
+  std::vector<std::array<float, kBlockSize>> y(num_capture_channels);
+  std::vector<SubtractorOutput> subtractor_output(num_capture_channels);
+  for (size_t ch = 0; ch < num_capture_channels; ++ch) {
+    subtractor_output[ch].Reset();
+    subtractor_output[ch].s_main.fill(100.f);
+    subtractor_output[ch].e_main.fill(100.f);
+    y[ch].fill(1000.f);
+  }
   Aec3Fft fft;
-  output.s_main.fill(100.f);
-  output.e_main.fill(100.f);
-  y.fill(1000.f);
-
   std::vector<std::array<float, kFftLengthBy2Plus1>>
       converged_filter_frequency_response(10);
   for (auto& v : converged_filter_frequency_response) {
@@ -53,52 +64,59 @@
       diverged_filter_frequency_response = converged_filter_frequency_response;
   converged_filter_frequency_response[2].fill(100.f);
   converged_filter_frequency_response[2][0] = 1.f;
-
   std::vector<float> impulse_response(
       GetTimeDomainLength(config.filter.main.length_blocks), 0.f);
 
   // Verify that linear AEC usability is true when the filter is converged
   for (size_t band = 0; band < kNumBands; ++band) {
-    for (size_t channel = 0; channel < kNumChannels; ++channel) {
-      std::fill(x[band][channel].begin(), x[band][channel].end(), 101.f);
+    for (size_t ch = 0; ch < num_render_channels; ++ch) {
+      std::fill(x[band][ch].begin(), x[band][ch].end(), 101.f);
     }
   }
   for (int k = 0; k < 3000; ++k) {
     render_delay_buffer->Insert(x);
-    output.ComputeMetrics(y);
+    for (size_t ch = 0; ch < num_capture_channels; ++ch) {
+      subtractor_output[ch].ComputeMetrics(y[ch]);
+    }
     state.Update(delay_estimate, converged_filter_frequency_response,
                  impulse_response, *render_delay_buffer->GetRenderBuffer(),
-                 E2_main, Y2, output, y);
+                 E2_main, Y2, subtractor_output);
   }
   EXPECT_TRUE(state.UsableLinearEstimate());
 
-  // Verify that linear AEC usability becomes false after an echo path change is
-  // reported
-  output.ComputeMetrics(y);
+  // Verify that linear AEC usability becomes false after an echo path
+  // change is reported
+  for (size_t ch = 0; ch < num_capture_channels; ++ch) {
+    subtractor_output[ch].ComputeMetrics(y[ch]);
+  }
   state.HandleEchoPathChange(EchoPathVariability(
       false, EchoPathVariability::DelayAdjustment::kBufferReadjustment, false));
   state.Update(delay_estimate, converged_filter_frequency_response,
                impulse_response, *render_delay_buffer->GetRenderBuffer(),
-               E2_main, Y2, output, y);
+               E2_main, Y2, subtractor_output);
   EXPECT_FALSE(state.UsableLinearEstimate());
 
   // Verify that the active render detection works as intended.
   std::fill(x[0][0].begin(), x[0][0].end(), 101.f);
   render_delay_buffer->Insert(x);
-  output.ComputeMetrics(y);
+  for (size_t ch = 0; ch < num_capture_channels; ++ch) {
+    subtractor_output[ch].ComputeMetrics(y[ch]);
+  }
   state.HandleEchoPathChange(EchoPathVariability(
       true, EchoPathVariability::DelayAdjustment::kNewDetectedDelay, false));
   state.Update(delay_estimate, converged_filter_frequency_response,
                impulse_response, *render_delay_buffer->GetRenderBuffer(),
-               E2_main, Y2, output, y);
+               E2_main, Y2, subtractor_output);
   EXPECT_FALSE(state.ActiveRender());
 
   for (int k = 0; k < 1000; ++k) {
     render_delay_buffer->Insert(x);
-    output.ComputeMetrics(y);
+    for (size_t ch = 0; ch < num_capture_channels; ++ch) {
+      subtractor_output[ch].ComputeMetrics(y[ch]);
+    }
     state.Update(delay_estimate, converged_filter_frequency_response,
                  impulse_response, *render_delay_buffer->GetRenderBuffer(),
-                 E2_main, Y2, output, y);
+                 E2_main, Y2, subtractor_output);
   }
   EXPECT_TRUE(state.ActiveRender());
 
@@ -121,10 +139,12 @@
 
   Y2.fill(10.f * 10000.f * 10000.f);
   for (size_t k = 0; k < 1000; ++k) {
-    output.ComputeMetrics(y);
+    for (size_t ch = 0; ch < num_capture_channels; ++ch) {
+      subtractor_output[ch].ComputeMetrics(y[ch]);
+    }
     state.Update(delay_estimate, converged_filter_frequency_response,
                  impulse_response, *render_delay_buffer->GetRenderBuffer(),
-                 E2_main, Y2, output, y);
+                 E2_main, Y2, subtractor_output);
   }
 
   ASSERT_TRUE(state.UsableLinearEstimate());
@@ -139,15 +159,17 @@
   E2_main.fill(1.f * 10000.f * 10000.f);
   Y2.fill(10.f * E2_main[0]);
   for (size_t k = 0; k < 1000; ++k) {
-    output.ComputeMetrics(y);
+    for (size_t ch = 0; ch < num_capture_channels; ++ch) {
+      subtractor_output[ch].ComputeMetrics(y[ch]);
+    }
     state.Update(delay_estimate, converged_filter_frequency_response,
                  impulse_response, *render_delay_buffer->GetRenderBuffer(),
-                 E2_main, Y2, output, y);
+                 E2_main, Y2, subtractor_output);
   }
   ASSERT_TRUE(state.UsableLinearEstimate());
   {
-    // Note that the render spectrum is built so it does not have energy in the
-    // odd bands but just in the even bands.
+    // Note that the render spectrum is built so it does not have energy in
+    // the odd bands but just in the even bands.
     const auto& erle = state.Erle();
     EXPECT_EQ(erle[0], erle[1]);
     constexpr size_t kLowFrequencyLimit = 32;
@@ -163,10 +185,12 @@
   E2_main.fill(1.f * 10000.f * 10000.f);
   Y2.fill(5.f * E2_main[0]);
   for (size_t k = 0; k < 1000; ++k) {
-    output.ComputeMetrics(y);
+    for (size_t ch = 0; ch < num_capture_channels; ++ch) {
+      subtractor_output[ch].ComputeMetrics(y[ch]);
+    }
     state.Update(delay_estimate, converged_filter_frequency_response,
                  impulse_response, *render_delay_buffer->GetRenderBuffer(),
-                 E2_main, Y2, output, y);
+                 E2_main, Y2, subtractor_output);
   }
 
   ASSERT_TRUE(state.UsableLinearEstimate());
@@ -184,11 +208,24 @@
   }
 }
 
+}  // namespace
+
+// Verify the general functionality of AecState
+TEST(AecState, NormalUsage) {
+  for (size_t num_render_channels : {1, 2, 8}) {
+    for (size_t num_capture_channels : {1, 2, 8}) {
+      SCOPED_TRACE(ProduceDebugText(num_render_channels, num_capture_channels));
+      RunNormalUsageTest(num_render_channels, num_capture_channels);
+    }
+  }
+}
+
 // Verifies the delay for a converged filter is correctly identified.
 TEST(AecState, ConvergedFilterDelay) {
   constexpr int kFilterLengthBlocks = 10;
+  constexpr size_t kNumCaptureChannels = 1;
   EchoCanceller3Config config;
-  AecState state(config);
+  AecState state(config, kNumCaptureChannels);
   std::unique_ptr<RenderDelayBuffer> render_delay_buffer(
       RenderDelayBuffer::Create(config, 48000, 1));
   absl::optional<DelayEstimate> delay_estimate;
@@ -197,10 +234,12 @@
   std::array<float, kBlockSize> x;
   EchoPathVariability echo_path_variability(
       false, EchoPathVariability::DelayAdjustment::kNone, false);
-  SubtractorOutput output;
-  output.Reset();
+  std::vector<SubtractorOutput> subtractor_output(kNumCaptureChannels);
+  for (auto& output : subtractor_output) {
+    output.Reset();
+    output.s_main.fill(100.f);
+  }
   std::array<float, kBlockSize> y;
-  output.s_main.fill(100.f);
   x.fill(0.f);
   y.fill(0.f);
 
@@ -213,16 +252,17 @@
   std::vector<float> impulse_response(
       GetTimeDomainLength(config.filter.main.length_blocks), 0.f);
 
-  // Verify that the filter delay for a converged filter is properly identified.
+  // Verify that the filter delay for a converged filter is properly
+  // identified.
   for (int k = 0; k < kFilterLengthBlocks; ++k) {
     std::fill(impulse_response.begin(), impulse_response.end(), 0.f);
     impulse_response[k * kBlockSize + 1] = 1.f;
 
     state.HandleEchoPathChange(echo_path_variability);
-    output.ComputeMetrics(y);
+    subtractor_output[0].ComputeMetrics(y);
     state.Update(delay_estimate, frequency_response, impulse_response,
-                 *render_delay_buffer->GetRenderBuffer(), E2_main, Y2, output,
-                 y);
+                 *render_delay_buffer->GetRenderBuffer(), E2_main, Y2,
+                 subtractor_output);
   }
 }
 
diff --git a/modules/audio_processing/aec3/comfort_noise_generator_unittest.cc b/modules/audio_processing/aec3/comfort_noise_generator_unittest.cc
index bac30b4..94aa039 100644
--- a/modules/audio_processing/aec3/comfort_noise_generator_unittest.cc
+++ b/modules/audio_processing/aec3/comfort_noise_generator_unittest.cc
@@ -37,7 +37,7 @@
   FftData noise;
   EXPECT_DEATH(
       ComfortNoiseGenerator(DetectOptimization(), 42)
-          .Compute(AecState(EchoCanceller3Config{}), N2, nullptr, &noise),
+          .Compute(AecState(EchoCanceller3Config{}, 1), N2, nullptr, &noise),
       "");
 }
 
@@ -46,7 +46,7 @@
   FftData noise;
   EXPECT_DEATH(
       ComfortNoiseGenerator(DetectOptimization(), 42)
-          .Compute(AecState(EchoCanceller3Config{}), N2, &noise, nullptr),
+          .Compute(AecState(EchoCanceller3Config{}, 1), N2, &noise, nullptr),
       "");
 }
 
@@ -54,7 +54,7 @@
 
 TEST(ComfortNoiseGenerator, CorrectLevel) {
   ComfortNoiseGenerator cng(DetectOptimization(), 42);
-  AecState aec_state(EchoCanceller3Config{});
+  AecState aec_state(EchoCanceller3Config{}, 1);
 
   std::array<float, kFftLengthBy2Plus1> N2;
   N2.fill(1000.f * 1000.f);
diff --git a/modules/audio_processing/aec3/echo_remover.cc b/modules/audio_processing/aec3/echo_remover.cc
index c9a58ec..2df9cfd 100644
--- a/modules/audio_processing/aec3/echo_remover.cc
+++ b/modules/audio_processing/aec3/echo_remover.cc
@@ -202,7 +202,7 @@
                           num_capture_channels_),
       render_signal_analyzer_(config_),
       residual_echo_estimators_(num_capture_channels_),
-      aec_state_(config_),
+      aec_state_(config_, num_capture_channels_),
       e_old_(num_capture_channels_),
       y_old_(num_capture_channels_),
       e_heap_(NumChannelsOnHeap(num_capture_channels_)),
@@ -388,7 +388,7 @@
   // TODO(bugs.webrtc.org/10913): Take all subtractors into account.
   aec_state_.Update(external_delay, subtractor_.FilterFrequencyResponse(),
                     subtractor_.FilterImpulseResponse(), *render_buffer, E2[0],
-                    Y2[0], subtractor_output[0], y0);
+                    Y2[0], subtractor_output);
 
   // Choose the linear output.
   const auto& Y_fft = aec_state_.UseLinearFilterOutput() ? E : Y;
diff --git a/modules/audio_processing/aec3/echo_remover_metrics_unittest.cc b/modules/audio_processing/aec3/echo_remover_metrics_unittest.cc
index c16c7ea..30c6611 100644
--- a/modules/audio_processing/aec3/echo_remover_metrics_unittest.cc
+++ b/modules/audio_processing/aec3/echo_remover_metrics_unittest.cc
@@ -138,7 +138,7 @@
 // Verify the general functionality of EchoRemoverMetrics.
 TEST(EchoRemoverMetrics, NormalUsage) {
   EchoRemoverMetrics metrics;
-  AecState aec_state(EchoCanceller3Config{});
+  AecState aec_state(EchoCanceller3Config{}, 1);
   std::array<float, kFftLengthBy2Plus1> comfort_noise_spectrum;
   std::array<float, kFftLengthBy2Plus1> suppressor_gain;
   comfort_noise_spectrum.fill(10.f);
diff --git a/modules/audio_processing/aec3/main_filter_update_gain_unittest.cc b/modules/audio_processing/aec3/main_filter_update_gain_unittest.cc
index e78f1cd..20714ce 100644
--- a/modules/audio_processing/aec3/main_filter_update_gain_unittest.cc
+++ b/modules/audio_processing/aec3/main_filter_update_gain_unittest.cc
@@ -83,21 +83,23 @@
   config.delay.default_delay = 1;
   std::unique_ptr<RenderDelayBuffer> render_delay_buffer(
       RenderDelayBuffer::Create(config, kSampleRateHz, kNumChannels));
-  AecState aec_state(config);
+  AecState aec_state(config, kNumChannels);
   RenderSignalAnalyzer render_signal_analyzer(config);
   absl::optional<DelayEstimate> delay_estimate;
   std::array<float, kFftLength> s_scratch;
   std::array<float, kBlockSize> s;
   FftData S;
   FftData G;
-  SubtractorOutput output;
-  output.Reset();
-  FftData& E_main = output.E_main;
+  std::vector<SubtractorOutput> output(kNumChannels);
+  for (auto& subtractor_output : output) {
+    subtractor_output.Reset();
+  }
+  FftData& E_main = output[0].E_main;
   FftData E_shadow;
   std::array<float, kFftLengthBy2Plus1> Y2;
-  std::array<float, kFftLengthBy2Plus1>& E2_main = output.E2_main;
-  std::array<float, kBlockSize>& e_main = output.e_main;
-  std::array<float, kBlockSize>& e_shadow = output.e_shadow;
+  std::array<float, kFftLengthBy2Plus1>& E2_main = output[0].E2_main;
+  std::array<float, kBlockSize>& e_main = output[0].e_main;
+  std::array<float, kBlockSize>& e_shadow = output[0].e_shadow;
   Y2.fill(0.f);
 
   constexpr float kScale = 1.0f / kFftLengthBy2;
@@ -165,8 +167,8 @@
     fft.ZeroPaddedFft(e_shadow, Aec3Fft::Window::kRectangular, &E_shadow);
 
     // Compute spectra for future use.
-    E_main.Spectrum(Aec3Optimization::kNone, output.E2_main);
-    E_shadow.Spectrum(Aec3Optimization::kNone, output.E2_shadow);
+    E_main.Spectrum(Aec3Optimization::kNone, output[0].E2_main);
+    E_shadow.Spectrum(Aec3Optimization::kNone, output[0].E2_shadow);
 
     // Adapt the shadow filter.
     std::array<float, kFftLengthBy2Plus1> render_power;
@@ -182,7 +184,7 @@
 
     std::array<float, kFftLengthBy2Plus1> erl;
     ComputeErl(optimization, H2, erl);
-    main_gain.Compute(render_power, render_signal_analyzer, output, erl,
+    main_gain.Compute(render_power, render_signal_analyzer, output[0], erl,
                       main_filter.SizePartitions(), saturation, &G);
     main_filter.Adapt(*render_delay_buffer->GetRenderBuffer(), G, &h);
 
@@ -192,7 +194,7 @@
     main_filter.ComputeFrequencyResponse(&H2);
     aec_state.Update(delay_estimate, H2, h,
                      *render_delay_buffer->GetRenderBuffer(), E2_main, Y2,
-                     output, y);
+                     output);
   }
 
   std::copy(e_main.begin(), e_main.end(), e_last_block->begin());
diff --git a/modules/audio_processing/aec3/residual_echo_estimator_unittest.cc b/modules/audio_processing/aec3/residual_echo_estimator_unittest.cc
index 863f8f8..2823cae 100644
--- a/modules/audio_processing/aec3/residual_echo_estimator_unittest.cc
+++ b/modules/audio_processing/aec3/residual_echo_estimator_unittest.cc
@@ -25,7 +25,7 @@
 // Verifies that the check for non-null output residual echo power works.
 TEST(ResidualEchoEstimator, NullResidualEchoPowerOutput) {
   EchoCanceller3Config config;
-  AecState aec_state(config);
+  AecState aec_state(config, 1);
   std::unique_ptr<RenderDelayBuffer> render_delay_buffer(
       RenderDelayBuffer::Create(config, 48000, 1));
   std::vector<std::array<float, kFftLengthBy2Plus1>> H2;
@@ -49,7 +49,7 @@
   EchoCanceller3Config config;
   config.ep_strength.default_len = 0.f;
   ResidualEchoEstimator estimator(config);
-  AecState aec_state(config);
+  AecState aec_state(config, kNumChannels);
   std::unique_ptr<RenderDelayBuffer> render_delay_buffer(
       RenderDelayBuffer::Create(config, kSampleRateHz, kNumChannels));
 
@@ -66,7 +66,7 @@
                      kNumChannels, std::vector<float>(kBlockSize, 0.f)));
   std::vector<std::array<float, kFftLengthBy2Plus1>> H2(10);
   Random random_generator(42U);
-  SubtractorOutput output;
+  std::vector<SubtractorOutput> output(kNumChannels);
   std::array<float, kBlockSize> y;
   Aec3Fft fft;
   absl::optional<DelayEstimate> delay_estimate;
@@ -80,8 +80,10 @@
   std::vector<float> h(GetTimeDomainLength(config.filter.main.length_blocks),
                        0.f);
 
-  output.Reset();
-  output.s_main.fill(100.f);
+  for (auto& subtractor_output : output) {
+    subtractor_output.Reset();
+    subtractor_output.s_main.fill(100.f);
+  }
   y.fill(0.f);
 
   constexpr float kLevel = 10.f;
@@ -103,7 +105,7 @@
     aec_state.HandleEchoPathChange(echo_path_variability);
     aec_state.Update(delay_estimate, H2, h,
                      *render_delay_buffer->GetRenderBuffer(), E2_main, Y2,
-                     output, y);
+                     output);
 
     estimator.Estimate(aec_state, *render_delay_buffer->GetRenderBuffer(),
                        S2_linear, Y2, &R2);
diff --git a/modules/audio_processing/aec3/shadow_filter_update_gain_unittest.cc b/modules/audio_processing/aec3/shadow_filter_update_gain_unittest.cc
index 300f6b1..605f570 100644
--- a/modules/audio_processing/aec3/shadow_filter_update_gain_unittest.cc
+++ b/modules/audio_processing/aec3/shadow_filter_update_gain_unittest.cc
@@ -64,7 +64,6 @@
       std::vector<std::vector<float>>(num_render_channels,
                                       std::vector<float>(kBlockSize, 0.f)));
   std::array<float, kBlockSize> y;
-  AecState aec_state(config);
   RenderSignalAnalyzer render_signal_analyzer(config);
   std::array<float, kFftLength> s;
   FftData S;
diff --git a/modules/audio_processing/aec3/subtractor_output.h b/modules/audio_processing/aec3/subtractor_output.h
index 5f6fd3e..2822b08 100644
--- a/modules/audio_processing/aec3/subtractor_output.h
+++ b/modules/audio_processing/aec3/subtractor_output.h
@@ -19,7 +19,8 @@
 
 namespace webrtc {
 
-// Stores the values being returned from the echo subtractor.
+// Stores the values being returned from the echo subtractor for a single
+// capture channel.
 struct SubtractorOutput {
   SubtractorOutput();
   ~SubtractorOutput();
diff --git a/modules/audio_processing/aec3/subtractor_unittest.cc b/modules/audio_processing/aec3/subtractor_unittest.cc
index daacbd3..b5635f4 100644
--- a/modules/audio_processing/aec3/subtractor_unittest.cc
+++ b/modules/audio_processing/aec3/subtractor_unittest.cc
@@ -55,7 +55,7 @@
   std::array<float, kFftLengthBy2Plus1> Y2;
   std::array<float, kFftLengthBy2Plus1> E2_main;
   std::array<float, kFftLengthBy2Plus1> E2_shadow;
-  AecState aec_state(config);
+  AecState aec_state(config, kNumChannels);
   x_old.fill(0.f);
   Y2.fill(0.f);
   E2_main.fill(0.f);
@@ -93,7 +93,7 @@
     aec_state.Update(delay_estimate, subtractor.FilterFrequencyResponse(),
                      subtractor.FilterImpulseResponse(),
                      *render_delay_buffer->GetRenderBuffer(), E2_main, Y2,
-                     output[0], y[0]);
+                     output);
   }
 
   const float output_power =
@@ -139,7 +139,7 @@
 
   EXPECT_DEATH(
       subtractor.Process(*render_delay_buffer->GetRenderBuffer(), y,
-                         render_signal_analyzer, AecState(config), output),
+                         render_signal_analyzer, AecState(config, 1), output),
       "");
 }
 
diff --git a/modules/audio_processing/aec3/suppression_gain_unittest.cc b/modules/audio_processing/aec3/suppression_gain_unittest.cc
index cfd92be..490c7ec 100644
--- a/modules/audio_processing/aec3/suppression_gain_unittest.cc
+++ b/modules/audio_processing/aec3/suppression_gain_unittest.cc
@@ -42,7 +42,7 @@
   Y.im.fill(0.f);
 
   float high_bands_gain;
-  AecState aec_state(EchoCanceller3Config{});
+  AecState aec_state(EchoCanceller3Config{}, 1);
   EXPECT_DEATH(
       SuppressionGain(EchoCanceller3Config{}, DetectOptimization(), 16000)
           .GetGain(E2, S2, R2, N2,
@@ -71,13 +71,13 @@
   std::array<float, kFftLengthBy2Plus1> R2;
   std::array<float, kFftLengthBy2Plus1> N2;
   std::array<float, kFftLengthBy2Plus1> g;
-  SubtractorOutput output;
+  std::vector<SubtractorOutput> output(kNumChannels);
   std::array<float, kBlockSize> y;
   std::vector<std::vector<std::vector<float>>> x(
       kNumBands, std::vector<std::vector<float>>(
                      kNumChannels, std::vector<float>(kBlockSize, 0.f)));
   EchoCanceller3Config config;
-  AecState aec_state(config);
+  AecState aec_state(config, kNumChannels);
   ApmDataDumper data_dumper(42);
   Subtractor subtractor(config, 1, 1, &data_dumper, DetectOptimization());
   std::unique_ptr<RenderDelayBuffer> render_delay_buffer(
@@ -90,22 +90,22 @@
   R2.fill(0.1f);
   S2.fill(0.1f);
   N2.fill(100.f);
-  output.Reset();
+  for (auto& subtractor_output : output) {
+    subtractor_output.Reset();
+  }
   y.fill(0.f);
 
   // Ensure that the gain is no longer forced to zero.
   for (int k = 0; k <= kNumBlocksPerSecond / 5 + 1; ++k) {
     aec_state.Update(delay_estimate, subtractor.FilterFrequencyResponse(),
                      subtractor.FilterImpulseResponse(),
-                     *render_delay_buffer->GetRenderBuffer(), E2, Y2, output,
-                     y);
+                     *render_delay_buffer->GetRenderBuffer(), E2, Y2, output);
   }
 
   for (int k = 0; k < 100; ++k) {
     aec_state.Update(delay_estimate, subtractor.FilterFrequencyResponse(),
                      subtractor.FilterImpulseResponse(),
-                     *render_delay_buffer->GetRenderBuffer(), E2, Y2, output,
-                     y);
+                     *render_delay_buffer->GetRenderBuffer(), E2, Y2, output);
     suppression_gain.GetGain(E2, S2, R2, N2, analyzer, aec_state, x,
                              &high_bands_gain, &g);
   }
@@ -122,8 +122,7 @@
   for (int k = 0; k < 100; ++k) {
     aec_state.Update(delay_estimate, subtractor.FilterFrequencyResponse(),
                      subtractor.FilterImpulseResponse(),
-                     *render_delay_buffer->GetRenderBuffer(), E2, Y2, output,
-                     y);
+                     *render_delay_buffer->GetRenderBuffer(), E2, Y2, output);
     suppression_gain.GetGain(E2, S2, R2, N2, analyzer, aec_state, x,
                              &high_bands_gain, &g);
   }