AEC3: Handle multichannel audio in single CNG instance

Instead of having a comfort noise generator (CNG) instance per capture
channel, one instance handles CNG for all capture channels.

Bug: webrtc:10913
Change-Id: I897471be6d203ad750c517c5076d421f2ae3879b
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/158780
Reviewed-by: Per Ã…hgren <peah@webrtc.org>
Commit-Queue: Gustaf Ullberg <gustaf@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#29668}
diff --git a/modules/audio_processing/aec3/comfort_noise_generator.cc b/modules/audio_processing/aec3/comfort_noise_generator.cc
index 005c25c..16c4a2b 100644
--- a/modules/audio_processing/aec3/comfort_noise_generator.cc
+++ b/modules/audio_processing/aec3/comfort_noise_generator.cc
@@ -93,39 +93,49 @@
 }  // namespace
 
 ComfortNoiseGenerator::ComfortNoiseGenerator(Aec3Optimization optimization,
-                                             uint32_t seed)
+                                             size_t num_capture_channels)
     : optimization_(optimization),
-      seed_(seed),
-      N2_initial_(new std::array<float, kFftLengthBy2Plus1>()) {
-  N2_initial_->fill(0.f);
-  Y2_smoothed_.fill(0.f);
-  N2_.fill(1.0e6f);
+      seed_(42),
+      num_capture_channels_(num_capture_channels),
+      N2_initial_(
+          std::make_unique<std::vector<std::array<float, kFftLengthBy2Plus1>>>(
+              num_capture_channels_)),
+      Y2_smoothed_(num_capture_channels_),
+      N2_(num_capture_channels_) {
+  for (size_t ch = 0; ch < num_capture_channels_; ++ch) {
+    (*N2_initial_)[ch].fill(0.f);
+    Y2_smoothed_[ch].fill(0.f);
+    N2_[ch].fill(1.0e6f);
+  }
 }
 
 ComfortNoiseGenerator::~ComfortNoiseGenerator() = default;
 
 void ComfortNoiseGenerator::Compute(
     bool saturated_capture,
-    const std::array<float, kFftLengthBy2Plus1>& capture_spectrum,
-    FftData* lower_band_noise,
-    FftData* upper_band_noise) {
-  RTC_DCHECK(lower_band_noise);
-  RTC_DCHECK(upper_band_noise);
+    rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>>
+        capture_spectrum,
+    rtc::ArrayView<FftData> lower_band_noise,
+    rtc::ArrayView<FftData> upper_band_noise) {
   const auto& Y2 = capture_spectrum;
 
   if (!saturated_capture) {
     // Smooth Y2.
-    std::transform(Y2_smoothed_.begin(), Y2_smoothed_.end(), Y2.begin(),
-                   Y2_smoothed_.begin(),
-                   [](float a, float b) { return a + 0.1f * (b - a); });
+    for (size_t ch = 0; ch < num_capture_channels_; ++ch) {
+      std::transform(Y2_smoothed_[ch].begin(), Y2_smoothed_[ch].end(),
+                     Y2[ch].begin(), Y2_smoothed_[ch].begin(),
+                     [](float a, float b) { return a + 0.1f * (b - a); });
+    }
 
     if (N2_counter_ > 50) {
       // Update N2 from Y2_smoothed.
-      std::transform(N2_.begin(), N2_.end(), Y2_smoothed_.begin(), N2_.begin(),
-                     [](float a, float b) {
-                       return b < a ? (0.9f * b + 0.1f * a) * 1.0002f
-                                    : a * 1.0002f;
-                     });
+      for (size_t ch = 0; ch < num_capture_channels_; ++ch) {
+        std::transform(N2_[ch].begin(), N2_[ch].end(), Y2_smoothed_[ch].begin(),
+                       N2_[ch].begin(), [](float a, float b) {
+                         return b < a ? (0.9f * b + 0.1f * a) * 1.0002f
+                                      : a * 1.0002f;
+                       });
+      }
     }
 
     if (N2_initial_) {
@@ -133,31 +143,38 @@
         N2_initial_.reset();
       } else {
         // Compute the N2_initial from N2.
-        std::transform(
-            N2_.begin(), N2_.end(), N2_initial_->begin(), N2_initial_->begin(),
-            [](float a, float b) { return a > b ? b + 0.001f * (a - b) : a; });
+        for (size_t ch = 0; ch < num_capture_channels_; ++ch) {
+          std::transform(N2_[ch].begin(), N2_[ch].end(),
+                         (*N2_initial_)[ch].begin(), (*N2_initial_)[ch].begin(),
+                         [](float a, float b) {
+                           return a > b ? b + 0.001f * (a - b) : a;
+                         });
+        }
+      }
+    }
+
+    // Limit the noise to a floor matching a WGN input of -96 dBFS.
+    constexpr float kNoiseFloor = 17.1267f;
+
+    for (size_t ch = 0; ch < num_capture_channels_; ++ch) {
+      for (auto& n : N2_[ch]) {
+        n = std::max(n, kNoiseFloor);
+      }
+      if (N2_initial_) {
+        for (auto& n : (*N2_initial_)[ch]) {
+          n = std::max(n, kNoiseFloor);
+        }
       }
     }
   }
 
-  // Limit the noise to a floor matching a WGN input of -96 dBFS.
-  constexpr float kNoiseFloor = 17.1267f;
-
-  for (auto& n : N2_) {
-    n = std::max(n, kNoiseFloor);
-  }
-  if (N2_initial_) {
-    for (auto& n : *N2_initial_) {
-      n = std::max(n, kNoiseFloor);
-    }
-  }
-
   // Choose N2 estimate to use.
-  const std::array<float, kFftLengthBy2Plus1>& N2 =
-      N2_initial_ ? *N2_initial_ : N2_;
+  const auto& N2 = N2_initial_ ? (*N2_initial_) : N2_;
 
-  GenerateComfortNoise(optimization_, N2, &seed_, lower_band_noise,
-                       upper_band_noise);
+  for (size_t ch = 0; ch < num_capture_channels_; ++ch) {
+    GenerateComfortNoise(optimization_, N2[ch], &seed_, &lower_band_noise[ch],
+                         &upper_band_noise[ch]);
+  }
 }
 
 }  // namespace webrtc
diff --git a/modules/audio_processing/aec3/comfort_noise_generator.h b/modules/audio_processing/aec3/comfort_noise_generator.h
index 31360d2..776ed1b 100644
--- a/modules/audio_processing/aec3/comfort_noise_generator.h
+++ b/modules/audio_processing/aec3/comfort_noise_generator.h
@@ -41,29 +41,34 @@
 // Generates the comfort noise.
 class ComfortNoiseGenerator {
  public:
-  ComfortNoiseGenerator(Aec3Optimization optimization, uint32_t seed);
+  ComfortNoiseGenerator(Aec3Optimization optimization,
+                        size_t num_capture_channels);
+  ComfortNoiseGenerator() = delete;
   ~ComfortNoiseGenerator();
+  ComfortNoiseGenerator(const ComfortNoiseGenerator&) = delete;
 
   // Computes the comfort noise.
   void Compute(bool saturated_capture,
-               const std::array<float, kFftLengthBy2Plus1>& capture_spectrum,
-               FftData* lower_band_noise,
-               FftData* upper_band_noise);
+               rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>>
+                   capture_spectrum,
+               rtc::ArrayView<FftData> lower_band_noise,
+               rtc::ArrayView<FftData> upper_band_noise);
 
   // Returns the estimate of the background noise spectrum.
-  const std::array<float, kFftLengthBy2Plus1>& NoiseSpectrum() const {
+  rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> NoiseSpectrum()
+      const {
     return N2_;
   }
 
  private:
   const Aec3Optimization optimization_;
   uint32_t seed_;
-  std::unique_ptr<std::array<float, kFftLengthBy2Plus1>> N2_initial_;
-  std::array<float, kFftLengthBy2Plus1> Y2_smoothed_;
-  std::array<float, kFftLengthBy2Plus1> N2_;
+  const size_t num_capture_channels_;
+  std::unique_ptr<std::vector<std::array<float, kFftLengthBy2Plus1>>>
+      N2_initial_;
+  std::vector<std::array<float, kFftLengthBy2Plus1>> Y2_smoothed_;
+  std::vector<std::array<float, kFftLengthBy2Plus1>> N2_;
   int N2_counter_ = 0;
-
-  RTC_DISALLOW_IMPLICIT_CONSTRUCTORS(ComfortNoiseGenerator);
 };
 
 }  // namespace webrtc
diff --git a/modules/audio_processing/aec3/comfort_noise_generator_unittest.cc b/modules/audio_processing/aec3/comfort_noise_generator_unittest.cc
index 2d87cd8..02c26cc 100644
--- a/modules/audio_processing/aec3/comfort_noise_generator_unittest.cc
+++ b/modules/audio_processing/aec3/comfort_noise_generator_unittest.cc
@@ -31,50 +31,39 @@
 
 }  // namespace
 
-#if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID)
-
-TEST(ComfortNoiseGenerator, NullLowerBandNoise) {
-  std::array<float, kFftLengthBy2Plus1> N2;
-  FftData noise;
-  EXPECT_DEATH(ComfortNoiseGenerator(DetectOptimization(), 42)
-                   .Compute(false, N2, nullptr, &noise),
-               "");
-}
-
-TEST(ComfortNoiseGenerator, NullUpperBandNoise) {
-  std::array<float, kFftLengthBy2Plus1> N2;
-  FftData noise;
-  EXPECT_DEATH(ComfortNoiseGenerator(DetectOptimization(), 42)
-                   .Compute(false, N2, &noise, nullptr),
-               "");
-}
-
-#endif
-
 TEST(ComfortNoiseGenerator, CorrectLevel) {
-  ComfortNoiseGenerator cng(DetectOptimization(), 42);
-  AecState aec_state(EchoCanceller3Config{}, 1);
+  constexpr size_t kNumChannels = 5;
+  ComfortNoiseGenerator cng(DetectOptimization(), kNumChannels);
+  AecState aec_state(EchoCanceller3Config{}, kNumChannels);
 
-  std::array<float, kFftLengthBy2Plus1> N2;
-  N2.fill(1000.f * 1000.f);
+  std::vector<std::array<float, kFftLengthBy2Plus1>> N2(kNumChannels);
+  std::vector<FftData> n_lower(kNumChannels);
+  std::vector<FftData> n_upper(kNumChannels);
 
-  FftData n_lower;
-  FftData n_upper;
-  n_lower.re.fill(0.f);
-  n_lower.im.fill(0.f);
-  n_upper.re.fill(0.f);
-  n_upper.im.fill(0.f);
+  for (size_t ch = 0; ch < kNumChannels; ++ch) {
+    N2[ch].fill(1000.f * 1000.f / (ch + 1));
+    n_lower[ch].re.fill(0.f);
+    n_lower[ch].im.fill(0.f);
+    n_upper[ch].re.fill(0.f);
+    n_upper[ch].im.fill(0.f);
+  }
 
   // Ensure instantaneous updata to nonzero noise.
-  cng.Compute(false, N2, &n_lower, &n_upper);
-  EXPECT_LT(0.f, Power(n_lower));
-  EXPECT_LT(0.f, Power(n_upper));
+  cng.Compute(false, N2, n_lower, n_upper);
+
+  for (size_t ch = 0; ch < kNumChannels; ++ch) {
+    EXPECT_LT(0.f, Power(n_lower[ch]));
+    EXPECT_LT(0.f, Power(n_upper[ch]));
+  }
 
   for (int k = 0; k < 10000; ++k) {
-    cng.Compute(false, N2, &n_lower, &n_upper);
+    cng.Compute(false, N2, n_lower, n_upper);
   }
-  EXPECT_NEAR(2.f * N2[0], Power(n_lower), N2[0] / 10.f);
-  EXPECT_NEAR(2.f * N2[0], Power(n_upper), N2[0] / 10.f);
+
+  for (size_t ch = 0; ch < kNumChannels; ++ch) {
+    EXPECT_NEAR(2.f * N2[ch][0], Power(n_lower[ch]), N2[ch][0] / 10.f);
+    EXPECT_NEAR(2.f * N2[ch][0], Power(n_upper[ch]), N2[ch][0] / 10.f);
+  }
 }
 
 }  // namespace aec3
diff --git a/modules/audio_processing/aec3/echo_remover.cc b/modules/audio_processing/aec3/echo_remover.cc
index 602a353..5f48e22 100644
--- a/modules/audio_processing/aec3/echo_remover.cc
+++ b/modules/audio_processing/aec3/echo_remover.cc
@@ -149,7 +149,7 @@
   const bool use_shadow_filter_output_;
   Subtractor subtractor_;
   std::vector<std::unique_ptr<SuppressionGain>> suppression_gains_;
-  std::vector<std::unique_ptr<ComfortNoiseGenerator>> cngs_;
+  ComfortNoiseGenerator cng_;
   SuppressionFilter suppression_filter_;
   RenderSignalAnalyzer render_signal_analyzer_;
   ResidualEchoEstimator residual_echo_estimator_;
@@ -196,7 +196,7 @@
                   data_dumper_.get(),
                   optimization_),
       suppression_gains_(num_capture_channels_),
-      cngs_(num_capture_channels_),
+      cng_(optimization_, num_capture_channels_),
       suppression_filter_(optimization_,
                           sample_rate_hz_,
                           num_capture_channels_),
@@ -220,12 +220,9 @@
     e_k.fill(0.f);
   }
 
-  uint32_t cng_seed = 42;
   for (size_t ch = 0; ch < num_capture_channels_; ++ch) {
     suppression_gains_[ch] = std::make_unique<SuppressionGain>(
         config_, optimization_, sample_rate_hz);
-    cngs_[ch] =
-        std::make_unique<ComfortNoiseGenerator>(optimization_, cng_seed++);
     e_old_[ch].fill(0.f);
     y_old_[ch].fill(0.f);
   }
@@ -401,11 +398,11 @@
   residual_echo_estimator_.Estimate(aec_state_, *render_buffer, S2_linear, Y2,
                                     R2);
 
-  for (size_t ch = 0; ch < num_capture_channels_; ++ch) {
-    // Estimate the comfort noise.
-    cngs_[ch]->Compute(aec_state_.SaturatedCapture(), Y2[ch],
-                       &comfort_noise[ch], &high_band_comfort_noise[ch]);
+  // Estimate the comfort noise.
+  cng_.Compute(aec_state_.SaturatedCapture(), Y2, comfort_noise,
+               high_band_comfort_noise);
 
+  for (size_t ch = 0; ch < num_capture_channels_; ++ch) {
     // Suppressor echo estimate.
     const auto& echo_spectrum =
         aec_state_.UsableLinearEstimate() ? S2_linear[ch] : R2[ch];
@@ -425,7 +422,7 @@
     float high_bands_gain_channel;
     std::array<float, kFftLengthBy2Plus1> G_channel;
     suppression_gains_[ch]->GetGain(nearend_spectrum, echo_spectrum, R2[ch],
-                                    cngs_[ch]->NoiseSpectrum(),
+                                    cng_.NoiseSpectrum()[ch],
                                     render_signal_analyzer_, aec_state_, x,
                                     &high_bands_gain_channel, &G_channel);
 
@@ -438,7 +435,7 @@
                                 high_bands_gain, Y_fft, y);
 
   // Update the metrics.
-  metrics_.Update(aec_state_, cngs_[0]->NoiseSpectrum(), G);
+  metrics_.Update(aec_state_, cng_.NoiseSpectrum()[0], G);
 
   // Debug outputs for the purpose of development and analysis.
   data_dumper_->DumpWav("aec3_echo_estimate", kBlockSize,
@@ -446,7 +443,7 @@
   data_dumper_->DumpRaw("aec3_output", (*y)[0][0]);
   data_dumper_->DumpRaw("aec3_narrow_render",
                         render_signal_analyzer_.NarrowPeakBand() ? 1 : 0);
-  data_dumper_->DumpRaw("aec3_N2", cngs_[0]->NoiseSpectrum());
+  data_dumper_->DumpRaw("aec3_N2", cng_.NoiseSpectrum()[0]);
   data_dumper_->DumpRaw("aec3_suppressor_gain", G);
   data_dumper_->DumpWav("aec3_output",
                         rtc::ArrayView<const float>(&(*y)[0][0][0], kBlockSize),