Correct and soften the AEC3 handling of saturated mic signals

This CL changes the handling of saturated microphone signals in AEC3.

Some of the changes included are
-Make the detection of saturated echoes depend on the echo path gain
 estimate.
-Remove redundant code related to echo saturation.
-Correct the computation of residual echoes when the echo is saturated.
-Soften the echo removal during echo saturation.

Bug: webrtc:9119
Change-Id: I5cb11cd449de552ab670beeb24ed8112f8beb734
Reviewed-on: https://webrtc-review.googlesource.com/67220
Commit-Queue: Per Åhgren <peah@webrtc.org>
Reviewed-by: Gustaf Ullberg <gustaf@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#22809}
diff --git a/api/audio/echo_canceller3_config.h b/api/audio/echo_canceller3_config.h
index f7125b2..db12cf7 100644
--- a/api/audio/echo_canceller3_config.h
+++ b/api/audio/echo_canceller3_config.h
@@ -62,9 +62,9 @@
   } erle;
 
   struct EpStrength {
-    float lf = 2.f;
-    float mf = 2.f;
-    float hf = 2.f;
+    float lf = 10.f;
+    float mf = 10.f;
+    float hf = 10.f;
     float default_len = 0.f;
     bool echo_can_saturate = true;
     bool bounded_erl = false;
diff --git a/modules/audio_processing/aec3/aec_state.cc b/modules/audio_processing/aec3/aec_state.cc
index bc987c4..ed664aa 100644
--- a/modules/audio_processing/aec3/aec_state.cc
+++ b/modules/audio_processing/aec3/aec_state.cc
@@ -59,7 +59,6 @@
     usable_linear_estimate_ = false;
     capture_signal_saturation_ = false;
     echo_saturation_ = false;
-    previous_max_sample_ = 0.f;
     std::fill(max_render_.begin(), max_render_.end(), 0.f);
     blocks_with_proper_filter_adaptation_ = 0;
     blocks_since_reset_ = 0;
@@ -144,7 +143,7 @@
   // TODO(peah): Add the delay in this computation to ensure that the render and
   // capture signals are properly aligned.
   if (config_.ep_strength.echo_can_saturate) {
-    echo_saturation_ = DetectEchoSaturation(x);
+    echo_saturation_ = DetectEchoSaturation(x, EchoPathGain());
   }
 
   bool filter_has_had_time_to_converge =
@@ -458,19 +457,22 @@
                         kFftLengthBy2;
 }
 
-bool AecState::DetectEchoSaturation(rtc::ArrayView<const float> x) {
+bool AecState::DetectEchoSaturation(rtc::ArrayView<const float> x,
+                                    float echo_path_gain) {
   RTC_DCHECK_LT(0, x.size());
   const float max_sample = fabs(*std::max_element(
       x.begin(), x.end(), [](float a, float b) { return a * a < b * b; }));
-  previous_max_sample_ = max_sample;
 
   // Set flag for potential presence of saturated echo
-  blocks_since_last_saturation_ =
-      previous_max_sample_ > 200.f && SaturatedCapture()
-          ? 0
-          : blocks_since_last_saturation_ + 1;
+  const float kMargin = 10.f;
+  float peak_echo_amplitude = max_sample * echo_path_gain * kMargin;
+  if (SaturatedCapture() && peak_echo_amplitude > 32000) {
+    blocks_since_last_saturation_ = 0;
+  } else {
+    ++blocks_since_last_saturation_;
+  }
 
-  return blocks_since_last_saturation_ < 20;
+  return blocks_since_last_saturation_ < 5;
 }
 
 }  // namespace webrtc
diff --git a/modules/audio_processing/aec3/aec_state.h b/modules/audio_processing/aec3/aec_state.h
index e1ff492..968704e 100644
--- a/modules/audio_processing/aec3/aec_state.h
+++ b/modules/audio_processing/aec3/aec_state.h
@@ -81,9 +81,6 @@
   // Returns whether the echo signal is saturated.
   bool SaturatedEcho() const { return echo_saturation_; }
 
-  // Returns whether the echo path can saturate.
-  bool SaturatingEchoPath() const { return saturating_echo_path_; }
-
   // Updates the capture signal saturation.
   void UpdateCaptureSaturation(bool capture_signal_saturation) {
     capture_signal_saturation_ = capture_signal_saturation;
@@ -127,7 +124,8 @@
   void UpdateReverb(const std::vector<float>& impulse_response);
   bool DetectActiveRender(rtc::ArrayView<const float> x) const;
   void UpdateSuppressorGainLimit(bool render_activity);
-  bool DetectEchoSaturation(rtc::ArrayView<const float> x);
+  bool DetectEchoSaturation(rtc::ArrayView<const float> x,
+                            float echo_path_gain);
 
   static int instance_count_;
   std::unique_ptr<ApmDataDumper> data_dumper_;
@@ -141,7 +139,6 @@
   bool capture_signal_saturation_ = false;
   bool echo_saturation_ = false;
   bool transparent_mode_ = false;
-  float previous_max_sample_ = 0.f;
   bool render_received_ = false;
   int filter_delay_blocks_ = 0;
   size_t blocks_since_last_saturation_ = 1000;
@@ -158,7 +155,6 @@
   const EchoCanceller3Config config_;
   std::vector<float> max_render_;
   float reverb_decay_ = fabsf(config_.ep_strength.default_len);
-  bool saturating_echo_path_ = false;
   bool filter_has_had_time_to_converge_ = false;
   bool initial_state_ = true;
   const float gain_rampup_increase_;
diff --git a/modules/audio_processing/aec3/residual_echo_estimator.cc b/modules/audio_processing/aec3/residual_echo_estimator.cc
index f534817..7435b6c 100644
--- a/modules/audio_processing/aec3/residual_echo_estimator.cc
+++ b/modules/audio_processing/aec3/residual_echo_estimator.cc
@@ -38,16 +38,10 @@
 
   // Estimate the residual echo power.
   if (aec_state.UsableLinearEstimate()) {
-    LinearEstimate(S2_linear, aec_state.Erle(), aec_state.FilterDelayBlocks(),
-                   R2);
-    AddEchoReverb(S2_linear, aec_state.SaturatedEcho(),
-                  aec_state.FilterDelayBlocks(), aec_state.ReverbDecay(), R2);
-
-    // If the echo is saturated, estimate the echo power as the maximum echo
-    // power with a leakage factor.
-    if (aec_state.SaturatedEcho()) {
-      R2->fill((*std::max_element(R2->begin(), R2->end())) * 100.f);
-    }
+    RTC_DCHECK(!aec_state.SaturatedEcho());
+    LinearEstimate(S2_linear, aec_state.Erle(), R2);
+    AddEchoReverb(S2_linear, aec_state.FilterDelayBlocks(),
+                  aec_state.ReverbDecay(), R2);
   } else {
     // Estimate the echo generating signal power.
     std::array<float, kFftLengthBy2Plus1> X2;
@@ -69,15 +63,16 @@
                          0.f, a - config_.echo_model.stationary_gate_slope * b);
                    });
 
-    NonLinearEstimate(aec_state.SaturatedEcho(), aec_state.EchoPathGain(), X2,
-                      Y2, R2);
+    NonLinearEstimate(aec_state.EchoPathGain(), X2, Y2, R2);
 
+    // If the echo is saturated, estimate the echo power as the maximum echo
+    // power with a leakage factor.
     if (aec_state.SaturatedEcho()) {
-      // TODO(peah): Modify to make sense theoretically.
-      AddEchoReverb(*R2, aec_state.SaturatedEcho(),
-                    config_.filter.main.length_blocks, aec_state.ReverbDecay(),
-                    R2);
+      R2->fill((*std::max_element(R2->begin(), R2->end())) * 100.f);
     }
+
+    AddEchoReverb(*R2, config_.filter.main.length_blocks,
+                  aec_state.ReverbDecay(), R2);
   }
 
   // If the echo is deemed inaudible, set the residual echo to zero.
@@ -104,7 +99,6 @@
 void ResidualEchoEstimator::LinearEstimate(
     const std::array<float, kFftLengthBy2Plus1>& S2_linear,
     const std::array<float, kFftLengthBy2Plus1>& erle,
-    size_t delay,
     std::array<float, kFftLengthBy2Plus1>* R2) {
   std::fill(R2_hold_counter_.begin(), R2_hold_counter_.end(), 10.f);
   std::transform(erle.begin(), erle.end(), S2_linear.begin(), R2->begin(),
@@ -115,17 +109,15 @@
 }
 
 void ResidualEchoEstimator::NonLinearEstimate(
-    bool saturated_echo,
     float echo_path_gain,
     const std::array<float, kFftLengthBy2Plus1>& X2,
     const std::array<float, kFftLengthBy2Plus1>& Y2,
     std::array<float, kFftLengthBy2Plus1>* R2) {
-  float echo_path_gain_use = saturated_echo ? 10000.f : echo_path_gain;
 
   // Compute preliminary residual echo.
-  std::transform(
-      X2.begin(), X2.end(), R2->begin(),
-      [echo_path_gain_use](float a) { return a * echo_path_gain_use; });
+  std::transform(X2.begin(), X2.end(), R2->begin(), [echo_path_gain](float a) {
+    return a * echo_path_gain * echo_path_gain;
+  });
 
   for (size_t k = 0; k < R2->size(); ++k) {
     // Update hold counter.
@@ -144,7 +136,6 @@
 
 void ResidualEchoEstimator::AddEchoReverb(
     const std::array<float, kFftLengthBy2Plus1>& S2,
-    bool saturated_echo,
     size_t delay,
     float reverb_decay_factor,
     std::array<float, kFftLengthBy2Plus1>* R2) {
@@ -171,12 +162,7 @@
       });
 
   // Update the buffer of old echo powers.
-  if (saturated_echo) {
-    S2_old_[S2_old_index_].fill((*std::max_element(S2.begin(), S2.end())) *
-                                100.f);
-  } else {
-    std::copy(S2.begin(), S2.end(), S2_old_[S2_old_index_].begin());
-  }
+  std::copy(S2.begin(), S2.end(), S2_old_[S2_old_index_].begin());
 
   // Add the power of the echo reverb to the residual echo power.
   std::transform(R2->begin(), R2->end(), R2_reverb_.begin(), R2->begin(),
diff --git a/modules/audio_processing/aec3/residual_echo_estimator.h b/modules/audio_processing/aec3/residual_echo_estimator.h
index 1222d54..7b8a9b1 100644
--- a/modules/audio_processing/aec3/residual_echo_estimator.h
+++ b/modules/audio_processing/aec3/residual_echo_estimator.h
@@ -43,13 +43,11 @@
   // (ERLE) and the linear power estimate.
   void LinearEstimate(const std::array<float, kFftLengthBy2Plus1>& S2_linear,
                       const std::array<float, kFftLengthBy2Plus1>& erle,
-                      size_t delay,
                       std::array<float, kFftLengthBy2Plus1>* R2);
 
   // Estimates the residual echo power based on the estimate of the echo path
   // gain.
-  void NonLinearEstimate(bool saturated_echo,
-                         float echo_path_gain,
+  void NonLinearEstimate(float echo_path_gain,
                          const std::array<float, kFftLengthBy2Plus1>& X2,
                          const std::array<float, kFftLengthBy2Plus1>& Y2,
                          std::array<float, kFftLengthBy2Plus1>* R2);
@@ -57,7 +55,6 @@
   // Adds the estimated unmodelled echo power to the residual echo power
   // estimate.
   void AddEchoReverb(const std::array<float, kFftLengthBy2Plus1>& S2,
-                     bool saturated_echo,
                      size_t delay,
                      float reverb_decay_factor,
                      std::array<float, kFftLengthBy2Plus1>* R2);
diff --git a/modules/audio_processing/aec3/suppression_gain.cc b/modules/audio_processing/aec3/suppression_gain.cc
index b73e87e..49197f6 100644
--- a/modules/audio_processing/aec3/suppression_gain.cc
+++ b/modules/audio_processing/aec3/suppression_gain.cc
@@ -1,3 +1,4 @@
+
 /*
  *  Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
  *
@@ -117,7 +118,6 @@
     const EchoCanceller3Config& config,
     bool low_noise_render,
     bool saturated_echo,
-    bool saturating_echo_path,
     bool linear_echo_estimate,
     const std::array<float, kFftLengthBy2Plus1>& nearend,
     const std::array<float, kFftLengthBy2Plus1>& echo,
@@ -224,12 +224,8 @@
     const std::array<float, kFftLengthBy2Plus1>& comfort_noise,
     std::array<float, kFftLengthBy2Plus1>* gain) {
   const bool saturated_echo = aec_state.SaturatedEcho();
-  const bool saturating_echo_path = aec_state.SaturatingEchoPath();
   const bool linear_echo_estimate = aec_state.UsableLinearEstimate();
 
-  // Count the number of blocks since saturation.
-  no_saturation_counter_ = saturated_echo ? 0 : no_saturation_counter_ + 1;
-
   // Precompute 1/echo (note that when the echo is zero, the precomputed value
   // is never used).
   std::array<float, kFftLengthBy2Plus1> one_by_echo;
@@ -242,7 +238,7 @@
   const float min_echo_power =
       low_noise_render ? config_.echo_audibility.low_render_limit
                        : config_.echo_audibility.normal_render_limit;
-  if (no_saturation_counter_ > 10) {
+  if (!saturated_echo) {
     for (size_t k = 0; k < nearend.size(); ++k) {
       const float denom = std::min(nearend[k], echo[k]);
       min_gain[k] = denom > 0.f ? min_echo_power / denom : 1.f;
@@ -268,8 +264,8 @@
     std::array<float, kFftLengthBy2Plus1> masker;
     MaskingPower(config_, nearend, comfort_noise, last_masker_, *gain, &masker);
     GainToNoAudibleEcho(config_, low_noise_render, saturated_echo,
-                        saturating_echo_path, linear_echo_estimate, nearend,
-                        echo, masker, min_gain, max_gain, one_by_echo, gain);
+                        linear_echo_estimate, nearend, echo, masker, min_gain,
+                        max_gain, one_by_echo, gain);
     AdjustForExternalFilters(gain);
     if (narrow_peak_band) {
       NarrowBandAttenuation(*narrow_peak_band, nearend, echo, gain);
@@ -280,7 +276,8 @@
   AdjustNonConvergedFrequencies(gain);
 
   // Update the allowed maximum gain increase.
-  UpdateGainIncrease(low_noise_render, linear_echo_estimate, echo, *gain);
+  UpdateGainIncrease(low_noise_render, linear_echo_estimate, saturated_echo,
+                     echo, *gain);
 
   // Adjust gain dynamics.
   const float gain_bound =
@@ -353,6 +350,7 @@
 void SuppressionGain::UpdateGainIncrease(
     bool low_noise_render,
     bool linear_echo_estimate,
+    bool saturated_echo,
     const std::array<float, kFftLengthBy2Plus1>& echo,
     const std::array<float, kFftLengthBy2Plus1>& new_gain) {
   float max_inc;
@@ -379,7 +377,7 @@
     rate_dec = p.nonlinear.rate_dec;
     min_inc = p.nonlinear.min_inc;
     min_dec = p.nonlinear.min_dec;
-  } else if (initial_state_ && no_saturation_counter_ > 10) {
+  } else if (initial_state_ && !saturated_echo) {
     if (initial_state_change_counter_ > 0) {
       float change_factor =
           initial_state_change_counter_ * one_by_state_change_duration_blocks_;
@@ -409,7 +407,7 @@
     rate_dec = p.low_noise.rate_dec;
     min_inc = p.low_noise.min_inc;
     min_dec = p.low_noise.min_dec;
-  } else if (no_saturation_counter_ > 10) {
+  } else if (!saturated_echo) {
     max_inc = p.normal.max_inc;
     max_dec = p.normal.max_dec;
     rate_inc = p.normal.rate_inc;
diff --git a/modules/audio_processing/aec3/suppression_gain.h b/modules/audio_processing/aec3/suppression_gain.h
index 6624c1c..a519894 100644
--- a/modules/audio_processing/aec3/suppression_gain.h
+++ b/modules/audio_processing/aec3/suppression_gain.h
@@ -51,6 +51,7 @@
   void UpdateGainIncrease(
       bool low_noise_render,
       bool linear_echo_estimate,
+      bool saturated_echo,
       const std::array<float, kFftLengthBy2Plus1>& echo,
       const std::array<float, kFftLengthBy2Plus1>& new_gain);
 
@@ -72,7 +73,6 @@
   std::array<float, kFftLengthBy2Plus1> last_echo_;
 
   LowNoiseRenderDetector low_render_detector_;
-  size_t no_saturation_counter_ = 0;
   bool initial_state_ = true;
   int initial_state_change_counter_ = 0;
   RTC_DISALLOW_COPY_AND_ASSIGN(SuppressionGain);