AEC3: Multi channel ERL estimator

The estimator will simply compute the worst value of all combinations
of render and capture signal.

This has the drawback that low-volume or silent render channels may
severely misestimate the ERL.

The changes have been shown to be bitexact over a large dataset.

Bug: webrtc:10913
Change-Id: Id53c3ab81646ac0fab303edafc5e38892d285d8e
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/157308
Commit-Queue: Sam Zackrisson <saza@webrtc.org>
Reviewed-by: Per Ã…hgren <peah@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#29542}
diff --git a/modules/audio_processing/aec3/aec_state.cc b/modules/audio_processing/aec3/aec_state.cc
index d35bed5..365ec9e 100644
--- a/modules/audio_processing/aec3/aec_state.cc
+++ b/modules/audio_processing/aec3/aec_state.cc
@@ -230,11 +230,9 @@
                          avg_render_spectrum_with_reverb, Y2, E2_main,
                          subtractor_output_analyzer_.ConvergedFilters());
 
-  // TODO(bugs.webrtc.org/10913): Take all channels into account.
-  const auto& X2 = render_buffer.Spectrum(
-      delay_state_.MinDirectPathFilterDelay())[/*channel=*/0];
-  erl_estimator_.Update(subtractor_output_analyzer_.ConvergedFilters()[0], X2,
-                        Y2[0]);
+  erl_estimator_.Update(
+      subtractor_output_analyzer_.ConvergedFilters(),
+      render_buffer.Spectrum(delay_state_.MinDirectPathFilterDelay()), Y2);
 
   // Detect and flag echo saturation.
   saturation_detector_.Update(aligned_render_block, SaturatedCapture(),
diff --git a/modules/audio_processing/aec3/aec_state_unittest.cc b/modules/audio_processing/aec3/aec_state_unittest.cc
index b038770..c068b6e 100644
--- a/modules/audio_processing/aec3/aec_state_unittest.cc
+++ b/modules/audio_processing/aec3/aec_state_unittest.cc
@@ -106,7 +106,9 @@
   EXPECT_FALSE(state.UsableLinearEstimate());
 
   // Verify that the active render detection works as intended.
-  std::fill(x[0][0].begin(), x[0][0].end(), 101.f);
+  for (size_t ch = 0; ch < num_render_channels; ++ch) {
+    std::fill(x[0][ch].begin(), x[0][ch].end(), 101.f);
+  }
   render_delay_buffer->Insert(x);
   for (size_t ch = 0; ch < num_capture_channels; ++ch) {
     subtractor_output[ch].ComputeMetrics(y[ch]);
@@ -136,7 +138,9 @@
     }
   }
 
-  x[0][0][0] = 5000.f;
+  for (size_t ch = 0; ch < num_render_channels; ++ch) {
+    x[0][ch][0] = 5000.f;
+  }
   for (size_t k = 0;
        k < render_delay_buffer->GetRenderBuffer()->GetFftBuffer().size(); ++k) {
     render_delay_buffer->Insert(x);
diff --git a/modules/audio_processing/aec3/erl_estimator.cc b/modules/audio_processing/aec3/erl_estimator.cc
index 4a0c441..01cc33c 100644
--- a/modules/audio_processing/aec3/erl_estimator.cc
+++ b/modules/audio_processing/aec3/erl_estimator.cc
@@ -39,20 +39,69 @@
 }
 
 void ErlEstimator::Update(
-    bool converged_filter,
-    rtc::ArrayView<const float, kFftLengthBy2Plus1> render_spectrum,
-    rtc::ArrayView<const float, kFftLengthBy2Plus1> capture_spectrum) {
-  const auto& X2 = render_spectrum;
-  const auto& Y2 = capture_spectrum;
+    const std::vector<bool>& converged_filters,
+    rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> render_spectra,
+    rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>>
+        capture_spectra) {
+  const size_t num_capture_channels = converged_filters.size();
+  RTC_DCHECK_EQ(capture_spectra.size(), num_capture_channels);
 
   // Corresponds to WGN of power -46 dBFS.
   constexpr float kX2Min = 44015068.0f;
 
+  const auto first_converged_iter =
+      std::find(converged_filters.begin(), converged_filters.end(), true);
+  const bool any_filter_converged =
+      first_converged_iter != converged_filters.end();
+
   if (++blocks_since_reset_ < startup_phase_length_blocks__ ||
-      !converged_filter) {
+      !any_filter_converged) {
     return;
   }
 
+  // Use the maximum spectrum across capture and the maximum across render.
+  std::array<float, kFftLengthBy2Plus1> max_capture_spectrum_data;
+  std::array<float, kFftLengthBy2Plus1> max_capture_spectrum =
+      capture_spectra[/*channel=*/0];
+  if (num_capture_channels > 1) {
+    // Initialize using the first channel with a converged filter.
+    const size_t first_converged =
+        std::distance(converged_filters.begin(), first_converged_iter);
+    RTC_DCHECK_GE(first_converged, 0);
+    RTC_DCHECK_LT(first_converged, num_capture_channels);
+    max_capture_spectrum_data = capture_spectra[first_converged];
+
+    for (size_t ch = first_converged + 1; ch < num_capture_channels; ++ch) {
+      if (!converged_filters[ch]) {
+        continue;
+      }
+      for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) {
+        max_capture_spectrum_data[k] =
+            std::max(max_capture_spectrum_data[k], capture_spectra[ch][k]);
+      }
+    }
+    max_capture_spectrum = max_capture_spectrum_data;
+  }
+
+  const size_t num_render_channels = render_spectra.size();
+  std::array<float, kFftLengthBy2Plus1> max_render_spectrum_data;
+  rtc::ArrayView<const float, kFftLengthBy2Plus1> max_render_spectrum =
+      render_spectra[/*channel=*/0];
+  if (num_render_channels > 1) {
+    std::copy(render_spectra[0].begin(), render_spectra[0].end(),
+              max_render_spectrum_data.begin());
+    for (size_t ch = 1; ch < num_render_channels; ++ch) {
+      for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) {
+        max_render_spectrum_data[k] =
+            std::max(max_render_spectrum_data[k], render_spectra[ch][k]);
+      }
+    }
+    max_render_spectrum = max_render_spectrum_data;
+  }
+
+  const auto& X2 = max_render_spectrum;
+  const auto& Y2 = max_capture_spectrum;
+
   // Update the estimates in a maximum statistics manner.
   for (size_t k = 1; k < kFftLengthBy2; ++k) {
     if (X2[k] > kX2Min) {
diff --git a/modules/audio_processing/aec3/erl_estimator.h b/modules/audio_processing/aec3/erl_estimator.h
index 25dc39c..89bf6ac 100644
--- a/modules/audio_processing/aec3/erl_estimator.h
+++ b/modules/audio_processing/aec3/erl_estimator.h
@@ -14,6 +14,7 @@
 #include <stddef.h>
 
 #include <array>
+#include <vector>
 
 #include "api/array_view.h"
 #include "modules/audio_processing/aec3/aec3_common.h"
@@ -31,9 +32,11 @@
   void Reset();
 
   // Updates the ERL estimate.
-  void Update(bool converged_filter,
-              rtc::ArrayView<const float, kFftLengthBy2Plus1> render_spectrum,
-              rtc::ArrayView<const float, kFftLengthBy2Plus1> capture_spectrum);
+  void Update(const std::vector<bool>& converged_filters,
+              rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>>
+                  render_spectra,
+              rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>>
+                  capture_spectra);
 
   // Returns the most recent ERL estimate.
   const std::array<float, kFftLengthBy2Plus1>& Erl() const { return erl_; }
diff --git a/modules/audio_processing/aec3/erl_estimator_unittest.cc b/modules/audio_processing/aec3/erl_estimator_unittest.cc
index 1b965d0..344551d 100644
--- a/modules/audio_processing/aec3/erl_estimator_unittest.cc
+++ b/modules/audio_processing/aec3/erl_estimator_unittest.cc
@@ -10,11 +10,19 @@
 
 #include "modules/audio_processing/aec3/erl_estimator.h"
 
+#include "rtc_base/strings/string_builder.h"
 #include "test/gtest.h"
 
 namespace webrtc {
 
 namespace {
+std::string ProduceDebugText(size_t num_render_channels,
+                             size_t num_capture_channels) {
+  rtc::StringBuilder ss;
+  ss << "Render channels: " << num_render_channels;
+  ss << ", Capture channels: " << num_capture_channels;
+  return ss.Release();
+}
 
 void VerifyErl(const std::array<float, kFftLengthBy2Plus1>& erl,
                float erl_time_domain,
@@ -28,45 +36,65 @@
 
 // Verifies that the correct ERL estimates are achieved.
 TEST(ErlEstimator, Estimates) {
-  std::array<float, kFftLengthBy2Plus1> X2;
-  std::array<float, kFftLengthBy2Plus1> Y2;
+  for (size_t num_render_channels : {1, 2, 8}) {
+    for (size_t num_capture_channels : {1, 2, 8}) {
+      SCOPED_TRACE(ProduceDebugText(num_render_channels, num_capture_channels));
+      std::vector<std::array<float, kFftLengthBy2Plus1>> X2(
+          num_render_channels);
+      for (auto& X2_ch : X2) {
+        X2_ch.fill(0.f);
+      }
+      std::vector<std::array<float, kFftLengthBy2Plus1>> Y2(
+          num_capture_channels);
+      for (auto& Y2_ch : Y2) {
+        Y2_ch.fill(0.f);
+      }
+      std::vector<bool> converged_filters(num_capture_channels, false);
+      const size_t converged_idx = num_capture_channels - 1;
+      converged_filters[converged_idx] = true;
 
-  ErlEstimator estimator(0);
+      ErlEstimator estimator(0);
 
-  // Verifies that the ERL estimate is properly reduced to lower values.
-  X2.fill(500 * 1000.f * 1000.f);
-  Y2.fill(10 * X2[0]);
-  for (size_t k = 0; k < 200; ++k) {
-    estimator.Update(true, X2, Y2);
+      // Verifies that the ERL estimate is properly reduced to lower values.
+      for (auto& X2_ch : X2) {
+        X2_ch.fill(500 * 1000.f * 1000.f);
+      }
+      Y2[converged_idx].fill(10 * X2[0][0]);
+      for (size_t k = 0; k < 200; ++k) {
+        estimator.Update(converged_filters, X2, Y2);
+      }
+      VerifyErl(estimator.Erl(), estimator.ErlTimeDomain(), 10.f);
+
+      // Verifies that the ERL is not immediately increased when the ERL in the
+      // data increases.
+      Y2[converged_idx].fill(10000 * X2[0][0]);
+      for (size_t k = 0; k < 998; ++k) {
+        estimator.Update(converged_filters, X2, Y2);
+      }
+      VerifyErl(estimator.Erl(), estimator.ErlTimeDomain(), 10.f);
+
+      // Verifies that the rate of increase is 3 dB.
+      estimator.Update(converged_filters, X2, Y2);
+      VerifyErl(estimator.Erl(), estimator.ErlTimeDomain(), 20.f);
+
+      // Verifies that the maximum ERL is achieved when there are no low RLE
+      // estimates.
+      for (size_t k = 0; k < 1000; ++k) {
+        estimator.Update(converged_filters, X2, Y2);
+      }
+      VerifyErl(estimator.Erl(), estimator.ErlTimeDomain(), 1000.f);
+
+      // Verifies that the ERL estimate is is not updated for low-level signals
+      for (auto& X2_ch : X2) {
+        X2_ch.fill(1000.f * 1000.f);
+      }
+      Y2[converged_idx].fill(10 * X2[0][0]);
+      for (size_t k = 0; k < 200; ++k) {
+        estimator.Update(converged_filters, X2, Y2);
+      }
+      VerifyErl(estimator.Erl(), estimator.ErlTimeDomain(), 1000.f);
+    }
   }
-  VerifyErl(estimator.Erl(), estimator.ErlTimeDomain(), 10.f);
-
-  // Verifies that the ERL is not immediately increased when the ERL in the data
-  // increases.
-  Y2.fill(10000 * X2[0]);
-  for (size_t k = 0; k < 998; ++k) {
-    estimator.Update(true, X2, Y2);
-  }
-  VerifyErl(estimator.Erl(), estimator.ErlTimeDomain(), 10.f);
-
-  // Verifies that the rate of increase is 3 dB.
-  estimator.Update(true, X2, Y2);
-  VerifyErl(estimator.Erl(), estimator.ErlTimeDomain(), 20.f);
-
-  // Verifies that the maximum ERL is achieved when there are no low RLE
-  // estimates.
-  for (size_t k = 0; k < 1000; ++k) {
-    estimator.Update(true, X2, Y2);
-  }
-  VerifyErl(estimator.Erl(), estimator.ErlTimeDomain(), 1000.f);
-
-  // Verifies that the ERL estimate is is not updated for low-level signals
-  X2.fill(1000.f * 1000.f);
-  Y2.fill(10 * X2[0]);
-  for (size_t k = 0; k < 200; ++k) {
-    estimator.Update(true, X2, Y2);
-  }
-  VerifyErl(estimator.Erl(), estimator.ErlTimeDomain(), 1000.f);
 }
 
 }  // namespace webrtc