AEC3: Add multichannel support to the residual echo estimator

This CL adds support for multichannel in the residual echo
estimator code. It also adds placeholder functionality in
the surrounding code to ensure that the residual echo
estimator receives the require inputs.

The changes in the CL has been shown to be bitexact on a
large set of mono recordings.

Bug: webrtc:10913
Change-Id: I726128ca928648b1dcf36c5f479eb243f3ff3f96
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/155361
Commit-Queue: Per Åhgren <peah@webrtc.org>
Reviewed-by: Sam Zackrisson <saza@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#29400}
diff --git a/modules/audio_processing/aec3/aec_state.cc b/modules/audio_processing/aec3/aec_state.cc
index 97c27d5..4b30d30 100644
--- a/modules/audio_processing/aec3/aec_state.cc
+++ b/modules/audio_processing/aec3/aec_state.cc
@@ -65,7 +65,7 @@
       transparent_state_(config_),
       filter_quality_state_(config_),
       erl_estimator_(2 * kNumBlocksPerSecond),
-      erle_estimator_(2 * kNumBlocksPerSecond, config_),
+      erle_estimator_(2 * kNumBlocksPerSecond, config_, num_capture_channels),
       filter_analyzer_(config_),
       echo_audibility_(
           config_.echo_audibility.use_stationarity_properties_at_init),
@@ -214,7 +214,7 @@
   reverb_model_estimator_.Dump(data_dumper_.get());
   data_dumper_->DumpRaw("aec3_erl", Erl());
   data_dumper_->DumpRaw("aec3_erl_time_domain", ErlTimeDomain());
-  data_dumper_->DumpRaw("aec3_erle", Erle());
+  data_dumper_->DumpRaw("aec3_erle", Erle()[0]);
   data_dumper_->DumpRaw("aec3_usable_linear_estimate", UsableLinearEstimate());
   data_dumper_->DumpRaw("aec3_transparent_mode", TransparentMode());
   data_dumper_->DumpRaw("aec3_filter_delay", filter_analyzer_.DelayBlocks());
diff --git a/modules/audio_processing/aec3/aec_state.h b/modules/audio_processing/aec3/aec_state.h
index 1229732..f860987 100644
--- a/modules/audio_processing/aec3/aec_state.h
+++ b/modules/audio_processing/aec3/aec_state.h
@@ -68,12 +68,12 @@
 
   // Returns whether the stationary properties of the signals are used in the
   // aec.
-  bool UseStationaryProperties() const {
+  bool UseStationarityProperties() const {
     return config_.echo_audibility.use_stationarity_properties;
   }
 
   // Returns the ERLE.
-  const std::array<float, kFftLengthBy2Plus1>& Erle() const {
+  rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> Erle() const {
     return erle_estimator_.Erle();
   }
 
diff --git a/modules/audio_processing/aec3/aec_state_unittest.cc b/modules/audio_processing/aec3/aec_state_unittest.cc
index ccf953a..5997ab1 100644
--- a/modules/audio_processing/aec3/aec_state_unittest.cc
+++ b/modules/audio_processing/aec3/aec_state_unittest.cc
@@ -170,7 +170,7 @@
   {
     // Note that the render spectrum is built so it does not have energy in
     // the odd bands but just in the even bands.
-    const auto& erle = state.Erle();
+    const auto& erle = state.Erle()[0];
     EXPECT_EQ(erle[0], erle[1]);
     constexpr size_t kLowFrequencyLimit = 32;
     for (size_t k = 2; k < kLowFrequencyLimit; k = k + 2) {
@@ -195,7 +195,7 @@
 
   ASSERT_TRUE(state.UsableLinearEstimate());
   {
-    const auto& erle = state.Erle();
+    const auto& erle = state.Erle()[0];
     EXPECT_EQ(erle[0], erle[1]);
     constexpr size_t kLowFrequencyLimit = 32;
     for (size_t k = 1; k < kLowFrequencyLimit; ++k) {
diff --git a/modules/audio_processing/aec3/echo_remover.cc b/modules/audio_processing/aec3/echo_remover.cc
index c33b39c..31736bf 100644
--- a/modules/audio_processing/aec3/echo_remover.cc
+++ b/modules/audio_processing/aec3/echo_remover.cc
@@ -152,7 +152,7 @@
   std::vector<std::unique_ptr<ComfortNoiseGenerator>> cngs_;
   SuppressionFilter suppression_filter_;
   RenderSignalAnalyzer render_signal_analyzer_;
-  std::vector<std::unique_ptr<ResidualEchoEstimator>> residual_echo_estimators_;
+  ResidualEchoEstimator residual_echo_estimator_;
   bool echo_leakage_detected_ = false;
   AecState aec_state_;
   EchoRemoverMetrics metrics_;
@@ -201,7 +201,7 @@
                           sample_rate_hz_,
                           num_capture_channels_),
       render_signal_analyzer_(config_),
-      residual_echo_estimators_(num_capture_channels_),
+      residual_echo_estimator_(config_, num_render_channels),
       aec_state_(config_, num_capture_channels_),
       e_old_(num_capture_channels_),
       y_old_(num_capture_channels_),
@@ -222,8 +222,6 @@
 
   uint32_t cng_seed = 42;
   for (size_t ch = 0; ch < num_capture_channels_; ++ch) {
-    residual_echo_estimators_[ch] =
-        std::make_unique<ResidualEchoEstimator>(config_);
     suppression_gains_[ch] = std::make_unique<SuppressionGain>(
         config_, optimization_, sample_rate_hz);
     cngs_[ch] =
@@ -400,11 +398,11 @@
   std::array<float, kFftLengthBy2Plus1> G;
   G.fill(1.f);
 
-  for (size_t ch = 0; ch < num_capture_channels_; ++ch) {
-    // Estimate the residual echo power.
-    residual_echo_estimators_[ch]->Estimate(aec_state_, *render_buffer,
-                                            S2_linear[ch], Y2[ch], &R2[ch]);
+  // Estimate the residual echo power.
+  residual_echo_estimator_.Estimate(aec_state_, *render_buffer, S2_linear, Y2,
+                                    R2);
 
+  for (size_t ch = 0; ch < num_capture_channels_; ++ch) {
     // Estimate the comfort noise.
     cngs_[ch]->Compute(aec_state_, Y2[ch], &comfort_noise[ch],
                        &high_band_comfort_noise[ch]);
@@ -462,8 +460,6 @@
       "aec3_X2",
       render_buffer->Spectrum(aec_state_.FilterDelayBlocks(), /*channel=*/0));
   data_dumper_->DumpRaw("aec3_R2", R2[0]);
-  data_dumper_->DumpRaw("aec3_R2_reverb",
-                        residual_echo_estimators_[0]->GetReverbPowerSpectrum());
   data_dumper_->DumpRaw("aec3_filter_delay", aec_state_.FilterDelayBlocks());
   data_dumper_->DumpRaw("aec3_capture_saturation",
                         aec_state_.SaturatedCapture() ? 1 : 0);
diff --git a/modules/audio_processing/aec3/echo_remover_metrics.cc b/modules/audio_processing/aec3/echo_remover_metrics.cc
index 4590f85..4ab05f8 100644
--- a/modules/audio_processing/aec3/echo_remover_metrics.cc
+++ b/modules/audio_processing/aec3/echo_remover_metrics.cc
@@ -70,7 +70,7 @@
   if (++block_counter_ <= kMetricsCollectionBlocks) {
     aec3::UpdateDbMetric(aec_state.Erl(), &erl_);
     erl_time_domain_.UpdateInstant(aec_state.ErlTimeDomain());
-    aec3::UpdateDbMetric(aec_state.Erle(), &erle_);
+    aec3::UpdateDbMetric(aec_state.Erle()[0], &erle_);
     erle_time_domain_.UpdateInstant(aec_state.FullBandErleLog2());
     aec3::UpdateDbMetric(comfort_noise_spectrum, &comfort_noise_);
     aec3::UpdateDbMetric(suppressor_gain, &suppressor_gain_);
diff --git a/modules/audio_processing/aec3/erle_estimator.cc b/modules/audio_processing/aec3/erle_estimator.cc
index 656a9c7..17bb79d 100644
--- a/modules/audio_processing/aec3/erle_estimator.cc
+++ b/modules/audio_processing/aec3/erle_estimator.cc
@@ -16,12 +16,13 @@
 namespace webrtc {
 
 ErleEstimator::ErleEstimator(size_t startup_phase_length_blocks_,
-                             const EchoCanceller3Config& config)
+                             const EchoCanceller3Config& config,
+                             size_t num_capture_channels)
     : startup_phase_length_blocks__(startup_phase_length_blocks_),
       use_signal_dependent_erle_(config.erle.num_sections > 1),
       fullband_erle_estimator_(config.erle.min, config.erle.max_l),
-      subband_erle_estimator_(config),
-      signal_dependent_erle_estimator_(config) {
+      subband_erle_estimator_(config, num_capture_channels),
+      signal_dependent_erle_estimator_(config, num_capture_channels) {
   Reset(true);
 }
 
diff --git a/modules/audio_processing/aec3/erle_estimator.h b/modules/audio_processing/aec3/erle_estimator.h
index 126774d..7f882ca 100644
--- a/modules/audio_processing/aec3/erle_estimator.h
+++ b/modules/audio_processing/aec3/erle_estimator.h
@@ -33,7 +33,8 @@
 class ErleEstimator {
  public:
   ErleEstimator(size_t startup_phase_length_blocks_,
-                const EchoCanceller3Config& config);
+                const EchoCanceller3Config& config,
+                size_t num_capture_channels);
   ~ErleEstimator();
 
   // Resets the fullband ERLE estimator and the subbands ERLE estimators.
@@ -50,10 +51,11 @@
               bool onset_detection);
 
   // Returns the most recent subband ERLE estimates.
-  const std::array<float, kFftLengthBy2Plus1>& Erle() const {
+  rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> Erle() const {
     return use_signal_dependent_erle_ ? signal_dependent_erle_estimator_.Erle()
                                       : subband_erle_estimator_.Erle();
   }
+
   // Returns the subband ERLE that are estimated during onsets. Used
   // for logging/testing.
   rtc::ArrayView<const float> ErleOnsets() const {
diff --git a/modules/audio_processing/aec3/erle_estimator_unittest.cc b/modules/audio_processing/aec3/erle_estimator_unittest.cc
index e2af48b..e8f99bc 100644
--- a/modules/audio_processing/aec3/erle_estimator_unittest.cc
+++ b/modules/audio_processing/aec3/erle_estimator_unittest.cc
@@ -113,22 +113,23 @@
   std::array<float, kFftLengthBy2Plus1> X2;
   std::array<float, kFftLengthBy2Plus1> E2;
   std::array<float, kFftLengthBy2Plus1> Y2;
-  constexpr size_t kNumChannels = 1;
+  constexpr size_t kNumRenderChannels = 1;
+  constexpr size_t kNumCaptureChannels = 1;
   constexpr int kSampleRateHz = 48000;
   constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz);
 
   EchoCanceller3Config config;
   std::vector<std::vector<std::vector<float>>> x(
       kNumBands, std::vector<std::vector<float>>(
-                     kNumChannels, std::vector<float>(kBlockSize, 0.f)));
+                     kNumRenderChannels, std::vector<float>(kBlockSize, 0.f)));
   std::vector<std::array<float, kFftLengthBy2Plus1>> filter_frequency_response(
       config.filter.main.length_blocks);
   std::unique_ptr<RenderDelayBuffer> render_delay_buffer(
-      RenderDelayBuffer::Create(config, kSampleRateHz, kNumChannels));
+      RenderDelayBuffer::Create(config, kSampleRateHz, kNumRenderChannels));
 
   GetFilterFreq(filter_frequency_response, config.delay.delay_headroom_samples);
 
-  ErleEstimator estimator(0, config);
+  ErleEstimator estimator(0, config, kNumCaptureChannels);
 
   FormFarendTimeFrame(&x);
   render_delay_buffer->Insert(x);
@@ -142,7 +143,7 @@
     estimator.Update(*render_delay_buffer->GetRenderBuffer(),
                      filter_frequency_response, X2, Y2, E2, true, true);
   }
-  VerifyErle(estimator.Erle(), std::pow(2.f, estimator.FullbandErleLog2()),
+  VerifyErle(estimator.Erle()[0], std::pow(2.f, estimator.FullbandErleLog2()),
              config.erle.max_l, config.erle.max_h);
 
   FormNearendFrame(&x, &X2, &E2, &Y2);
@@ -154,12 +155,13 @@
     estimator.Update(*render_delay_buffer->GetRenderBuffer(),
                      filter_frequency_response, X2, Y2, E2, true, true);
   }
-  VerifyErle(estimator.Erle(), std::pow(2.f, estimator.FullbandErleLog2()),
+  VerifyErle(estimator.Erle()[0], std::pow(2.f, estimator.FullbandErleLog2()),
              config.erle.max_l, config.erle.max_h);
 }
 
 TEST(ErleEstimator, VerifyErleTrackingOnOnsets) {
-  constexpr size_t kNumChannels = 1;
+  constexpr size_t kNumRenderChannels = 1;
+  constexpr size_t kNumCaptureChannels = 1;
   constexpr int kSampleRateHz = 48000;
   constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz);
   std::array<float, kFftLengthBy2Plus1> X2;
@@ -168,16 +170,16 @@
   EchoCanceller3Config config;
   std::vector<std::vector<std::vector<float>>> x(
       kNumBands, std::vector<std::vector<float>>(
-                     kNumChannels, std::vector<float>(kBlockSize, 0.f)));
+                     kNumRenderChannels, std::vector<float>(kBlockSize, 0.f)));
   std::vector<std::array<float, kFftLengthBy2Plus1>> filter_frequency_response(
       config.filter.main.length_blocks);
 
   std::unique_ptr<RenderDelayBuffer> render_delay_buffer(
-      RenderDelayBuffer::Create(config, kSampleRateHz, kNumChannels));
+      RenderDelayBuffer::Create(config, kSampleRateHz, kNumRenderChannels));
 
   GetFilterFreq(filter_frequency_response, config.delay.delay_headroom_samples);
 
-  ErleEstimator estimator(0, config);
+  ErleEstimator estimator(0, config, kNumCaptureChannels);
 
   FormFarendTimeFrame(&x);
   render_delay_buffer->Insert(x);
@@ -215,7 +217,7 @@
                      filter_frequency_response, X2, Y2, E2, true, true);
   }
   // Verifies that during ne activity, Erle converges to the Erle for onsets.
-  VerifyErle(estimator.Erle(), std::pow(2.f, estimator.FullbandErleLog2()),
+  VerifyErle(estimator.Erle()[0], std::pow(2.f, estimator.FullbandErleLog2()),
              config.erle.min, config.erle.min);
 }
 
diff --git a/modules/audio_processing/aec3/render_reverb_model.cc b/modules/audio_processing/aec3/render_reverb_model.cc
index 1c6a7e8..0410a9a 100644
--- a/modules/audio_processing/aec3/render_reverb_model.cc
+++ b/modules/audio_processing/aec3/render_reverb_model.cc
@@ -36,10 +36,14 @@
   int idx_past = spectrum_buffer.IncIndex(idx_at_delay);
   const auto& X2 = spectrum_buffer.buffer[idx_at_delay][/*channel=*/0];
   RTC_DCHECK_EQ(X2.size(), reverb_power_spectrum.size());
-  std::copy(X2.begin(), X2.end(), reverb_power_spectrum.begin());
-  render_reverb_.AddReverbNoFreqShaping(
-      spectrum_buffer.buffer[idx_past][/*channel=*/0], 1.0f, reverb_decay,
-      reverb_power_spectrum);
+  render_reverb_.UpdateReverbNoFreqShaping(
+      spectrum_buffer.buffer[idx_past][/*channel=*/0], 1.0f, reverb_decay);
+
+  rtc::ArrayView<const float, kFftLengthBy2Plus1> reverb_power =
+      render_reverb_.reverb();
+  for (size_t k = 0; k < X2.size(); ++k) {
+    reverb_power_spectrum[k] = X2[k] + reverb_power[k];
+  }
 }
 
 }  // namespace webrtc
diff --git a/modules/audio_processing/aec3/render_reverb_model.h b/modules/audio_processing/aec3/render_reverb_model.h
index a52351c..8859a90 100644
--- a/modules/audio_processing/aec3/render_reverb_model.h
+++ b/modules/audio_processing/aec3/render_reverb_model.h
@@ -37,7 +37,7 @@
   // Gets the reverberation spectrum that was added to the render spectrum for
   // computing the reverberation render spectrum.
   rtc::ArrayView<const float> GetReverbContributionPowerSpectrum() const {
-    return render_reverb_.GetPowerSpectrum();
+    return render_reverb_.reverb();
   }
 
  private:
diff --git a/modules/audio_processing/aec3/residual_echo_estimator.cc b/modules/audio_processing/aec3/residual_echo_estimator.cc
index e615d36..07197e3 100644
--- a/modules/audio_processing/aec3/residual_echo_estimator.cc
+++ b/modules/audio_processing/aec3/residual_echo_estimator.cc
@@ -43,10 +43,114 @@
   *idx_stop = spectrum_buffer.OffsetIndex(spectrum_buffer.read, window_end + 1);
 }
 
+// Estimates the residual echo power based on the echo return loss enhancement
+// (ERLE) and the linear power estimate.
+void LinearEstimate(
+    rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> S2_linear,
+    rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> erle,
+    rtc::ArrayView<std::array<float, kFftLengthBy2Plus1>> R2) {
+  RTC_DCHECK_EQ(S2_linear.size(), erle.size());
+  RTC_DCHECK_EQ(S2_linear.size(), R2.size());
+
+  const size_t num_capture_channels = R2.size();
+  for (size_t ch = 0; ch < num_capture_channels; ++ch) {
+    for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) {
+      RTC_DCHECK_LT(0.f, erle[ch][k]);
+      R2[ch][k] = S2_linear[ch][k] / erle[ch][k];
+    }
+  }
+}
+
+// Estimates the residual echo power based on an uncertainty estimate of the
+// echo return loss enhancement (ERLE) and the linear power estimate.
+void LinearEstimate(
+    rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> S2_linear,
+    float erle_uncertainty,
+    rtc::ArrayView<std::array<float, kFftLengthBy2Plus1>> R2) {
+  RTC_DCHECK_EQ(S2_linear.size(), R2.size());
+
+  const size_t num_capture_channels = R2.size();
+  for (size_t ch = 0; ch < num_capture_channels; ++ch) {
+    for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) {
+      R2[ch][k] = S2_linear[ch][k] * erle_uncertainty;
+    }
+  }
+}
+
+// Estimates the residual echo power based on the estimate of the echo path
+// gain.
+void NonLinearEstimate(
+    float echo_path_gain,
+    const std::array<float, kFftLengthBy2Plus1>& X2,
+    rtc::ArrayView<std::array<float, kFftLengthBy2Plus1>> R2) {
+  const size_t num_capture_channels = R2.size();
+  for (size_t ch = 0; ch < num_capture_channels; ++ch) {
+    for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) {
+      R2[ch][k] = X2[k] * echo_path_gain;
+    }
+  }
+}
+
+// Applies a soft noise gate to the echo generating power.
+void ApplyNoiseGate(const EchoCanceller3Config::EchoModel& config,
+                    rtc::ArrayView<float, kFftLengthBy2Plus1> X2) {
+  for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) {
+    if (config.noise_gate_power > X2[k]) {
+      X2[k] = std::max(0.f, X2[k] - config.noise_gate_slope *
+                                        (config.noise_gate_power - X2[k]));
+    }
+  }
+}
+
+// Estimates the echo generating signal power as gated maximal power over a
+// time window.
+void EchoGeneratingPower(size_t num_render_channels,
+                         const SpectrumBuffer& spectrum_buffer,
+                         const EchoCanceller3Config::EchoModel& echo_model,
+                         int filter_delay_blocks,
+                         rtc::ArrayView<float, kFftLengthBy2Plus1> X2) {
+  int idx_stop;
+  int idx_start;
+  GetRenderIndexesToAnalyze(spectrum_buffer, echo_model, filter_delay_blocks,
+                            &idx_start, &idx_stop);
+
+  std::fill(X2.begin(), X2.end(), 0.f);
+  if (num_render_channels == 1) {
+    for (int k = idx_start; k != idx_stop; k = spectrum_buffer.IncIndex(k)) {
+      for (size_t j = 0; j < kFftLengthBy2Plus1; ++j) {
+        X2[j] = std::max(X2[j], spectrum_buffer.buffer[k][/*channel=*/0][j]);
+      }
+    }
+  } else {
+    for (int k = idx_start; k != idx_stop; k = spectrum_buffer.IncIndex(k)) {
+      std::array<float, kFftLengthBy2Plus1> render_power;
+      render_power.fill(0.f);
+      for (size_t ch = 0; ch < num_render_channels; ++ch) {
+        const auto& channel_power = spectrum_buffer.buffer[k][ch];
+        for (size_t j = 0; j < kFftLengthBy2Plus1; ++j) {
+          render_power[j] += channel_power[j];
+        }
+      }
+      for (size_t j = 0; j < kFftLengthBy2Plus1; ++j) {
+        X2[j] = std::max(X2[j], render_power[j]);
+      }
+    }
+  }
+}
+
+// Chooses the echo path gain to use.
+float GetEchoPathGain(const AecState& aec_state,
+                      const EchoCanceller3Config::EpStrength& config) {
+  float gain_amplitude =
+      aec_state.TransparentMode() ? 0.01f : config.default_gain;
+  return gain_amplitude * gain_amplitude;
+}
+
 }  // namespace
 
-ResidualEchoEstimator::ResidualEchoEstimator(const EchoCanceller3Config& config)
-    : config_(config) {
+ResidualEchoEstimator::ResidualEchoEstimator(const EchoCanceller3Config& config,
+                                             size_t num_render_channels)
+    : config_(config), num_render_channels_(num_render_channels) {
   Reset();
 }
 
@@ -55,72 +159,78 @@
 void ResidualEchoEstimator::Estimate(
     const AecState& aec_state,
     const RenderBuffer& render_buffer,
-    const std::array<float, kFftLengthBy2Plus1>& S2_linear,
-    const std::array<float, kFftLengthBy2Plus1>& Y2,
-    std::array<float, kFftLengthBy2Plus1>* R2) {
-  RTC_DCHECK(R2);
+    rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> S2_linear,
+    rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> Y2,
+    rtc::ArrayView<std::array<float, kFftLengthBy2Plus1>> R2) {
+  RTC_DCHECK_EQ(R2.size(), Y2.size());
+  RTC_DCHECK_EQ(R2.size(), S2_linear.size());
+
+  const size_t num_capture_channels = R2.size();
 
   // Estimate the power of the stationary noise in the render signal.
-  RenderNoisePower(render_buffer, &X2_noise_floor_, &X2_noise_floor_counter_);
+  UpdateRenderNoisePower(render_buffer);
 
   // Estimate the residual echo power.
   if (aec_state.UsableLinearEstimate()) {
-    LinearEstimate(S2_linear, aec_state.Erle(), aec_state.ErleUncertainty(),
-                   R2);
-
     // When there is saturated echo, assume the same spectral content as is
     // present in the microphone signal.
     if (aec_state.SaturatedEcho()) {
-      std::copy(Y2.begin(), Y2.end(), R2->begin());
+      for (size_t ch = 0; ch < num_capture_channels; ++ch) {
+        std::copy(Y2[ch].begin(), Y2[ch].end(), R2[ch].begin());
+      }
+    } else {
+      absl::optional<float> erle_uncertainty = aec_state.ErleUncertainty();
+      if (erle_uncertainty) {
+        LinearEstimate(S2_linear, *erle_uncertainty, R2);
+      } else {
+        LinearEstimate(S2_linear, aec_state.Erle(), R2);
+      }
     }
 
-    // Adds the estimated unmodelled echo power to the residual echo power
-    // estimate.
-    echo_reverb_.AddReverb(
-        render_buffer.Spectrum(aec_state.FilterLengthBlocks() + 1,
-                               /*channel=*/0),
-        aec_state.GetReverbFrequencyResponse(), aec_state.ReverbDecay(), *R2);
+    AddReverb(ReverbType::kLinear, aec_state, render_buffer, R2);
   } else {
-    // Estimate the echo generating signal power.
-    std::array<float, kFftLengthBy2Plus1> X2;
-
-    EchoGeneratingPower(render_buffer.GetSpectrumBuffer(), config_.echo_model,
-                        aec_state.FilterDelayBlocks(),
-                        !aec_state.UseStationaryProperties(), &X2);
-
-    // Subtract the stationary noise power to avoid stationary noise causing
-    // excessive echo suppression.
-    std::transform(X2.begin(), X2.end(), X2_noise_floor_.begin(), X2.begin(),
-                   [&](float a, float b) {
-                     return std::max(
-                         0.f, a - config_.echo_model.stationary_gate_slope * b);
-                   });
-
-    float echo_path_gain;
-    echo_path_gain =
-        aec_state.TransparentMode() ? 0.01f : config_.ep_strength.default_gain;
-    NonLinearEstimate(echo_path_gain, X2, R2);
+    const float echo_path_gain =
+        GetEchoPathGain(aec_state, config_.ep_strength);
 
     // When there is saturated echo, assume the same spectral content as is
     // present in the microphone signal.
     if (aec_state.SaturatedEcho()) {
-      std::copy(Y2.begin(), Y2.end(), R2->begin());
+      for (size_t ch = 0; ch < num_capture_channels; ++ch) {
+        std::copy(Y2[ch].begin(), Y2[ch].end(), R2[ch].begin());
+      }
+    } else {
+      // Estimate the echo generating signal power.
+      std::array<float, kFftLengthBy2Plus1> X2;
+      EchoGeneratingPower(num_render_channels_,
+                          render_buffer.GetSpectrumBuffer(), config_.echo_model,
+                          aec_state.FilterDelayBlocks(), X2);
+      if (!aec_state.UseStationarityProperties()) {
+        ApplyNoiseGate(config_.echo_model, X2);
+      }
+
+      // Subtract the stationary noise power to avoid stationary noise causing
+      // excessive echo suppression.
+      for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) {
+        X2[k] -= config_.echo_model.stationary_gate_slope * X2_noise_floor_[k];
+        X2[k] = std::max(0.f, X2[k]);
+      }
+
+      NonLinearEstimate(echo_path_gain, X2, R2);
     }
 
-    if (!(aec_state.TransparentMode())) {
-      echo_reverb_.AddReverbNoFreqShaping(
-          render_buffer.Spectrum(aec_state.FilterDelayBlocks() + 1,
-                                 /*channel=*/0),
-          echo_path_gain * echo_path_gain, aec_state.ReverbDecay(), *R2);
+    if (!aec_state.TransparentMode()) {
+      AddReverb(ReverbType::kNonLinear, aec_state, render_buffer, R2);
     }
   }
 
-  if (aec_state.UseStationaryProperties()) {
+  if (aec_state.UseStationarityProperties()) {
     // Scale the echo according to echo audibility.
     std::array<float, kFftLengthBy2Plus1> residual_scaling;
     aec_state.GetResidualEchoScaling(residual_scaling);
-    for (size_t k = 0; k < R2->size(); ++k) {
-      (*R2)[k] *= residual_scaling[k];
+    for (size_t ch = 0; ch < num_capture_channels; ++ch) {
+      for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) {
+        R2[ch][k] *= residual_scaling[k];
+      }
     }
   }
 }
@@ -131,94 +241,97 @@
   X2_noise_floor_.fill(config_.echo_model.min_noise_floor_power);
 }
 
-void ResidualEchoEstimator::LinearEstimate(
-    const std::array<float, kFftLengthBy2Plus1>& S2_linear,
-    const std::array<float, kFftLengthBy2Plus1>& erle,
-    absl::optional<float> erle_uncertainty,
-    std::array<float, kFftLengthBy2Plus1>* R2) {
-  if (erle_uncertainty) {
-    for (size_t k = 0; k < R2->size(); ++k) {
-      (*R2)[k] = S2_linear[k] * *erle_uncertainty;
-    }
+void ResidualEchoEstimator::UpdateRenderNoisePower(
+    const RenderBuffer& render_buffer) {
+  std::array<float, kFftLengthBy2Plus1> render_power_data;
+  rtc::ArrayView<const float> render_power;
+  if (num_render_channels_ == 1) {
+    render_power = render_buffer.Spectrum(0, /*channel=*/0);
   } else {
-    std::transform(erle.begin(), erle.end(), S2_linear.begin(), R2->begin(),
-                   [](float a, float b) {
-                     RTC_DCHECK_LT(0.f, a);
-                     return b / a;
-                   });
-  }
-}
-
-void ResidualEchoEstimator::NonLinearEstimate(
-    float echo_path_gain,
-    const std::array<float, kFftLengthBy2Plus1>& X2,
-    std::array<float, kFftLengthBy2Plus1>* R2) {
-  // Compute preliminary residual echo.
-  std::transform(X2.begin(), X2.end(), R2->begin(), [echo_path_gain](float a) {
-    return a * echo_path_gain * echo_path_gain;
-  });
-}
-
-void ResidualEchoEstimator::EchoGeneratingPower(
-    const SpectrumBuffer& spectrum_buffer,
-    const EchoCanceller3Config::EchoModel& echo_model,
-    int filter_delay_blocks,
-    bool apply_noise_gating,
-    std::array<float, kFftLengthBy2Plus1>* X2) const {
-  int idx_stop, idx_start;
-
-  RTC_DCHECK(X2);
-  GetRenderIndexesToAnalyze(spectrum_buffer, config_.echo_model,
-                            filter_delay_blocks, &idx_start, &idx_stop);
-
-  X2->fill(0.f);
-  for (int k = idx_start; k != idx_stop; k = spectrum_buffer.IncIndex(k)) {
-    std::transform(X2->begin(), X2->end(),
-                   spectrum_buffer.buffer[k][/*channel=*/0].begin(),
-                   X2->begin(),
-                   [](float a, float b) { return std::max(a, b); });
-  }
-
-  if (apply_noise_gating) {
-    // Apply soft noise gate.
-    std::for_each(X2->begin(), X2->end(), [&](float& a) {
-      if (config_.echo_model.noise_gate_power > a) {
-        a = std::max(0.f, a - config_.echo_model.noise_gate_slope *
-                                  (config_.echo_model.noise_gate_power - a));
+    render_power_data.fill(0.f);
+    for (size_t ch = 0; ch < num_render_channels_; ++ch) {
+      const auto& channel_power = render_buffer.Spectrum(0, ch);
+      RTC_DCHECK_EQ(channel_power.size(), kFftLengthBy2Plus1);
+      for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) {
+        render_power_data[k] += channel_power[k];
       }
-    });
+    }
+    render_power = render_power_data;
   }
-}
-
-void ResidualEchoEstimator::RenderNoisePower(
-    const RenderBuffer& render_buffer,
-    std::array<float, kFftLengthBy2Plus1>* X2_noise_floor,
-    std::array<int, kFftLengthBy2Plus1>* X2_noise_floor_counter) const {
-  RTC_DCHECK(X2_noise_floor);
-  RTC_DCHECK(X2_noise_floor_counter);
-
-  const auto render_power = render_buffer.Spectrum(0, /*channel=*/0);
-  RTC_DCHECK_EQ(X2_noise_floor->size(), render_power.size());
-  RTC_DCHECK_EQ(X2_noise_floor_counter->size(), render_power.size());
+  RTC_DCHECK_EQ(render_power.size(), kFftLengthBy2Plus1);
 
   // Estimate the stationary noise power in a minimum statistics manner.
-  for (size_t k = 0; k < render_power.size(); ++k) {
+  for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) {
     // Decrease rapidly.
-    if (render_power[k] < (*X2_noise_floor)[k]) {
-      (*X2_noise_floor)[k] = render_power[k];
-      (*X2_noise_floor_counter)[k] = 0;
+    if (render_power[k] < X2_noise_floor_[k]) {
+      X2_noise_floor_[k] = render_power[k];
+      X2_noise_floor_counter_[k] = 0;
     } else {
       // Increase in a delayed, leaky manner.
-      if ((*X2_noise_floor_counter)[k] >=
+      if (X2_noise_floor_counter_[k] >=
           static_cast<int>(config_.echo_model.noise_floor_hold)) {
-        (*X2_noise_floor)[k] =
-            std::max((*X2_noise_floor)[k] * 1.1f,
-                     config_.echo_model.min_noise_floor_power);
+        X2_noise_floor_[k] = std::max(X2_noise_floor_[k] * 1.1f,
+                                      config_.echo_model.min_noise_floor_power);
       } else {
-        ++(*X2_noise_floor_counter)[k];
+        ++X2_noise_floor_counter_[k];
       }
     }
   }
 }
 
+// Adds the estimated power of the reverb to the residual echo power.
+void ResidualEchoEstimator::AddReverb(
+    ReverbType reverb_type,
+    const AecState& aec_state,
+    const RenderBuffer& render_buffer,
+    rtc::ArrayView<std::array<float, kFftLengthBy2Plus1>> R2) {
+  const size_t num_capture_channels = R2.size();
+
+  // Choose reverb partition based on what type of echo power model is used.
+  const size_t first_reverb_partition = reverb_type == ReverbType::kLinear
+                                            ? aec_state.FilterLengthBlocks() + 1
+                                            : aec_state.FilterDelayBlocks() + 1;
+
+  // Compute render power for the reverb.
+  std::array<float, kFftLengthBy2Plus1> render_power_data;
+  rtc::ArrayView<const float> render_power;
+  if (num_render_channels_ == 1) {
+    render_power =
+        render_buffer.Spectrum(first_reverb_partition, /*channel=*/0);
+  } else {
+    render_power_data.fill(0.f);
+    for (size_t ch = 0; ch < num_render_channels_; ++ch) {
+      const auto& channel_power =
+          render_buffer.Spectrum(first_reverb_partition, ch);
+      RTC_DCHECK_EQ(channel_power.size(), kFftLengthBy2Plus1);
+      for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) {
+        render_power_data[k] += channel_power[k];
+      }
+    }
+    render_power = render_power_data;
+  }
+  RTC_DCHECK_EQ(render_power.size(), kFftLengthBy2Plus1);
+
+  // Update the reverb estimate.
+  if (reverb_type == ReverbType::kLinear) {
+    echo_reverb_.UpdateReverb(render_power,
+                              aec_state.GetReverbFrequencyResponse(),
+                              aec_state.ReverbDecay());
+  } else {
+    const float echo_path_gain =
+        GetEchoPathGain(aec_state, config_.ep_strength);
+    echo_reverb_.UpdateReverbNoFreqShaping(render_power, echo_path_gain,
+                                           aec_state.ReverbDecay());
+  }
+
+  // Add the reverb power.
+  rtc::ArrayView<const float, kFftLengthBy2Plus1> reverb_power =
+      echo_reverb_.reverb();
+  for (size_t ch = 0; ch < num_capture_channels; ++ch) {
+    for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) {
+      R2[ch][k] += reverb_power[k];
+    }
+  }
+}
+
 }  // namespace webrtc
diff --git a/modules/audio_processing/aec3/residual_echo_estimator.h b/modules/audio_processing/aec3/residual_echo_estimator.h
index e340918..5c14bdb 100644
--- a/modules/audio_processing/aec3/residual_echo_estimator.h
+++ b/modules/audio_processing/aec3/residual_echo_estimator.h
@@ -22,63 +22,47 @@
 #include "modules/audio_processing/aec3/reverb_model.h"
 #include "modules/audio_processing/aec3/spectrum_buffer.h"
 #include "rtc_base/checks.h"
-#include "rtc_base/constructor_magic.h"
 
 namespace webrtc {
 
 class ResidualEchoEstimator {
  public:
-  explicit ResidualEchoEstimator(const EchoCanceller3Config& config);
+  ResidualEchoEstimator(const EchoCanceller3Config& config,
+                        size_t num_render_channels);
   ~ResidualEchoEstimator();
 
-  void Estimate(const AecState& aec_state,
-                const RenderBuffer& render_buffer,
-                const std::array<float, kFftLengthBy2Plus1>& S2_linear,
-                const std::array<float, kFftLengthBy2Plus1>& Y2,
-                std::array<float, kFftLengthBy2Plus1>* R2);
+  ResidualEchoEstimator(const ResidualEchoEstimator&) = delete;
+  ResidualEchoEstimator& operator=(const ResidualEchoEstimator&) = delete;
 
-  // Returns the reverberant power spectrum contributions to the echo residual.
-  rtc::ArrayView<const float> GetReverbPowerSpectrum() const {
-    return echo_reverb_.GetPowerSpectrum();
-  }
+  void Estimate(
+      const AecState& aec_state,
+      const RenderBuffer& render_buffer,
+      rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> S2_linear,
+      rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> Y2,
+      rtc::ArrayView<std::array<float, kFftLengthBy2Plus1>> R2);
 
  private:
+  enum class ReverbType { kLinear, kNonLinear };
+
   // Resets the state.
   void Reset();
 
-  // Estimates the residual echo power based on the echo return loss enhancement
-  // (ERLE) and the linear power estimate.
-  void LinearEstimate(const std::array<float, kFftLengthBy2Plus1>& S2_linear,
-                      const std::array<float, kFftLengthBy2Plus1>& erle,
-                      absl::optional<float> erle_uncertainty,
-                      std::array<float, kFftLengthBy2Plus1>* R2);
-
-  // Estimates the residual echo power based on the estimate of the echo path
-  // gain.
-  void NonLinearEstimate(float echo_path_gain,
-                         const std::array<float, kFftLengthBy2Plus1>& X2,
-                         std::array<float, kFftLengthBy2Plus1>* R2);
-
-  // Estimates the echo generating signal power as gated maximal power over a
-  // time window.
-  void EchoGeneratingPower(const SpectrumBuffer& spectrum_buffer,
-                           const EchoCanceller3Config::EchoModel& echo_model,
-                           int filter_delay_blocks,
-                           bool apply_noise_gating,
-                           std::array<float, kFftLengthBy2Plus1>* X2) const;
-
   // Updates estimate for the power of the stationary noise component in the
   // render signal.
-  void RenderNoisePower(
-      const RenderBuffer& render_buffer,
-      std::array<float, kFftLengthBy2Plus1>* X2_noise_floor,
-      std::array<int, kFftLengthBy2Plus1>* X2_noise_floor_counter) const;
+  void UpdateRenderNoisePower(const RenderBuffer& render_buffer);
+
+  // Adds the estimated unmodelled echo power to the residual echo power
+  // estimate.
+  void AddReverb(ReverbType reverb_type,
+                 const AecState& aec_state,
+                 const RenderBuffer& render_buffer,
+                 rtc::ArrayView<std::array<float, kFftLengthBy2Plus1>> R2);
 
   const EchoCanceller3Config config_;
+  const size_t num_render_channels_;
   std::array<float, kFftLengthBy2Plus1> X2_noise_floor_;
   std::array<int, kFftLengthBy2Plus1> X2_noise_floor_counter_;
   ReverbModel echo_reverb_;
-  RTC_DISALLOW_IMPLICIT_CONSTRUCTORS(ResidualEchoEstimator);
 };
 
 }  // namespace webrtc
diff --git a/modules/audio_processing/aec3/residual_echo_estimator_unittest.cc b/modules/audio_processing/aec3/residual_echo_estimator_unittest.cc
index 2823cae..55f634b 100644
--- a/modules/audio_processing/aec3/residual_echo_estimator_unittest.cc
+++ b/modules/audio_processing/aec3/residual_echo_estimator_unittest.cc
@@ -20,98 +20,73 @@
 
 namespace webrtc {
 
-#if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID)
+TEST(ResidualEchoEstimator, BasicTest) {
+  for (size_t num_render_channels : {1, 2, 4}) {
+    for (size_t num_capture_channels : {1, 2, 4}) {
+      constexpr int kSampleRateHz = 48000;
+      constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz);
 
-// Verifies that the check for non-null output residual echo power works.
-TEST(ResidualEchoEstimator, NullResidualEchoPowerOutput) {
-  EchoCanceller3Config config;
-  AecState aec_state(config, 1);
-  std::unique_ptr<RenderDelayBuffer> render_delay_buffer(
-      RenderDelayBuffer::Create(config, 48000, 1));
-  std::vector<std::array<float, kFftLengthBy2Plus1>> H2;
-  std::array<float, kFftLengthBy2Plus1> S2_linear;
-  std::array<float, kFftLengthBy2Plus1> Y2;
-  EXPECT_DEATH(ResidualEchoEstimator(EchoCanceller3Config{})
-                   .Estimate(aec_state, *render_delay_buffer->GetRenderBuffer(),
-                             S2_linear, Y2, nullptr),
-               "");
-}
+      EchoCanceller3Config config;
+      ResidualEchoEstimator estimator(config, num_render_channels);
+      AecState aec_state(config, num_render_channels);
+      std::unique_ptr<RenderDelayBuffer> render_delay_buffer(
+          RenderDelayBuffer::Create(config, kSampleRateHz,
+                                    num_render_channels));
 
-#endif
+      std::array<float, kFftLengthBy2Plus1> E2_main;
+      std::vector<std::array<float, kFftLengthBy2Plus1>> S2_linear(
+          num_capture_channels);
+      std::vector<std::array<float, kFftLengthBy2Plus1>> Y2(
+          num_capture_channels);
+      std::vector<std::array<float, kFftLengthBy2Plus1>> R2(
+          num_capture_channels);
+      std::vector<std::vector<std::vector<float>>> x(
+          kNumBands,
+          std::vector<std::vector<float>>(num_render_channels,
+                                          std::vector<float>(kBlockSize, 0.f)));
+      std::vector<std::array<float, kFftLengthBy2Plus1>> H2(10);
+      Random random_generator(42U);
+      std::vector<SubtractorOutput> output(num_render_channels);
+      std::array<float, kBlockSize> y;
+      absl::optional<DelayEstimate> delay_estimate;
 
-// TODO(peah): This test is broken in the sense that it not at all tests what it
-// seems to test. Enable the test once that is adressed.
-TEST(ResidualEchoEstimator, DISABLED_BasicTest) {
-  constexpr size_t kNumChannels = 1;
-  constexpr int kSampleRateHz = 48000;
-  constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz);
+      for (auto& H2_k : H2) {
+        H2_k.fill(0.01f);
+      }
+      H2[2].fill(10.f);
+      H2[2][0] = 0.1f;
 
-  EchoCanceller3Config config;
-  config.ep_strength.default_len = 0.f;
-  ResidualEchoEstimator estimator(config);
-  AecState aec_state(config, kNumChannels);
-  std::unique_ptr<RenderDelayBuffer> render_delay_buffer(
-      RenderDelayBuffer::Create(config, kSampleRateHz, kNumChannels));
+      std::vector<float> h(
+          GetTimeDomainLength(config.filter.main.length_blocks), 0.f);
 
-  std::array<float, kFftLengthBy2Plus1> E2_main;
-  std::array<float, kFftLengthBy2Plus1> E2_shadow;
-  std::array<float, kFftLengthBy2Plus1> S2_linear;
-  std::array<float, kFftLengthBy2Plus1> S2_fallback;
-  std::array<float, kFftLengthBy2Plus1> Y2;
-  std::array<float, kFftLengthBy2Plus1> R2;
-  EchoPathVariability echo_path_variability(
-      false, EchoPathVariability::DelayAdjustment::kNone, false);
-  std::vector<std::vector<std::vector<float>>> x(
-      kNumBands, std::vector<std::vector<float>>(
-                     kNumChannels, std::vector<float>(kBlockSize, 0.f)));
-  std::vector<std::array<float, kFftLengthBy2Plus1>> H2(10);
-  Random random_generator(42U);
-  std::vector<SubtractorOutput> output(kNumChannels);
-  std::array<float, kBlockSize> y;
-  Aec3Fft fft;
-  absl::optional<DelayEstimate> delay_estimate;
+      for (auto& subtractor_output : output) {
+        subtractor_output.Reset();
+        subtractor_output.s_main.fill(100.f);
+      }
+      y.fill(0.f);
 
-  for (auto& H2_k : H2) {
-    H2_k.fill(0.01f);
-  }
-  H2[2].fill(10.f);
-  H2[2][0] = 0.1f;
+      constexpr float kLevel = 10.f;
+      E2_main.fill(kLevel);
+      S2_linear[0].fill(kLevel);
+      Y2[0].fill(kLevel);
 
-  std::vector<float> h(GetTimeDomainLength(config.filter.main.length_blocks),
-                       0.f);
+      for (int k = 0; k < 1993; ++k) {
+        RandomizeSampleVector(&random_generator, x[0][0]);
+        render_delay_buffer->Insert(x);
+        if (k == 0) {
+          render_delay_buffer->Reset();
+        }
+        render_delay_buffer->PrepareCaptureProcessing();
 
-  for (auto& subtractor_output : output) {
-    subtractor_output.Reset();
-    subtractor_output.s_main.fill(100.f);
-  }
-  y.fill(0.f);
+        aec_state.Update(delay_estimate, H2, h,
+                         *render_delay_buffer->GetRenderBuffer(), E2_main,
+                         Y2[0], output);
 
-  constexpr float kLevel = 10.f;
-  E2_shadow.fill(kLevel);
-  E2_main.fill(kLevel);
-  S2_linear.fill(kLevel);
-  S2_fallback.fill(kLevel);
-  Y2.fill(kLevel);
-
-  for (int k = 0; k < 1993; ++k) {
-    RandomizeSampleVector(&random_generator, x[0][0]);
-    std::for_each(x[0][0].begin(), x[0][0].end(), [](float& a) { a /= 30.f; });
-    render_delay_buffer->Insert(x);
-    if (k == 0) {
-      render_delay_buffer->Reset();
+        estimator.Estimate(aec_state, *render_delay_buffer->GetRenderBuffer(),
+                           S2_linear, Y2, R2);
+      }
     }
-    render_delay_buffer->PrepareCaptureProcessing();
-
-    aec_state.HandleEchoPathChange(echo_path_variability);
-    aec_state.Update(delay_estimate, H2, h,
-                     *render_delay_buffer->GetRenderBuffer(), E2_main, Y2,
-                     output);
-
-    estimator.Estimate(aec_state, *render_delay_buffer->GetRenderBuffer(),
-                       S2_linear, Y2, &R2);
   }
-  std::for_each(R2.begin(), R2.end(),
-                [&](float a) { EXPECT_NEAR(kLevel, a, 0.1f); });
 }
 
 }  // namespace webrtc
diff --git a/modules/audio_processing/aec3/reverb_model.cc b/modules/audio_processing/aec3/reverb_model.cc
index ca65960..e4f3507 100644
--- a/modules/audio_processing/aec3/reverb_model.cc
+++ b/modules/audio_processing/aec3/reverb_model.cc
@@ -29,34 +29,7 @@
   reverb_.fill(0.);
 }
 
-void ReverbModel::AddReverbNoFreqShaping(
-    rtc::ArrayView<const float> power_spectrum,
-    float power_spectrum_scaling,
-    float reverb_decay,
-    rtc::ArrayView<float> reverb_power_spectrum) {
-  UpdateReverbContributionsNoFreqShaping(power_spectrum, power_spectrum_scaling,
-                                         reverb_decay);
-
-  // Add the power of the echo reverb to the residual echo power.
-  std::transform(reverb_power_spectrum.begin(), reverb_power_spectrum.end(),
-                 reverb_.begin(), reverb_power_spectrum.begin(),
-                 std::plus<float>());
-}
-
-void ReverbModel::AddReverb(rtc::ArrayView<const float> power_spectrum,
-                            rtc::ArrayView<const float> power_spectrum_scaling,
-                            float reverb_decay,
-                            rtc::ArrayView<float> reverb_power_spectrum) {
-  UpdateReverbContributions(power_spectrum, power_spectrum_scaling,
-                            reverb_decay);
-
-  // Add the power of the echo reverb to the residual echo power.
-  std::transform(reverb_power_spectrum.begin(), reverb_power_spectrum.end(),
-                 reverb_.begin(), reverb_power_spectrum.begin(),
-                 std::plus<float>());
-}
-
-void ReverbModel::UpdateReverbContributionsNoFreqShaping(
+void ReverbModel::UpdateReverbNoFreqShaping(
     rtc::ArrayView<const float> power_spectrum,
     float power_spectrum_scaling,
     float reverb_decay) {
@@ -69,9 +42,9 @@
   }
 }
 
-void ReverbModel::UpdateReverbContributions(
-    rtc::ArrayView<const float>& power_spectrum,
-    rtc::ArrayView<const float>& power_spectrum_scaling,
+void ReverbModel::UpdateReverb(
+    rtc::ArrayView<const float> power_spectrum,
+    rtc::ArrayView<const float> power_spectrum_scaling,
     float reverb_decay) {
   if (reverb_decay > 0) {
     // Update the estimate of the reverberant power.
diff --git a/modules/audio_processing/aec3/reverb_model.h b/modules/audio_processing/aec3/reverb_model.h
index 56e2266..5ba5485 100644
--- a/modules/audio_processing/aec3/reverb_model.h
+++ b/modules/audio_processing/aec3/reverb_model.h
@@ -28,37 +28,27 @@
   // Resets the state.
   void Reset();
 
-  // The methods AddReverbNoFreqShaping and AddReverb add the reverberation
-  // contribution to an input/output power spectrum
-  // Before applying the exponential reverberant model, the input power spectrum
-  // is pre-scaled. Use the method AddReverb when a different scaling should be
-  // applied per frequency and AddReverb_no_freq_shape if the same scaling
-  // should be used for all the frequencies.
-  void AddReverbNoFreqShaping(rtc::ArrayView<const float> power_spectrum,
-                              float power_spectrum_scaling,
-                              float reverb_decay,
-                              rtc::ArrayView<float> reverb_power_spectrum);
+  // Returns the reverb.
+  rtc::ArrayView<const float, kFftLengthBy2Plus1> reverb() const {
+    return reverb_;
+  }
 
-  void AddReverb(rtc::ArrayView<const float> power_spectrum,
-                 rtc::ArrayView<const float> freq_response_tail,
-                 float reverb_decay,
-                 rtc::ArrayView<float> reverb_power_spectrum);
+  // The methods UpdateReverbNoFreqShaping and UpdateReverb update the
+  // estimate of the reverberation contribution to an input/output power
+  // spectrum. Before applying the exponential reverberant model, the input
+  // power spectrum is pre-scaled. Use the method UpdateReverb when a different
+  // scaling should be applied per frequency and UpdateReverb_no_freq_shape if
+  // the same scaling should be used for all the frequencies.
+  void UpdateReverbNoFreqShaping(rtc::ArrayView<const float> power_spectrum,
+                                 float power_spectrum_scaling,
+                                 float reverb_decay);
 
-  // Updates the reverberation contributions without applying any shaping of the
-  // spectrum.
-  void UpdateReverbContributionsNoFreqShaping(
-      rtc::ArrayView<const float> power_spectrum,
-      float power_spectrum_scaling,
-      float reverb_decay);
-
-  // Returns the current power spectrum reverberation contributions.
-  rtc::ArrayView<const float> GetPowerSpectrum() const { return reverb_; }
+  // Update the reverb based on new data.
+  void UpdateReverb(rtc::ArrayView<const float> power_spectrum,
+                    rtc::ArrayView<const float> power_spectrum_scaling,
+                    float reverb_decay);
 
  private:
-  // Updates the reverberation contributions.
-  void UpdateReverbContributions(rtc::ArrayView<const float>& power_spectrum,
-                                 rtc::ArrayView<const float>& freq_resp_tail,
-                                 float reverb_decay);
 
   std::array<float, kFftLengthBy2Plus1> reverb_;
 };
diff --git a/modules/audio_processing/aec3/signal_dependent_erle_estimator.cc b/modules/audio_processing/aec3/signal_dependent_erle_estimator.cc
index e603675..d3c07a1 100644
--- a/modules/audio_processing/aec3/signal_dependent_erle_estimator.cc
+++ b/modules/audio_processing/aec3/signal_dependent_erle_estimator.cc
@@ -118,7 +118,8 @@
 }  // namespace
 
 SignalDependentErleEstimator::SignalDependentErleEstimator(
-    const EchoCanceller3Config& config)
+    const EchoCanceller3Config& config,
+    size_t num_capture_channels)
     : min_erle_(config.erle.min),
       num_sections_(config.erle.num_sections),
       num_blocks_(config.filter.main.length_blocks),
@@ -130,6 +131,7 @@
       section_boundaries_blocks_(SetSectionsBoundaries(delay_headroom_blocks_,
                                                        num_blocks_,
                                                        num_sections_)),
+      erle_(num_capture_channels),
       S2_section_accum_(num_sections_),
       erle_estimators_(num_sections_),
       correction_factors_(num_sections_) {
@@ -142,10 +144,12 @@
 SignalDependentErleEstimator::~SignalDependentErleEstimator() = default;
 
 void SignalDependentErleEstimator::Reset() {
-  erle_.fill(min_erle_);
-  for (auto& erle : erle_estimators_) {
+  for (auto& erle : erle_) {
     erle.fill(min_erle_);
   }
+  for (auto& erle_estimator : erle_estimators_) {
+    erle_estimator.fill(min_erle_);
+  }
   erle_ref_.fill(min_erle_);
   for (auto& factor : correction_factors_) {
     factor.fill(1.0f);
@@ -166,7 +170,7 @@
     rtc::ArrayView<const float> X2,
     rtc::ArrayView<const float> Y2,
     rtc::ArrayView<const float> E2,
-    rtc::ArrayView<const float> average_erle,
+    rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> average_erle,
     bool converged_filter) {
   RTC_DCHECK_GT(num_sections_, 1);
 
@@ -187,8 +191,8 @@
   for (size_t k = 0; k < kFftLengthBy2; ++k) {
     float correction_factor =
         correction_factors_[n_active_sections[k]][band_to_subband_[k]];
-    erle_[k] = rtc::SafeClamp(average_erle[k] * correction_factor, min_erle_,
-                              max_erle_[band_to_subband_[k]]);
+    erle_[0][k] = rtc::SafeClamp(average_erle[0][k] * correction_factor,
+                                 min_erle_, max_erle_[band_to_subband_[k]]);
   }
 }
 
diff --git a/modules/audio_processing/aec3/signal_dependent_erle_estimator.h b/modules/audio_processing/aec3/signal_dependent_erle_estimator.h
index d8b56c2..da0b8ab 100644
--- a/modules/audio_processing/aec3/signal_dependent_erle_estimator.h
+++ b/modules/audio_processing/aec3/signal_dependent_erle_estimator.h
@@ -29,25 +29,29 @@
 // this class receive as an input.
 class SignalDependentErleEstimator {
  public:
-  explicit SignalDependentErleEstimator(const EchoCanceller3Config& config);
+  SignalDependentErleEstimator(const EchoCanceller3Config& config,
+                               size_t num_capture_channels);
 
   ~SignalDependentErleEstimator();
 
   void Reset();
 
   // Returns the Erle per frequency subband.
-  const std::array<float, kFftLengthBy2Plus1>& Erle() const { return erle_; }
+  rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> Erle() const {
+    return erle_;
+  }
 
   // Updates the Erle estimate. The Erle that is passed as an input is required
   // to be an estimation of the average Erle achieved by the linear filter.
-  void Update(const RenderBuffer& render_buffer,
-              const std::vector<std::array<float, kFftLengthBy2Plus1>>&
-                  filter_frequency_response,
-              rtc::ArrayView<const float> X2,
-              rtc::ArrayView<const float> Y2,
-              rtc::ArrayView<const float> E2,
-              rtc::ArrayView<const float> average_erle,
-              bool converged_filter);
+  void Update(
+      const RenderBuffer& render_buffer,
+      const std::vector<std::array<float, kFftLengthBy2Plus1>>&
+          filter_frequency_response,
+      rtc::ArrayView<const float> X2,
+      rtc::ArrayView<const float> Y2,
+      rtc::ArrayView<const float> E2,
+      rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> average_erle,
+      bool converged_filter);
 
   void Dump(const std::unique_ptr<ApmDataDumper>& data_dumper) const;
 
@@ -80,7 +84,7 @@
   const std::array<size_t, kFftLengthBy2Plus1> band_to_subband_;
   const std::array<float, kSubbands> max_erle_;
   const std::vector<size_t> section_boundaries_blocks_;
-  std::array<float, kFftLengthBy2Plus1> erle_;
+  std::vector<std::array<float, kFftLengthBy2Plus1>> erle_;
   std::vector<std::array<float, kFftLengthBy2Plus1>> S2_section_accum_;
   std::vector<std::array<float, kSubbands>> erle_estimators_;
   std::array<float, kSubbands> erle_ref_;
diff --git a/modules/audio_processing/aec3/signal_dependent_erle_estimator_unittest.cc b/modules/audio_processing/aec3/signal_dependent_erle_estimator_unittest.cc
index 7baa8f0..ccc2ef3 100644
--- a/modules/audio_processing/aec3/signal_dependent_erle_estimator_unittest.cc
+++ b/modules/audio_processing/aec3/signal_dependent_erle_estimator_unittest.cc
@@ -112,6 +112,7 @@
 }  // namespace
 
 TEST(SignalDependentErleEstimator, SweepSettings) {
+  const size_t kNumCaptureChannels = 1;
   EchoCanceller3Config cfg;
   size_t max_length_blocks = 50;
   for (size_t blocks = 0; blocks < max_length_blocks; blocks = blocks + 10) {
@@ -124,9 +125,12 @@
         cfg.delay.delay_headroom_samples = delay_headroom * kBlockSize;
         cfg.erle.num_sections = num_sections;
         if (EchoCanceller3Config::Validate(&cfg)) {
-          SignalDependentErleEstimator s(cfg);
-          std::array<float, kFftLengthBy2Plus1> average_erle;
-          average_erle.fill(cfg.erle.max_l);
+          SignalDependentErleEstimator s(cfg, kNumCaptureChannels);
+          std::array<std::array<float, kFftLengthBy2Plus1>, kNumCaptureChannels>
+              average_erle;
+          for (auto& e : average_erle) {
+            e.fill(cfg.erle.max_l);
+          }
           TestInputs inputs(cfg);
           for (size_t n = 0; n < 10; ++n) {
             inputs.Update();
@@ -140,6 +144,7 @@
 }
 
 TEST(SignalDependentErleEstimator, LongerRun) {
+  const size_t kNumCaptureChannels = 1;
   EchoCanceller3Config cfg;
   cfg.filter.main.length_blocks = 2;
   cfg.filter.main_initial.length_blocks = 1;
@@ -147,9 +152,12 @@
   cfg.delay.hysteresis_limit_blocks = 0;
   cfg.erle.num_sections = 2;
   EXPECT_EQ(EchoCanceller3Config::Validate(&cfg), true);
-  std::array<float, kFftLengthBy2Plus1> average_erle;
-  average_erle.fill(cfg.erle.max_l);
-  SignalDependentErleEstimator s(cfg);
+  std::array<std::array<float, kFftLengthBy2Plus1>, kNumCaptureChannels>
+      average_erle;
+  for (auto& e : average_erle) {
+    e.fill(cfg.erle.max_l);
+  }
+  SignalDependentErleEstimator s(cfg, kNumCaptureChannels);
   TestInputs inputs(cfg);
   for (size_t n = 0; n < 200; ++n) {
     inputs.Update();
diff --git a/modules/audio_processing/aec3/subband_erle_estimator.cc b/modules/audio_processing/aec3/subband_erle_estimator.cc
index 82f3dab..137b055 100644
--- a/modules/audio_processing/aec3/subband_erle_estimator.cc
+++ b/modules/audio_processing/aec3/subband_erle_estimator.cc
@@ -40,17 +40,21 @@
 
 }  // namespace
 
-SubbandErleEstimator::SubbandErleEstimator(const EchoCanceller3Config& config)
+SubbandErleEstimator::SubbandErleEstimator(const EchoCanceller3Config& config,
+                                           size_t num_capture_channels)
     : min_erle_(config.erle.min),
       max_erle_(SetMaxErleBands(config.erle.max_l, config.erle.max_h)),
-      use_min_erle_during_onsets_(EnableMinErleDuringOnsets()) {
+      use_min_erle_during_onsets_(EnableMinErleDuringOnsets()),
+      erle_(num_capture_channels) {
   Reset();
 }
 
 SubbandErleEstimator::~SubbandErleEstimator() = default;
 
 void SubbandErleEstimator::Reset() {
-  erle_.fill(min_erle_);
+  for (auto& erle : erle_) {
+    erle.fill(min_erle_);
+  }
   erle_onsets_.fill(min_erle_);
   coming_onset_.fill(true);
   hold_counters_.fill(0);
@@ -74,8 +78,10 @@
     DecreaseErlePerBandForLowRenderSignals();
   }
 
-  erle_[0] = erle_[1];
-  erle_[kFftLengthBy2] = erle_[kFftLengthBy2 - 1];
+  for (auto& erle : erle_) {
+    erle[0] = erle[1];
+    erle[kFftLengthBy2] = erle[kFftLengthBy2 - 1];
+  }
 }
 
 void SubbandErleEstimator::Dump(
@@ -116,11 +122,12 @@
   for (size_t k = 1; k < kFftLengthBy2; ++k) {
     if (is_erle_updated[k]) {
       float alpha = 0.05f;
-      if (new_erle[k] < erle_[k]) {
+      if (new_erle[k] < erle_[0][k]) {
         alpha = accum_spectra_.low_render_energy_[k] ? 0.f : 0.1f;
       }
-      erle_[k] = rtc::SafeClamp(erle_[k] + alpha * (new_erle[k] - erle_[k]),
-                                min_erle_, max_erle_[k]);
+      erle_[0][k] =
+          rtc::SafeClamp(erle_[0][k] + alpha * (new_erle[k] - erle_[0][k]),
+                         min_erle_, max_erle_[k]);
     }
   }
 }
@@ -129,9 +136,9 @@
   for (size_t k = 1; k < kFftLengthBy2; ++k) {
     hold_counters_[k]--;
     if (hold_counters_[k] <= (kBlocksForOnsetDetection - kBlocksToHoldErle)) {
-      if (erle_[k] > erle_onsets_[k]) {
-        erle_[k] = std::max(erle_onsets_[k], 0.97f * erle_[k]);
-        RTC_DCHECK_LE(min_erle_, erle_[k]);
+      if (erle_[0][k] > erle_onsets_[k]) {
+        erle_[0][k] = std::max(erle_onsets_[k], 0.97f * erle_[0][k]);
+        RTC_DCHECK_LE(min_erle_, erle_[0][k]);
       }
       if (hold_counters_[k] <= 0) {
         coming_onset_[k] = true;
diff --git a/modules/audio_processing/aec3/subband_erle_estimator.h b/modules/audio_processing/aec3/subband_erle_estimator.h
index 0a22d61..18bab7d 100644
--- a/modules/audio_processing/aec3/subband_erle_estimator.h
+++ b/modules/audio_processing/aec3/subband_erle_estimator.h
@@ -27,7 +27,8 @@
 // Estimates the echo return loss enhancement for each frequency subband.
 class SubbandErleEstimator {
  public:
-  explicit SubbandErleEstimator(const EchoCanceller3Config& config);
+  SubbandErleEstimator(const EchoCanceller3Config& config,
+                       size_t num_capture_channels);
   ~SubbandErleEstimator();
 
   // Resets the ERLE estimator.
@@ -41,7 +42,9 @@
               bool onset_detection);
 
   // Returns the ERLE estimate.
-  const std::array<float, kFftLengthBy2Plus1>& Erle() const { return erle_; }
+  rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> Erle() const {
+    return erle_;
+  }
 
   // Returns the ERLE estimate at onsets.
   rtc::ArrayView<const float> ErleOnsets() const { return erle_onsets_; }
@@ -69,7 +72,7 @@
   const std::array<float, kFftLengthBy2Plus1> max_erle_;
   const bool use_min_erle_during_onsets_;
   AccumulatedSpectra accum_spectra_;
-  std::array<float, kFftLengthBy2Plus1> erle_;
+  std::vector<std::array<float, kFftLengthBy2Plus1>> erle_;
   std::array<float, kFftLengthBy2Plus1> erle_onsets_;
   std::array<bool, kFftLengthBy2Plus1> coming_onset_;
   std::array<int, kFftLengthBy2Plus1> hold_counters_;