AEC3: Made EchoAudibility multichannel

This CL corrects the EchoAudibility and StationarityEstimator
code to work properly with multiple channels.

It also changes the naming of the FilterDelayBlocks() method
to better reflect what it does.

The changes have been verified to be bitexact over a large number
of recordings.

Bug: webrtc:10913
Change-Id: I070b531efcdff4c33f70fd5b37fbb556dcebe5b4
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/156565
Reviewed-by: Sam Zackrisson <saza@webrtc.org>
Commit-Queue: Per Åhgren <peah@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#29482}
diff --git a/modules/audio_processing/aec3/adaptive_fir_filter_unittest.cc b/modules/audio_processing/aec3/adaptive_fir_filter_unittest.cc
index f1a6489..022e860 100644
--- a/modules/audio_processing/aec3/adaptive_fir_filter_unittest.cc
+++ b/modules/audio_processing/aec3/adaptive_fir_filter_unittest.cc
@@ -441,7 +441,7 @@
         auto* const render_buffer = render_delay_buffer->GetRenderBuffer();
 
         render_signal_analyzer.Update(*render_buffer,
-                                      aec_state.FilterDelayBlocks());
+                                      aec_state.MinDirectPathFilterDelay());
 
         filter.Filter(*render_buffer, &S);
         fft.Ifft(S, &s_scratch);
diff --git a/modules/audio_processing/aec3/aec_state.cc b/modules/audio_processing/aec3/aec_state.cc
index 7518e3a3..13b9bcc 100644
--- a/modules/audio_processing/aec3/aec_state.cc
+++ b/modules/audio_processing/aec3/aec_state.cc
@@ -29,7 +29,7 @@
 constexpr size_t kBlocksSinceConvergencedFilterInit = 10000;
 constexpr size_t kBlocksSinceConsistentEstimateInit = 10000;
 
-void UpdateAndComputeReverb(
+void ComputeAvgRenderReverb(
     const SpectrumBuffer& spectrum_buffer,
     int delay_blocks,
     float reverb_decay,
@@ -211,16 +211,16 @@
   strong_not_saturated_render_blocks_ +=
       active_render && !SaturatedCapture() ? 1 : 0;
 
-  std::array<float, kFftLengthBy2Plus1> X2_reverb;
+  std::array<float, kFftLengthBy2Plus1> avg_render_spectrum_with_reverb;
 
-  UpdateAndComputeReverb(render_buffer.GetSpectrumBuffer(),
-                         delay_state_.DirectPathFilterDelays()[0],
-                         ReverbDecay(), &reverb_model_, X2_reverb);
+  ComputeAvgRenderReverb(render_buffer.GetSpectrumBuffer(),
+                         delay_state_.MinDirectPathFilterDelay(), ReverbDecay(),
+                         &avg_render_reverb_, avg_render_spectrum_with_reverb);
 
   if (config_.echo_audibility.use_stationarity_properties) {
     // Update the echo audibility evaluator.
-    echo_audibility_.Update(render_buffer, reverb_model_.reverb(),
-                            delay_state_.DirectPathFilterDelays()[0],
+    echo_audibility_.Update(render_buffer, avg_render_reverb_.reverb(),
+                            delay_state_.MinDirectPathFilterDelay(),
                             delay_state_.ExternalDelayReported());
   }
 
@@ -229,17 +229,15 @@
     erle_estimator_.Reset(false);
   }
 
-  // TODO(bugs.webrtc.org/10913): Take all channels into account.
-  const auto& X2 =
-      render_buffer.Spectrum(delay_state_.DirectPathFilterDelays()[0],
-                             /*channel=*/0);
-  const auto& X2_input_erle = X2_reverb;
-
   erle_estimator_.Update(render_buffer, adaptive_filter_frequency_responses[0],
-                         X2_input_erle, Y2[0], E2_main[0],
+                         avg_render_spectrum_with_reverb, Y2[0], E2_main[0],
                          subtractor_output_analyzers_[0].ConvergedFilter(),
                          config_.erle.onset_detection);
 
+  // TODO(bugs.webrtc.org/10913): Take all channels into account.
+  const auto& X2 =
+      render_buffer.Spectrum(delay_state_.MinDirectPathFilterDelay(),
+                             /*channel=*/0);
   erl_estimator_.Update(subtractor_output_analyzers_[0].ConvergedFilter(), X2,
                         Y2[0]);
 
@@ -357,6 +355,9 @@
               analyzer_filter_delay_estimates_blocks.end(),
               filter_delays_blocks_.begin());
   }
+
+  min_filter_delay_ = *std::min_element(filter_delays_blocks_.begin(),
+                                        filter_delays_blocks_.end());
 }
 
 AecState::TransparentMode::TransparentMode(const EchoCanceller3Config& config)
diff --git a/modules/audio_processing/aec3/aec_state.h b/modules/audio_processing/aec3/aec_state.h
index 79fe13e..71000b4 100644
--- a/modules/audio_processing/aec3/aec_state.h
+++ b/modules/audio_processing/aec3/aec_state.h
@@ -91,8 +91,8 @@
   float ErlTimeDomain() const { return erl_estimator_.ErlTimeDomain(); }
 
   // Returns the delay estimate based on the linear filter.
-  int FilterDelayBlocks() const {
-    return delay_state_.DirectPathFilterDelays()[0];
+  int MinDirectPathFilterDelay() const {
+    return delay_state_.MinDirectPathFilterDelay();
   }
 
   // Returns whether the capture signal is saturated.
@@ -194,6 +194,10 @@
       return filter_delays_blocks_;
     }
 
+    // Returns the minimum delay among the direct path delays relative to the
+    // beginning of the filter
+    int MinDirectPathFilterDelay() const { return min_filter_delay_; }
+
     // Updates the delay estimates based on new data.
     void Update(
         rtc::ArrayView<const int> analyzer_filter_delay_estimates_blocks,
@@ -204,6 +208,7 @@
     const int delay_headroom_samples_;
     bool external_delay_reported_ = false;
     std::vector<int> filter_delays_blocks_;
+    int min_filter_delay_ = 0;
     absl::optional<DelayEstimate> external_delay_;
   } delay_state_;
 
@@ -308,7 +313,7 @@
   absl::optional<DelayEstimate> external_delay_;
   EchoAudibility echo_audibility_;
   ReverbModelEstimator reverb_model_estimator_;
-  ReverbModel reverb_model_;
+  ReverbModel avg_render_reverb_;
   std::vector<SubtractorOutputAnalyzer> subtractor_output_analyzers_;
 };
 
diff --git a/modules/audio_processing/aec3/echo_audibility.cc b/modules/audio_processing/aec3/echo_audibility.cc
index c534108..db62236 100644
--- a/modules/audio_processing/aec3/echo_audibility.cc
+++ b/modules/audio_processing/aec3/echo_audibility.cc
@@ -29,18 +29,16 @@
 
 EchoAudibility::~EchoAudibility() = default;
 
-void EchoAudibility::Update(
-    const RenderBuffer& render_buffer,
-    rtc::ArrayView<const float> render_reverb_contribution_spectrum,
-    int delay_blocks,
-    bool external_delay_seen) {
+void EchoAudibility::Update(const RenderBuffer& render_buffer,
+                            rtc::ArrayView<const float> average_reverb,
+                            int delay_blocks,
+                            bool external_delay_seen) {
   UpdateRenderNoiseEstimator(render_buffer.GetSpectrumBuffer(),
                              render_buffer.GetBlockBuffer(),
                              external_delay_seen);
 
   if (external_delay_seen || use_render_stationarity_at_init_) {
-    UpdateRenderStationarityFlags(
-        render_buffer, render_reverb_contribution_spectrum, delay_blocks);
+    UpdateRenderStationarityFlags(render_buffer, average_reverb, delay_blocks);
   }
 }
 
@@ -52,18 +50,17 @@
 
 void EchoAudibility::UpdateRenderStationarityFlags(
     const RenderBuffer& render_buffer,
-    rtc::ArrayView<const float> render_reverb_contribution_spectrum,
-    int delay_blocks) {
+    rtc::ArrayView<const float> average_reverb,
+    int min_channel_delay_blocks) {
   const SpectrumBuffer& spectrum_buffer = render_buffer.GetSpectrumBuffer();
-  int idx_at_delay =
-      spectrum_buffer.OffsetIndex(spectrum_buffer.read, delay_blocks);
+  int idx_at_delay = spectrum_buffer.OffsetIndex(spectrum_buffer.read,
+                                                 min_channel_delay_blocks);
 
-  int num_lookahead = render_buffer.Headroom() - delay_blocks + 1;
+  int num_lookahead = render_buffer.Headroom() - min_channel_delay_blocks + 1;
   num_lookahead = std::max(0, num_lookahead);
 
-  render_stationarity_.UpdateStationarityFlags(
-      spectrum_buffer, render_reverb_contribution_spectrum, idx_at_delay,
-      num_lookahead);
+  render_stationarity_.UpdateStationarityFlags(spectrum_buffer, average_reverb,
+                                               idx_at_delay, num_lookahead);
 }
 
 void EchoAudibility::UpdateRenderNoiseEstimator(
@@ -83,14 +80,15 @@
     for (int idx = render_spectrum_write_prev_.value();
          idx != render_spectrum_write_current;
          idx = spectrum_buffer.DecIndex(idx)) {
-      render_stationarity_.UpdateNoiseEstimator(
-          spectrum_buffer.buffer[idx][/*channel=*/0]);
+      render_stationarity_.UpdateNoiseEstimator(spectrum_buffer.buffer[idx]);
     }
   }
   render_spectrum_write_prev_ = render_spectrum_write_current;
 }
 
 bool EchoAudibility::IsRenderTooLow(const BlockBuffer& block_buffer) {
+  const int num_render_channels =
+      static_cast<int>(block_buffer.buffer[0].size());
   bool too_low = false;
   const int render_block_write_current = block_buffer.write;
   if (render_block_write_current == render_block_write_prev_) {
@@ -98,10 +96,16 @@
   } else {
     for (int idx = render_block_write_prev_; idx != render_block_write_current;
          idx = block_buffer.IncIndex(idx)) {
-      auto block = block_buffer.buffer[idx][0][0];
-      auto r = std::minmax_element(block.cbegin(), block.cend());
-      float max_abs = std::max(std::fabs(*r.first), std::fabs(*r.second));
-      if (max_abs < 10) {
+      float max_abs_over_channels = 0.f;
+      for (int ch = 0; ch < num_render_channels; ++ch) {
+        auto block = block_buffer.buffer[idx][0][ch];
+        auto r = std::minmax_element(block.cbegin(), block.cend());
+        float max_abs_channel =
+            std::max(std::fabs(*r.first), std::fabs(*r.second));
+        max_abs_over_channels =
+            std::max(max_abs_over_channels, max_abs_channel);
+      }
+      if (max_abs_over_channels < 10.f) {
         too_low = true;  // Discards all blocks if one of them is too low.
         break;
       }
diff --git a/modules/audio_processing/aec3/echo_audibility.h b/modules/audio_processing/aec3/echo_audibility.h
index 0152ea4..1ffc017 100644
--- a/modules/audio_processing/aec3/echo_audibility.h
+++ b/modules/audio_processing/aec3/echo_audibility.h
@@ -28,10 +28,13 @@
   explicit EchoAudibility(bool use_render_stationarity_at_init);
   ~EchoAudibility();
 
+  EchoAudibility(const EchoAudibility&) = delete;
+  EchoAudibility& operator=(const EchoAudibility&) = delete;
+
   // Feed new render data to the echo audibility estimator.
   void Update(const RenderBuffer& render_buffer,
-              rtc::ArrayView<const float> render_reverb_contribution_spectrum,
-              int delay_blocks,
+              rtc::ArrayView<const float> average_reverb,
+              int min_channel_delay_blocks,
               bool external_delay_seen);
   // Get the residual echo scaling.
   void GetResidualEchoScaling(bool filter_has_had_time_to_converge,
@@ -57,10 +60,9 @@
   void Reset();
 
   // Updates the render stationarity flags for the current frame.
-  void UpdateRenderStationarityFlags(
-      const RenderBuffer& render_buffer,
-      rtc::ArrayView<const float> render_reverb_contribution_spectrum,
-      int delay_blocks);
+  void UpdateRenderStationarityFlags(const RenderBuffer& render_buffer,
+                                     rtc::ArrayView<const float> average_reverb,
+                                     int delay_blocks);
 
   // Updates the noise estimator with the new render data since the previous
   // call to this method.
@@ -77,7 +79,6 @@
   bool non_zero_render_seen_;
   const bool use_render_stationarity_at_init_;
   StationarityEstimator render_stationarity_;
-  RTC_DISALLOW_COPY_AND_ASSIGN(EchoAudibility);
 };
 
 }  // namespace webrtc
diff --git a/modules/audio_processing/aec3/echo_remover.cc b/modules/audio_processing/aec3/echo_remover.cc
index e6f17c7..7cec47c 100644
--- a/modules/audio_processing/aec3/echo_remover.cc
+++ b/modules/audio_processing/aec3/echo_remover.cc
@@ -356,7 +356,7 @@
 
   // Analyze the render signal.
   render_signal_analyzer_.Update(*render_buffer,
-                                 aec_state_.FilterDelayBlocks());
+                                 aec_state_.MinDirectPathFilterDelay());
 
   // State transition.
   if (aec_state_.TransitionTriggered()) {
@@ -457,10 +457,11 @@
   data_dumper_->DumpRaw("aec3_S2_linear", S2_linear[0]);
   data_dumper_->DumpRaw("aec3_Y2", Y2[0]);
   data_dumper_->DumpRaw(
-      "aec3_X2",
-      render_buffer->Spectrum(aec_state_.FilterDelayBlocks(), /*channel=*/0));
+      "aec3_X2", render_buffer->Spectrum(aec_state_.MinDirectPathFilterDelay(),
+                                         /*channel=*/0));
   data_dumper_->DumpRaw("aec3_R2", R2[0]);
-  data_dumper_->DumpRaw("aec3_filter_delay", aec_state_.FilterDelayBlocks());
+  data_dumper_->DumpRaw("aec3_filter_delay",
+                        aec_state_.MinDirectPathFilterDelay());
   data_dumper_->DumpRaw("aec3_capture_saturation",
                         aec_state_.SaturatedCapture() ? 1 : 0);
 }
diff --git a/modules/audio_processing/aec3/echo_remover_metrics.cc b/modules/audio_processing/aec3/echo_remover_metrics.cc
index 4ab05f8..69d2252 100644
--- a/modules/audio_processing/aec3/echo_remover_metrics.cc
+++ b/modules/audio_processing/aec3/echo_remover_metrics.cc
@@ -241,7 +241,8 @@
             static_cast<int>(
                 active_render_count_ > kMetricsCollectionBlocksBy2 ? 1 : 0));
         RTC_HISTOGRAM_COUNTS_LINEAR("WebRTC.Audio.EchoCanceller.FilterDelay",
-                                    aec_state.FilterDelayBlocks(), 0, 30, 31);
+                                    aec_state.MinDirectPathFilterDelay(), 0, 30,
+                                    31);
         RTC_HISTOGRAM_BOOLEAN("WebRTC.Audio.EchoCanceller.CaptureSaturation",
                               static_cast<int>(saturated_capture_ ? 1 : 0));
         break;
diff --git a/modules/audio_processing/aec3/main_filter_update_gain_unittest.cc b/modules/audio_processing/aec3/main_filter_update_gain_unittest.cc
index 92334c2..fa3b263 100644
--- a/modules/audio_processing/aec3/main_filter_update_gain_unittest.cc
+++ b/modules/audio_processing/aec3/main_filter_update_gain_unittest.cc
@@ -150,7 +150,7 @@
     render_delay_buffer->PrepareCaptureProcessing();
 
     render_signal_analyzer.Update(*render_delay_buffer->GetRenderBuffer(),
-                                  aec_state.FilterDelayBlocks());
+                                  aec_state.MinDirectPathFilterDelay());
 
     // Apply the main filter.
     main_filter.Filter(*render_delay_buffer->GetRenderBuffer(), &S);
diff --git a/modules/audio_processing/aec3/residual_echo_estimator.cc b/modules/audio_processing/aec3/residual_echo_estimator.cc
index 07197e3..6a8a3f2 100644
--- a/modules/audio_processing/aec3/residual_echo_estimator.cc
+++ b/modules/audio_processing/aec3/residual_echo_estimator.cc
@@ -203,7 +203,7 @@
       std::array<float, kFftLengthBy2Plus1> X2;
       EchoGeneratingPower(num_render_channels_,
                           render_buffer.GetSpectrumBuffer(), config_.echo_model,
-                          aec_state.FilterDelayBlocks(), X2);
+                          aec_state.MinDirectPathFilterDelay(), X2);
       if (!aec_state.UseStationarityProperties()) {
         ApplyNoiseGate(config_.echo_model, X2);
       }
@@ -288,9 +288,10 @@
   const size_t num_capture_channels = R2.size();
 
   // Choose reverb partition based on what type of echo power model is used.
-  const size_t first_reverb_partition = reverb_type == ReverbType::kLinear
-                                            ? aec_state.FilterLengthBlocks() + 1
-                                            : aec_state.FilterDelayBlocks() + 1;
+  const size_t first_reverb_partition =
+      reverb_type == ReverbType::kLinear
+          ? aec_state.FilterLengthBlocks() + 1
+          : aec_state.MinDirectPathFilterDelay() + 1;
 
   // Compute render power for the reverb.
   std::array<float, kFftLengthBy2Plus1> render_power_data;
diff --git a/modules/audio_processing/aec3/stationarity_estimator.cc b/modules/audio_processing/aec3/stationarity_estimator.cc
index 080d13d..d0c3c9c 100644
--- a/modules/audio_processing/aec3/stationarity_estimator.cc
+++ b/modules/audio_processing/aec3/stationarity_estimator.cc
@@ -45,7 +45,7 @@
 
 // Update just the noise estimator. Usefull until the delay is known
 void StationarityEstimator::UpdateNoiseEstimator(
-    rtc::ArrayView<const float> spectrum) {
+    rtc::ArrayView<const std::vector<float>> spectrum) {
   noise_.Update(spectrum);
   data_dumper_->DumpRaw("aec3_stationarity_noise_spectrum", noise_.Spectrum());
   data_dumper_->DumpRaw("aec3_stationarity_is_block_stationary",
@@ -99,15 +99,20 @@
 
 bool StationarityEstimator::EstimateBandStationarity(
     const SpectrumBuffer& spectrum_buffer,
-    rtc::ArrayView<const float> reverb,
+    rtc::ArrayView<const float> average_reverb,
     const std::array<int, kWindowLength>& indexes,
     size_t band) const {
   constexpr float kThrStationarity = 10.f;
   float acum_power = 0.f;
+  const int num_render_channels =
+      static_cast<int>(spectrum_buffer.buffer[0].size());
+  const float one_by_num_channels = 1.f / num_render_channels;
   for (auto idx : indexes) {
-    acum_power += spectrum_buffer.buffer[idx][/*channel=*/0][band];
+    for (int ch = 0; ch < num_render_channels; ++ch) {
+      acum_power += spectrum_buffer.buffer[idx][ch][band] * one_by_num_channels;
+    }
   }
-  acum_power += reverb[band];
+  acum_power += average_reverb[band];
   float noise = kWindowLength * GetStationarityPowerBand(band);
   RTC_CHECK_LT(0.f, noise);
   bool stationary = acum_power < kThrStationarity * noise;
@@ -163,16 +168,42 @@
 }
 
 void StationarityEstimator::NoiseSpectrum::Update(
-    rtc::ArrayView<const float> spectrum) {
-  RTC_DCHECK_EQ(kFftLengthBy2Plus1, spectrum.size());
+    rtc::ArrayView<const std::vector<float>> spectrum) {
+  RTC_DCHECK_LE(1, spectrum[0].size());
+  const int num_render_channels = static_cast<int>(spectrum.size());
+
+  std::array<float, kFftLengthBy2Plus1> avg_spectrum_data;
+  rtc::ArrayView<const float> avg_spectrum;
+  RTC_DCHECK_EQ(kFftLengthBy2Plus1, spectrum[0].size());
+  if (num_render_channels == 1) {
+    avg_spectrum = spectrum[0];
+  } else {
+    // For multiple channels, average the channel spectra before passing to the
+    // noise spectrum estimator.
+    avg_spectrum = avg_spectrum_data;
+    std::copy(spectrum[0].begin(), spectrum[0].end(),
+              avg_spectrum_data.begin());
+    for (int ch = 1; ch < num_render_channels; ++ch) {
+      RTC_DCHECK_EQ(kFftLengthBy2Plus1, spectrum[ch].size());
+      for (size_t k = 1; k < kFftLengthBy2Plus1; ++k) {
+        avg_spectrum_data[k] += spectrum[ch][k];
+      }
+    }
+
+    const float one_by_num_channels = 1.f / num_render_channels;
+    for (size_t k = 1; k < kFftLengthBy2Plus1; ++k) {
+      avg_spectrum_data[k] *= one_by_num_channels;
+    }
+  }
+
   ++block_counter_;
   float alpha = GetAlpha();
-  for (size_t k = 0; k < spectrum.size(); ++k) {
+  for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) {
     if (block_counter_ <= kNBlocksAverageInitPhase) {
-      noise_spectrum_[k] += (1.f / kNBlocksAverageInitPhase) * spectrum[k];
+      noise_spectrum_[k] += (1.f / kNBlocksAverageInitPhase) * avg_spectrum[k];
     } else {
       noise_spectrum_[k] =
-          UpdateBandBySmoothing(spectrum[k], noise_spectrum_[k], alpha);
+          UpdateBandBySmoothing(avg_spectrum[k], noise_spectrum_[k], alpha);
     }
   }
 }
diff --git a/modules/audio_processing/aec3/stationarity_estimator.h b/modules/audio_processing/aec3/stationarity_estimator.h
index 504fea7..5860ef1 100644
--- a/modules/audio_processing/aec3/stationarity_estimator.h
+++ b/modules/audio_processing/aec3/stationarity_estimator.h
@@ -15,6 +15,7 @@
 
 #include <array>
 #include <memory>
+#include <vector>
 
 #include "api/array_view.h"
 #include "modules/audio_processing/aec3/aec3_common.h"  // kFftLengthBy2Plus1...
@@ -35,7 +36,7 @@
   void Reset();
 
   // Update just the noise estimator. Usefull until the delay is known
-  void UpdateNoiseEstimator(rtc::ArrayView<const float> spectrum);
+  void UpdateNoiseEstimator(rtc::ArrayView<const std::vector<float>> spectrum);
 
   // Update the flag indicating whether this current frame is stationary. For
   // getting a more robust estimation, it looks at future and/or past frames.
@@ -61,7 +62,7 @@
   // Get an estimation of the stationarity for the current band by looking
   // at the past/present/future available data.
   bool EstimateBandStationarity(const SpectrumBuffer& spectrum_buffer,
-                                rtc::ArrayView<const float> reverb,
+                                rtc::ArrayView<const float> average_reverb,
                                 const std::array<int, kWindowLength>& indexes,
                                 size_t band) const;
 
@@ -85,7 +86,7 @@
     void Reset();
 
     // Update the noise power spectrum with a new frame.
-    void Update(rtc::ArrayView<const float> spectrum);
+    void Update(rtc::ArrayView<const std::vector<float>> spectrum);
 
     // Get the noise estimation power spectrum.
     rtc::ArrayView<const float> Spectrum() const { return noise_spectrum_; }
diff --git a/modules/audio_processing/aec3/subtractor_unittest.cc b/modules/audio_processing/aec3/subtractor_unittest.cc
index 05faa4f..b59fa7b 100644
--- a/modules/audio_processing/aec3/subtractor_unittest.cc
+++ b/modules/audio_processing/aec3/subtractor_unittest.cc
@@ -135,7 +135,7 @@
     }
     render_delay_buffer->PrepareCaptureProcessing();
     render_signal_analyzer.Update(*render_delay_buffer->GetRenderBuffer(),
-                                  aec_state.FilterDelayBlocks());
+                                  aec_state.MinDirectPathFilterDelay());
 
     // Handle echo path changes.
     if (std::find(blocks_with_echo_path_changes.begin(),