AEC3: Suppression filter handles multiple channels

Suppression filter is extended to support the synthesis
of multiple channels. This CL is also a major clean-up of ApplyGain.

The CL has been tested for bit-exactness for single channel output.

Bug: webrtc:10913
Change-Id: I1319f127981552e17dec66701a248d34dcf0e563
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/154341
Commit-Queue: Gustaf Ullberg <gustaf@webrtc.org>
Reviewed-by: Per Ã…hgren <peah@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#29284}
diff --git a/modules/audio_processing/aec3/echo_remover.cc b/modules/audio_processing/aec3/echo_remover.cc
index a184517..725e33e 100644
--- a/modules/audio_processing/aec3/echo_remover.cc
+++ b/modules/audio_processing/aec3/echo_remover.cc
@@ -191,7 +191,9 @@
       subtractors_(num_capture_channels_),
       suppression_gains_(num_capture_channels_),
       cngs_(num_capture_channels_),
-      suppression_filter_(optimization_, sample_rate_hz_),
+      suppression_filter_(optimization_,
+                          sample_rate_hz_,
+                          num_capture_channels_),
       render_signal_analyzer_(config_),
       residual_echo_estimators_(num_capture_channels_),
       aec_state_(config_),
@@ -378,7 +380,7 @@
                     E2[0], Y2[0], subtractor_output[0], y0);
 
   // Choose the linear output.
-  const auto& Y_fft = aec_state_.UseLinearFilterOutput() ? E[0] : Y[0];
+  const auto& Y_fft = aec_state_.UseLinearFilterOutput() ? E : Y;
 
 #if WEBRTC_APM_DEBUG_DUMP
   if (aec_state_.UseLinearFilterOutput()) {
@@ -439,8 +441,7 @@
                    [](float a, float b) { return std::min(a, b); });
   }
 
-  // TODO(bugs.webrtc.org/10913): Make ApplyGain handle multiple channels.
-  suppression_filter_.ApplyGain(comfort_noise[0], high_band_comfort_noise[0], G,
+  suppression_filter_.ApplyGain(comfort_noise, high_band_comfort_noise, G,
                                 high_bands_gain, Y_fft, y);
 
   // Update the metrics.
diff --git a/modules/audio_processing/aec3/suppression_filter.cc b/modules/audio_processing/aec3/suppression_filter.cc
index 6679a87..8a813d9 100644
--- a/modules/audio_processing/aec3/suppression_filter.cc
+++ b/modules/audio_processing/aec3/suppression_filter.cc
@@ -61,107 +61,117 @@
 }  // namespace
 
 SuppressionFilter::SuppressionFilter(Aec3Optimization optimization,
-                                     int sample_rate_hz)
+                                     int sample_rate_hz,
+                                     size_t num_capture_channels)
     : optimization_(optimization),
       sample_rate_hz_(sample_rate_hz),
+      num_capture_channels_(num_capture_channels),
       fft_(),
-      e_output_old_(NumBandsForRate(sample_rate_hz_)) {
+      e_output_old_(NumBandsForRate(sample_rate_hz_),
+                    std::vector<std::array<float, kFftLengthBy2>>(
+                        num_capture_channels_)) {
   RTC_DCHECK(ValidFullBandRate(sample_rate_hz_));
-  std::for_each(e_output_old_.begin(), e_output_old_.end(),
-                [](std::array<float, kFftLengthBy2>& a) { a.fill(0.f); });
+  for (size_t b = 0; b < e_output_old_.size(); ++b) {
+    for (size_t ch = 0; ch < e_output_old_[b].size(); ++ch) {
+      e_output_old_[b][ch].fill(0.f);
+    }
+  }
 }
 
 SuppressionFilter::~SuppressionFilter() = default;
 
 void SuppressionFilter::ApplyGain(
-    const FftData& comfort_noise,
-    const FftData& comfort_noise_high_band,
+    rtc::ArrayView<const FftData> comfort_noise,
+    rtc::ArrayView<const FftData> comfort_noise_high_band,
     const std::array<float, kFftLengthBy2Plus1>& suppression_gain,
     float high_bands_gain,
-    const FftData& E_lowest_band,
+    rtc::ArrayView<const FftData> E_lowest_band,
     std::vector<std::vector<std::vector<float>>>* e) {
   RTC_DCHECK(e);
   RTC_DCHECK_EQ(e->size(), NumBandsForRate(sample_rate_hz_));
-  FftData E;
-
-  // Analysis filterbank.
-  E.Assign(E_lowest_band);
-
-  // Apply gain.
-  std::transform(suppression_gain.begin(), suppression_gain.end(), E.re.begin(),
-                 E.re.begin(), std::multiplies<float>());
-  std::transform(suppression_gain.begin(), suppression_gain.end(), E.im.begin(),
-                 E.im.begin(), std::multiplies<float>());
 
   // Comfort noise gain is sqrt(1-g^2), where g is the suppression gain.
   std::array<float, kFftLengthBy2Plus1> noise_gain;
-  std::transform(suppression_gain.begin(), suppression_gain.end(),
-                 noise_gain.begin(), [](float g) { return 1.f - g * g; });
+  for (size_t i = 0; i < kFftLengthBy2Plus1; ++i) {
+    noise_gain[i] = 1.f - suppression_gain[i] * suppression_gain[i];
+  }
   aec3::VectorMath(optimization_).Sqrt(noise_gain);
 
-  // Scale and add the comfort noise.
-  for (size_t k = 0; k < kFftLengthBy2Plus1; k++) {
-    E.re[k] += noise_gain[k] * comfort_noise.re[k];
-    E.im[k] += noise_gain[k] * comfort_noise.im[k];
-  }
+  const float high_bands_noise_scaling =
+      0.4f * std::sqrt(1.f - high_bands_gain * high_bands_gain);
 
-  // Synthesis filterbank.
-  std::array<float, kFftLength> e_extended;
-  constexpr float kIfftNormalization = 2.f / kFftLength;
+  for (size_t ch = 0; ch < num_capture_channels_; ++ch) {
+    FftData E;
 
-  fft_.Ifft(E, &e_extended);
-  std::transform(e_output_old_[0].begin(), e_output_old_[0].end(),
-                 std::begin(kSqrtHanning) + kFftLengthBy2, (*e)[0][0].begin(),
-                 [&](float a, float b) { return kIfftNormalization * a * b; });
-  std::transform(e_extended.begin(), e_extended.begin() + kFftLengthBy2,
-                 std::begin(kSqrtHanning), e_extended.begin(),
-                 [&](float a, float b) { return kIfftNormalization * a * b; });
-  std::transform((*e)[0][0].begin(), (*e)[0][0].end(), e_extended.begin(),
-                 (*e)[0][0].begin(), std::plus<float>());
-  std::for_each((*e)[0][0].begin(), (*e)[0][0].end(), [](float& x_k) {
-    x_k = rtc::SafeClamp(x_k, -32768.f, 32767.f);
-  });
-  std::copy(e_extended.begin() + kFftLengthBy2, e_extended.begin() + kFftLength,
-            std::begin(e_output_old_[0]));
+    // Analysis filterbank.
+    E.Assign(E_lowest_band[ch]);
 
-  if (e->size() > 1) {
-    // Form time-domain high-band noise.
-    std::array<float, kFftLength> time_domain_high_band_noise;
-    std::transform(comfort_noise_high_band.re.begin(),
-                   comfort_noise_high_band.re.end(), E.re.begin(),
-                   [&](float a) { return kIfftNormalization * a; });
-    std::transform(comfort_noise_high_band.im.begin(),
-                   comfort_noise_high_band.im.end(), E.im.begin(),
-                   [&](float a) { return kIfftNormalization * a; });
-    fft_.Ifft(E, &time_domain_high_band_noise);
+    for (size_t i = 0; i < kFftLengthBy2Plus1; ++i) {
+      // Apply suppression gains.
+      E.re[i] *= suppression_gain[i];
+      E.im[i] *= suppression_gain[i];
 
-    // Scale and apply the noise to the signals.
-    const float high_bands_noise_scaling =
-        0.4f * std::sqrt(1.f - high_bands_gain * high_bands_gain);
-
-    std::transform(
-        (*e)[1][0].begin(), (*e)[1][0].end(),
-        time_domain_high_band_noise.begin(), (*e)[1][0].begin(),
-        [&](float a, float b) {
-          return std::max(
-              std::min(b * high_bands_noise_scaling + high_bands_gain * a,
-                       32767.0f),
-              -32768.0f);
-        });
-
-    if (e->size() > 2) {
-      RTC_DCHECK_EQ(3, e->size());
-      std::for_each((*e)[2][0].begin(), (*e)[2][0].end(), [&](float& a) {
-        a = rtc::SafeClamp(a * high_bands_gain, -32768.f, 32767.f);
-      });
+      // Scale and add the comfort noise.
+      E.re[i] += noise_gain[i] * comfort_noise[ch].re[i];
+      E.im[i] += noise_gain[i] * comfort_noise[ch].im[i];
     }
 
-    std::array<float, kFftLengthBy2> tmp;
-    for (size_t k = 1; k < e->size(); ++k) {
-      std::copy((*e)[k][0].begin(), (*e)[k][0].end(), tmp.begin());
-      std::copy(e_output_old_[k].begin(), e_output_old_[k].end(),
-                (*e)[k][0].begin());
-      std::copy(tmp.begin(), tmp.end(), e_output_old_[k].begin());
+    // Synthesis filterbank.
+    std::array<float, kFftLength> e_extended;
+    constexpr float kIfftNormalization = 2.f / kFftLength;
+    fft_.Ifft(E, &e_extended);
+
+    auto& e0 = (*e)[0][ch];
+    auto& e0_old = e_output_old_[0][ch];
+
+    // Window and add the first half of e_extended with the second half of
+    // e_extended from the previous block.
+    for (size_t i = 0; i < kFftLengthBy2; ++i) {
+      e0[i] = e0_old[i] * kSqrtHanning[kFftLengthBy2 + i];
+      e0[i] += e_extended[i] * kSqrtHanning[i];
+      e0[i] *= kIfftNormalization;
+    }
+
+    // The second half of e_extended is stored for the succeeding frame.
+    std::copy(e_extended.begin() + kFftLengthBy2,
+              e_extended.begin() + kFftLength, std::begin(e0_old));
+
+    // Apply suppression gain to upper bands.
+    for (size_t b = 1; b < e->size(); ++b) {
+      auto& e_band = (*e)[b][ch];
+      for (size_t i = 0; i < kFftLengthBy2; ++i) {
+        e_band[i] *= high_bands_gain;
+      }
+    }
+
+    // Add comfort noise to band 1.
+    if (e->size() > 1) {
+      E.Assign(comfort_noise_high_band[ch]);
+      std::array<float, kFftLength> time_domain_high_band_noise;
+      fft_.Ifft(E, &time_domain_high_band_noise);
+
+      auto& e1 = (*e)[1][ch];
+      const float gain = high_bands_noise_scaling * kIfftNormalization;
+      for (size_t i = 0; i < kFftLengthBy2; ++i) {
+        e1[i] += time_domain_high_band_noise[i] * gain;
+      }
+    }
+
+    // Delay upper bands to match the delay of the filter bank.
+    for (size_t b = 1; b < e->size(); ++b) {
+      auto& e_band = (*e)[b][ch];
+      auto& e_band_old = e_output_old_[b][ch];
+      for (size_t i = 0; i < kFftLengthBy2; ++i) {
+        std::swap(e_band[i], e_band_old[i]);
+      }
+    }
+
+    // Clamp output of all bands.
+    for (size_t b = 0; b < e->size(); ++b) {
+      auto& e_band = (*e)[b][ch];
+      for (size_t i = 0; i < kFftLengthBy2; ++i) {
+        e_band[i] = rtc::SafeClamp(e_band[i], -32768.f, 32767.f);
+      }
     }
   }
 }
diff --git a/modules/audio_processing/aec3/suppression_filter.h b/modules/audio_processing/aec3/suppression_filter.h
index 03b13c8..a35fb40 100644
--- a/modules/audio_processing/aec3/suppression_filter.h
+++ b/modules/audio_processing/aec3/suppression_filter.h
@@ -24,21 +24,24 @@
 
 class SuppressionFilter {
  public:
-  SuppressionFilter(Aec3Optimization optimization, int sample_rate_hz);
+  SuppressionFilter(Aec3Optimization optimization,
+                    int sample_rate_hz,
+                    size_t num_capture_channels_);
   ~SuppressionFilter();
-  void ApplyGain(const FftData& comfort_noise,
-                 const FftData& comfort_noise_high_bands,
+  void ApplyGain(rtc::ArrayView<const FftData> comfort_noise,
+                 rtc::ArrayView<const FftData> comfort_noise_high_bands,
                  const std::array<float, kFftLengthBy2Plus1>& suppression_gain,
                  float high_bands_gain,
-                 const FftData& E_lowest_band,
+                 rtc::ArrayView<const FftData> E_lowest_band,
                  std::vector<std::vector<std::vector<float>>>* e);
 
  private:
   const Aec3Optimization optimization_;
   const int sample_rate_hz_;
+  const size_t num_capture_channels_;
   const OouraFft ooura_fft_;
   const Aec3Fft fft_;
-  std::vector<std::array<float, kFftLengthBy2>> e_output_old_;
+  std::vector<std::vector<std::array<float, kFftLengthBy2>>> e_output_old_;
   RTC_DISALLOW_COPY_AND_ASSIGN(SuppressionFilter);
 };
 
diff --git a/modules/audio_processing/aec3/suppression_filter_unittest.cc b/modules/audio_processing/aec3/suppression_filter_unittest.cc
index 1e05a02..b55c719 100644
--- a/modules/audio_processing/aec3/suppression_filter_unittest.cc
+++ b/modules/audio_processing/aec3/suppression_filter_unittest.cc
@@ -51,46 +51,46 @@
 
 // Verifies the check for null suppressor output.
 TEST(SuppressionFilter, NullOutput) {
-  FftData cn;
-  FftData cn_high_bands;
-  FftData E;
+  std::vector<FftData> cn(1);
+  std::vector<FftData> cn_high_bands(1);
+  std::vector<FftData> E(1);
   std::array<float, kFftLengthBy2Plus1> gain;
 
-  EXPECT_DEATH(SuppressionFilter(Aec3Optimization::kNone, 16000)
+  EXPECT_DEATH(SuppressionFilter(Aec3Optimization::kNone, 16000, 1)
                    .ApplyGain(cn, cn_high_bands, gain, 1.0f, E, nullptr),
                "");
 }
 
 // Verifies the check for allowed sample rate.
 TEST(SuppressionFilter, ProperSampleRate) {
-  EXPECT_DEATH(SuppressionFilter(Aec3Optimization::kNone, 16001), "");
+  EXPECT_DEATH(SuppressionFilter(Aec3Optimization::kNone, 16001, 1), "");
 }
 
 #endif
 
 // Verifies that no comfort noise is added when the gain is 1.
 TEST(SuppressionFilter, ComfortNoiseInUnityGain) {
-  SuppressionFilter filter(Aec3Optimization::kNone, 48000);
-  FftData cn;
-  FftData cn_high_bands;
+  SuppressionFilter filter(Aec3Optimization::kNone, 48000, 1);
+  std::vector<FftData> cn(1);
+  std::vector<FftData> cn_high_bands(1);
   std::array<float, kFftLengthBy2Plus1> gain;
   std::array<float, kFftLengthBy2> e_old_;
   Aec3Fft fft;
 
   e_old_.fill(0.f);
   gain.fill(1.f);
-  cn.re.fill(1.f);
-  cn.im.fill(1.f);
-  cn_high_bands.re.fill(1.f);
-  cn_high_bands.im.fill(1.f);
+  cn[0].re.fill(1.f);
+  cn[0].im.fill(1.f);
+  cn_high_bands[0].re.fill(1.f);
+  cn_high_bands[0].im.fill(1.f);
 
   std::vector<std::vector<std::vector<float>>> e(
       3,
       std::vector<std::vector<float>>(1, std::vector<float>(kBlockSize, 0.f)));
   std::vector<std::vector<std::vector<float>>> e_ref = e;
 
-  FftData E;
-  fft.PaddedFft(e[0][0], e_old_, Aec3Fft::Window::kSqrtHanning, &E);
+  std::vector<FftData> E(1);
+  fft.PaddedFft(e[0][0], e_old_, Aec3Fft::Window::kSqrtHanning, &E[0]);
   std::copy(e[0][0].begin(), e[0][0].end(), e_old_.begin());
 
   filter.ApplyGain(cn, cn_high_bands, gain, 1.f, E, &e);
@@ -110,9 +110,9 @@
   constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz);
   constexpr size_t kNumChannels = 1;
 
-  SuppressionFilter filter(Aec3Optimization::kNone, kSampleRateHz);
-  FftData cn;
-  FftData cn_high_bands;
+  SuppressionFilter filter(Aec3Optimization::kNone, kSampleRateHz, 1);
+  std::vector<FftData> cn(1);
+  std::vector<FftData> cn_high_bands(1);
   std::array<float, kFftLengthBy2> e_old_;
   Aec3Fft fft;
   std::array<float, kFftLengthBy2Plus1> gain;
@@ -124,10 +124,10 @@
   gain.fill(1.f);
   std::for_each(gain.begin() + 10, gain.end(), [](float& a) { a = 0.f; });
 
-  cn.re.fill(0.f);
-  cn.im.fill(0.f);
-  cn_high_bands.re.fill(0.f);
-  cn_high_bands.im.fill(0.f);
+  cn[0].re.fill(0.f);
+  cn[0].im.fill(0.f);
+  cn_high_bands[0].re.fill(0.f);
+  cn_high_bands[0].im.fill(0.f);
 
   size_t sample_counter = 0;
 
@@ -138,8 +138,8 @@
     e0_input = std::inner_product(e[0][0].begin(), e[0][0].end(),
                                   e[0][0].begin(), e0_input);
 
-    FftData E;
-    fft.PaddedFft(e[0][0], e_old_, Aec3Fft::Window::kSqrtHanning, &E);
+    std::vector<FftData> E(1);
+    fft.PaddedFft(e[0][0], e_old_, Aec3Fft::Window::kSqrtHanning, &E[0]);
     std::copy(e[0][0].begin(), e[0][0].end(), e_old_.begin());
 
     filter.ApplyGain(cn, cn_high_bands, gain, 1.f, E, &e);
@@ -157,11 +157,11 @@
   constexpr int kSampleRateHz = 48000;
   constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz);
 
-  SuppressionFilter filter(Aec3Optimization::kNone, kSampleRateHz);
-  FftData cn;
+  SuppressionFilter filter(Aec3Optimization::kNone, kSampleRateHz, 1);
+  std::vector<FftData> cn(1);
   std::array<float, kFftLengthBy2> e_old_;
   Aec3Fft fft;
-  FftData cn_high_bands;
+  std::vector<FftData> cn_high_bands(1);
   std::array<float, kFftLengthBy2Plus1> gain;
   std::vector<std::vector<std::vector<float>>> e(
       kNumBands, std::vector<std::vector<float>>(
@@ -170,10 +170,10 @@
   gain.fill(1.f);
   std::for_each(gain.begin() + 30, gain.end(), [](float& a) { a = 0.f; });
 
-  cn.re.fill(0.f);
-  cn.im.fill(0.f);
-  cn_high_bands.re.fill(0.f);
-  cn_high_bands.im.fill(0.f);
+  cn[0].re.fill(0.f);
+  cn[0].im.fill(0.f);
+  cn_high_bands[0].re.fill(0.f);
+  cn_high_bands[0].im.fill(0.f);
 
   size_t sample_counter = 0;
 
@@ -184,8 +184,8 @@
     e0_input = std::inner_product(e[0][0].begin(), e[0][0].end(),
                                   e[0][0].begin(), e0_input);
 
-    FftData E;
-    fft.PaddedFft(e[0][0], e_old_, Aec3Fft::Window::kSqrtHanning, &E);
+    std::vector<FftData> E(1);
+    fft.PaddedFft(e[0][0], e_old_, Aec3Fft::Window::kSqrtHanning, &E[0]);
     std::copy(e[0][0].begin(), e[0][0].end(), e_old_.begin());
 
     filter.ApplyGain(cn, cn_high_bands, gain, 1.f, E, &e);
@@ -202,9 +202,9 @@
   constexpr int kSampleRateHz = 48000;
   constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz);
 
-  SuppressionFilter filter(Aec3Optimization::kNone, kSampleRateHz);
-  FftData cn;
-  FftData cn_high_bands;
+  SuppressionFilter filter(Aec3Optimization::kNone, kSampleRateHz, 1);
+  std::vector<FftData> cn(1);
+  std::vector<FftData> cn_high_bands(1);
   std::array<float, kFftLengthBy2> e_old_;
   Aec3Fft fft;
   std::array<float, kFftLengthBy2Plus1> gain;
@@ -214,10 +214,10 @@
 
   gain.fill(1.f);
 
-  cn.re.fill(0.f);
-  cn.im.fill(0.f);
-  cn_high_bands.re.fill(0.f);
-  cn_high_bands.im.fill(0.f);
+  cn[0].re.fill(0.f);
+  cn[0].im.fill(0.f);
+  cn_high_bands[0].re.fill(0.f);
+  cn_high_bands[0].im.fill(0.f);
 
   for (size_t k = 0; k < 100; ++k) {
     for (size_t band = 0; band < kNumBands; ++band) {
@@ -228,8 +228,8 @@
       }
     }
 
-    FftData E;
-    fft.PaddedFft(e[0][0], e_old_, Aec3Fft::Window::kSqrtHanning, &E);
+    std::vector<FftData> E(1);
+    fft.PaddedFft(e[0][0], e_old_, Aec3Fft::Window::kSqrtHanning, &E[0]);
     std::copy(e[0][0].begin(), e[0][0].end(), e_old_.begin());
 
     filter.ApplyGain(cn, cn_high_bands, gain, 1.f, E, &e);