Version 3: Various changes on the pre-echo delay estimator:
- Lowering the energy threshold for updating the accumulated error.
- Not using the pre-echo estimate in the initial frames when the matched filters have been recently initialized.
- Slight speed up for the increases in the accumulated error.
- Not periodically resetting the accumulated error.

Bug: webrtc:14205
Change-Id: Ic337332e263b27d7a3aba0ab4b371517780f9c90
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/291320
Reviewed-by: Gustaf Ullberg <gustaf@webrtc.org>
Commit-Queue: Jesus de Vicente Pena <devicentepena@webrtc.org>
Cr-Commit-Position: refs/heads/main@{#39175}
diff --git a/modules/audio_processing/aec3/echo_path_delay_estimator.cc b/modules/audio_processing/aec3/echo_path_delay_estimator.cc
index fc83ca2..510e4b8 100644
--- a/modules/audio_processing/aec3/echo_path_delay_estimator.cc
+++ b/modules/audio_processing/aec3/echo_path_delay_estimator.cc
@@ -120,7 +120,7 @@
   if (reset_lag_aggregator) {
     matched_filter_lag_aggregator_.Reset(reset_delay_confidence);
   }
-  matched_filter_.Reset();
+  matched_filter_.Reset(/*full_reset=*/reset_lag_aggregator);
   old_aggregated_lag_ = absl::nullopt;
   consistent_estimate_counter_ = 0;
 }
diff --git a/modules/audio_processing/aec3/matched_filter.cc b/modules/audio_processing/aec3/matched_filter.cc
index a905482..af30ff1 100644
--- a/modules/audio_processing/aec3/matched_filter.cc
+++ b/modules/audio_processing/aec3/matched_filter.cc
@@ -43,14 +43,16 @@
 void UpdateAccumulatedError(
     const rtc::ArrayView<const float> instantaneous_accumulated_error,
     const rtc::ArrayView<float> accumulated_error,
-    float one_over_error_sum_anchor) {
+    float one_over_error_sum_anchor,
+    float smooth_constant_increases) {
   for (size_t k = 0; k < instantaneous_accumulated_error.size(); ++k) {
     float error_norm =
         instantaneous_accumulated_error[k] * one_over_error_sum_anchor;
     if (error_norm < accumulated_error[k]) {
       accumulated_error[k] = error_norm;
     } else {
-      accumulated_error[k] += 0.01f * (error_norm - accumulated_error[k]);
+      accumulated_error[k] +=
+          smooth_constant_increases * (error_norm - accumulated_error[k]);
     }
   }
 }
@@ -89,7 +91,8 @@
       }
       break;
     case 2:
-      // Mode 2: Pre echo lag is defined as the closest coefficient to the lag
+    case 3:
+      // Mode 2,3: Pre echo lag is defined as the closest coefficient to the lag
       // with an error lower than a certain threshold.
       for (int k = static_cast<int>(maximum_pre_echo_lag) - 1; k >= 0; --k) {
         if (accumulated_error[k] > pre_echo_configuration.threshold) {
@@ -705,17 +708,19 @@
 
 MatchedFilter::~MatchedFilter() = default;
 
-void MatchedFilter::Reset() {
+void MatchedFilter::Reset(bool full_reset) {
   for (auto& f : filters_) {
     std::fill(f.begin(), f.end(), 0.f);
   }
 
-  for (auto& e : accumulated_error_) {
-    std::fill(e.begin(), e.end(), 1.0f);
-  }
-
   winner_lag_ = absl::nullopt;
   reported_lag_estimate_ = absl::nullopt;
+  if (pre_echo_config_.mode != 3 || full_reset) {
+    for (auto& e : accumulated_error_) {
+      std::fill(e.begin(), e.end(), 1.0f);
+    }
+    number_pre_echo_updates_ = 0;
+  }
 }
 
 void MatchedFilter::Update(const DownsampledRenderBuffer& render_buffer,
@@ -816,20 +821,34 @@
     reported_lag_estimate_ =
         LagEstimate(winner_lag_.value(), /*pre_echo_lag=*/winner_lag_.value());
     if (detect_pre_echo_ && last_detected_best_lag_filter_ == winner_index) {
-      if (error_sum_anchor > 30.0f * 30.0f * y.size()) {
-        UpdateAccumulatedError(instantaneous_accumulated_error_,
-                               accumulated_error_[winner_index],
-                               1.0f / error_sum_anchor);
+      const float energy_threshold =
+          pre_echo_config_.mode == 3 ? 1.0f : 30.0f * 30.0f * y.size();
+
+      if (error_sum_anchor > energy_threshold) {
+        const float smooth_constant_increases =
+            pre_echo_config_.mode != 3 ? 0.01f : 0.015f;
+
+        UpdateAccumulatedError(
+            instantaneous_accumulated_error_, accumulated_error_[winner_index],
+            1.0f / error_sum_anchor, smooth_constant_increases);
+        number_pre_echo_updates_++;
       }
-      reported_lag_estimate_->pre_echo_lag = ComputePreEchoLag(
-          pre_echo_config_, accumulated_error_[winner_index],
-          winner_lag_.value(),
-          winner_index * filter_intra_lag_shift_ /*alignment_shift_winner*/);
+      if (pre_echo_config_.mode != 3 || number_pre_echo_updates_ >= 50) {
+        reported_lag_estimate_->pre_echo_lag = ComputePreEchoLag(
+            pre_echo_config_, accumulated_error_[winner_index],
+            winner_lag_.value(),
+            winner_index * filter_intra_lag_shift_ /*alignment_shift_winner*/);
+      } else {
+        reported_lag_estimate_->pre_echo_lag = winner_lag_.value();
+      }
     }
     last_detected_best_lag_filter_ = winner_index;
   }
   if (ApmDataDumper::IsAvailable()) {
     Dump();
+    data_dumper_->DumpRaw("error_sum_anchor", error_sum_anchor / y.size());
+    data_dumper_->DumpRaw("number_pre_echo_updates", number_pre_echo_updates_);
+    data_dumper_->DumpRaw("filter_smoothing", smoothing);
   }
 }
 
@@ -871,6 +890,9 @@
       std::string dumper_pre_lag =
           "aec3_correlator_pre_echo_lag_" + std::to_string(n);
       data_dumper_->DumpRaw(dumper_pre_lag.c_str(), pre_echo_lag);
+      if (static_cast<int>(n) == last_detected_best_lag_filter_) {
+        data_dumper_->DumpRaw("aec3_pre_echo_delay_winner_inst", pre_echo_lag);
+      }
     }
   }
 }
diff --git a/modules/audio_processing/aec3/matched_filter.h b/modules/audio_processing/aec3/matched_filter.h
index 1560fb0..bb54fba 100644
--- a/modules/audio_processing/aec3/matched_filter.h
+++ b/modules/audio_processing/aec3/matched_filter.h
@@ -135,7 +135,7 @@
               bool use_slow_smoothing);
 
   // Resets the matched filter.
-  void Reset();
+  void Reset(bool full_reset);
 
   // Returns the current lag estimates.
   absl::optional<const MatchedFilter::LagEstimate> GetBestLagEstimate() const {
@@ -176,6 +176,7 @@
   absl::optional<size_t> winner_lag_;
   int last_detected_best_lag_filter_ = -1;
   std::vector<size_t> filters_offsets_;
+  int number_pre_echo_updates_ = 0;
   const float excitation_limit_;
   const float smoothing_fast_;
   const float smoothing_slow_;