AGC2 add an interface for the noise level estimator

Done in preparation for the child CL which adds an alternative
implementation.

Bug: webrtc:7494
Change-Id: I4963376afc917eae434a0d0ccee18f21880eefe0
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/214125
Reviewed-by: Jakob Ivarsson <jakobi@webrtc.org>
Commit-Queue: Alessio Bazzica <alessiob@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#33646}
diff --git a/modules/audio_processing/agc2/adaptive_agc.cc b/modules/audio_processing/agc2/adaptive_agc.cc
index 9657aec..ca9959a 100644
--- a/modules/audio_processing/agc2/adaptive_agc.cc
+++ b/modules/audio_processing/agc2/adaptive_agc.cc
@@ -58,7 +58,7 @@
                     kMaxGainChangePerSecondDb,
                     kMaxOutputNoiseLevelDbfs),
       apm_data_dumper_(apm_data_dumper),
-      noise_level_estimator_(apm_data_dumper) {
+      noise_level_estimator_(CreateNoiseLevelEstimator(apm_data_dumper)) {
   RTC_DCHECK(apm_data_dumper);
 }
 
@@ -80,7 +80,7 @@
           config.adaptive_digital.max_gain_change_db_per_second,
           config.adaptive_digital.max_output_noise_level_dbfs),
       apm_data_dumper_(apm_data_dumper),
-      noise_level_estimator_(apm_data_dumper) {
+      noise_level_estimator_(CreateNoiseLevelEstimator(apm_data_dumper)) {
   RTC_DCHECK(apm_data_dumper);
   if (!config.adaptive_digital.use_saturation_protector) {
     RTC_LOG(LS_WARNING) << "The saturation protector cannot be disabled.";
@@ -94,7 +94,7 @@
   info.vad_result = vad_.AnalyzeFrame(frame);
   speech_level_estimator_.Update(info.vad_result);
   info.input_level_dbfs = speech_level_estimator_.level_dbfs();
-  info.input_noise_level_dbfs = noise_level_estimator_.Analyze(frame);
+  info.input_noise_level_dbfs = noise_level_estimator_->Analyze(frame);
   info.limiter_envelope_dbfs =
       limiter_envelope > 0 ? FloatS16ToDbfs(limiter_envelope) : -90.0f;
   info.estimate_is_confident = speech_level_estimator_.IsConfident();
diff --git a/modules/audio_processing/agc2/adaptive_agc.h b/modules/audio_processing/agc2/adaptive_agc.h
index f3c7854..b861c48 100644
--- a/modules/audio_processing/agc2/adaptive_agc.h
+++ b/modules/audio_processing/agc2/adaptive_agc.h
@@ -11,6 +11,8 @@
 #ifndef MODULES_AUDIO_PROCESSING_AGC2_ADAPTIVE_AGC_H_
 #define MODULES_AUDIO_PROCESSING_AGC2_ADAPTIVE_AGC_H_
 
+#include <memory>
+
 #include "modules/audio_processing/agc2/adaptive_digital_gain_applier.h"
 #include "modules/audio_processing/agc2/adaptive_mode_level_estimator.h"
 #include "modules/audio_processing/agc2/noise_level_estimator.h"
@@ -42,7 +44,7 @@
   VadLevelAnalyzer vad_;
   AdaptiveDigitalGainApplier gain_applier_;
   ApmDataDumper* const apm_data_dumper_;
-  NoiseLevelEstimator noise_level_estimator_;
+  std::unique_ptr<NoiseLevelEstimator> noise_level_estimator_;
 };
 
 }  // namespace webrtc
diff --git a/modules/audio_processing/agc2/noise_level_estimator.cc b/modules/audio_processing/agc2/noise_level_estimator.cc
index d50ecba..6aa942c 100644
--- a/modules/audio_processing/agc2/noise_level_estimator.cc
+++ b/modules/audio_processing/agc2/noise_level_estimator.cc
@@ -18,11 +18,11 @@
 
 #include "api/array_view.h"
 #include "common_audio/include/audio_util.h"
+#include "modules/audio_processing/agc2/signal_classifier.h"
 #include "modules/audio_processing/logging/apm_data_dumper.h"
 #include "rtc_base/checks.h"
 
 namespace webrtc {
-
 namespace {
 constexpr int kFramesPerSecond = 100;
 
@@ -41,86 +41,106 @@
   const float rms = std::sqrt(signal_energy / num_samples);
   return FloatS16ToDbfs(rms);
 }
-}  // namespace
 
-NoiseLevelEstimator::NoiseLevelEstimator(ApmDataDumper* data_dumper)
-    : data_dumper_(data_dumper), signal_classifier_(data_dumper) {
-  Initialize(48000);
-}
-
-NoiseLevelEstimator::~NoiseLevelEstimator() {}
-
-void NoiseLevelEstimator::Initialize(int sample_rate_hz) {
-  sample_rate_hz_ = sample_rate_hz;
-  noise_energy_ = 1.0f;
-  first_update_ = true;
-  min_noise_energy_ = sample_rate_hz * 2.0f * 2.0f / kFramesPerSecond;
-  noise_energy_hold_counter_ = 0;
-  signal_classifier_.Initialize(sample_rate_hz);
-}
-
-float NoiseLevelEstimator::Analyze(const AudioFrameView<const float>& frame) {
-  data_dumper_->DumpRaw("agc2_noise_level_estimator_hold_counter",
-                        noise_energy_hold_counter_);
-  const int sample_rate_hz =
-      static_cast<int>(frame.samples_per_channel() * kFramesPerSecond);
-  if (sample_rate_hz != sample_rate_hz_) {
-    Initialize(sample_rate_hz);
+class NoiseLevelEstimatorImpl : public NoiseLevelEstimator {
+ public:
+  NoiseLevelEstimatorImpl(ApmDataDumper* data_dumper)
+      : data_dumper_(data_dumper), signal_classifier_(data_dumper) {
+    Initialize(48000);
   }
-  const float frame_energy = FrameEnergy(frame);
-  if (frame_energy <= 0.f) {
-    RTC_DCHECK_GE(frame_energy, 0.f);
-    data_dumper_->DumpRaw("agc2_noise_level_estimator_signal_type", -1);
-    return EnergyToDbfs(noise_energy_, frame.samples_per_channel());
-  }
+  NoiseLevelEstimatorImpl(const NoiseLevelEstimatorImpl&) = delete;
+  NoiseLevelEstimatorImpl& operator=(const NoiseLevelEstimatorImpl&) = delete;
+  ~NoiseLevelEstimatorImpl() = default;
 
-  if (first_update_) {
-    // Initialize the noise energy to the frame energy.
-    first_update_ = false;
-    data_dumper_->DumpRaw("agc2_noise_level_estimator_signal_type", -1);
-    noise_energy_ = std::max(frame_energy, min_noise_energy_);
-    return EnergyToDbfs(noise_energy_, frame.samples_per_channel());
-  }
+  float Analyze(const AudioFrameView<const float>& frame) {
+    data_dumper_->DumpRaw("agc2_noise_level_estimator_hold_counter",
+                          noise_energy_hold_counter_);
+    const int sample_rate_hz =
+        static_cast<int>(frame.samples_per_channel() * kFramesPerSecond);
+    if (sample_rate_hz != sample_rate_hz_) {
+      Initialize(sample_rate_hz);
+    }
+    const float frame_energy = FrameEnergy(frame);
+    if (frame_energy <= 0.f) {
+      RTC_DCHECK_GE(frame_energy, 0.f);
+      data_dumper_->DumpRaw("agc2_noise_level_estimator_signal_type", -1);
+      return EnergyToDbfs(noise_energy_, frame.samples_per_channel());
+    }
 
-  const SignalClassifier::SignalType signal_type =
-      signal_classifier_.Analyze(frame.channel(0));
-  data_dumper_->DumpRaw("agc2_noise_level_estimator_signal_type",
-                        static_cast<int>(signal_type));
+    if (first_update_) {
+      // Initialize the noise energy to the frame energy.
+      first_update_ = false;
+      data_dumper_->DumpRaw("agc2_noise_level_estimator_signal_type", -1);
+      noise_energy_ = std::max(frame_energy, min_noise_energy_);
+      return EnergyToDbfs(noise_energy_, frame.samples_per_channel());
+    }
 
-  // Update the noise estimate in a minimum statistics-type manner.
-  if (signal_type == SignalClassifier::SignalType::kStationary) {
-    if (frame_energy > noise_energy_) {
-      // Leak the estimate upwards towards the frame energy if no recent
-      // downward update.
-      noise_energy_hold_counter_ = std::max(noise_energy_hold_counter_ - 1, 0);
+    const SignalClassifier::SignalType signal_type =
+        signal_classifier_.Analyze(frame.channel(0));
+    data_dumper_->DumpRaw("agc2_noise_level_estimator_signal_type",
+                          static_cast<int>(signal_type));
 
-      if (noise_energy_hold_counter_ == 0) {
-        constexpr float kMaxNoiseEnergyFactor = 1.01f;
+    // Update the noise estimate in a minimum statistics-type manner.
+    if (signal_type == SignalClassifier::SignalType::kStationary) {
+      if (frame_energy > noise_energy_) {
+        // Leak the estimate upwards towards the frame energy if no recent
+        // downward update.
+        noise_energy_hold_counter_ =
+            std::max(noise_energy_hold_counter_ - 1, 0);
+
+        if (noise_energy_hold_counter_ == 0) {
+          constexpr float kMaxNoiseEnergyFactor = 1.01f;
+          noise_energy_ =
+              std::min(noise_energy_ * kMaxNoiseEnergyFactor, frame_energy);
+        }
+      } else {
+        // Update smoothly downwards with a limited maximum update magnitude.
+        constexpr float kMinNoiseEnergyFactor = 0.9f;
+        constexpr float kNoiseEnergyDeltaFactor = 0.05f;
         noise_energy_ =
-            std::min(noise_energy_ * kMaxNoiseEnergyFactor, frame_energy);
+            std::max(noise_energy_ * kMinNoiseEnergyFactor,
+                     noise_energy_ - kNoiseEnergyDeltaFactor *
+                                         (noise_energy_ - frame_energy));
+        // Prevent an energy increase for the next 10 seconds.
+        constexpr int kNumFramesToEnergyIncreaseAllowed = 1000;
+        noise_energy_hold_counter_ = kNumFramesToEnergyIncreaseAllowed;
       }
     } else {
-      // Update smoothly downwards with a limited maximum update magnitude.
-      constexpr float kMinNoiseEnergyFactor = 0.9f;
-      constexpr float kNoiseEnergyDeltaFactor = 0.05f;
-      noise_energy_ =
-          std::max(noise_energy_ * kMinNoiseEnergyFactor,
-                   noise_energy_ - kNoiseEnergyDeltaFactor *
-                                       (noise_energy_ - frame_energy));
-      // Prevent an energy increase for the next 10 seconds.
-      constexpr int kNumFramesToEnergyIncreaseAllowed = 1000;
-      noise_energy_hold_counter_ = kNumFramesToEnergyIncreaseAllowed;
+      // TODO(bugs.webrtc.org/7494): Remove to not forget the estimated level.
+      // For a non-stationary signal, leak the estimate downwards in order to
+      // avoid estimate locking due to incorrect signal classification.
+      noise_energy_ = noise_energy_ * 0.99f;
     }
-  } else {
-    // TODO(bugs.webrtc.org/7494): Remove to not forget the estimated level.
-    // For a non-stationary signal, leak the estimate downwards in order to
-    // avoid estimate locking due to incorrect signal classification.
-    noise_energy_ = noise_energy_ * 0.99f;
+
+    // Ensure a minimum of the estimate.
+    noise_energy_ = std::max(noise_energy_, min_noise_energy_);
+    return EnergyToDbfs(noise_energy_, frame.samples_per_channel());
   }
 
-  // Ensure a minimum of the estimate.
-  noise_energy_ = std::max(noise_energy_, min_noise_energy_);
-  return EnergyToDbfs(noise_energy_, frame.samples_per_channel());
+ private:
+  void Initialize(int sample_rate_hz) {
+    sample_rate_hz_ = sample_rate_hz;
+    noise_energy_ = 1.0f;
+    first_update_ = true;
+    min_noise_energy_ = sample_rate_hz * 2.0f * 2.0f / kFramesPerSecond;
+    noise_energy_hold_counter_ = 0;
+    signal_classifier_.Initialize(sample_rate_hz);
+  }
+
+  ApmDataDumper* const data_dumper_;
+  int sample_rate_hz_;
+  float min_noise_energy_;
+  bool first_update_;
+  float noise_energy_;
+  int noise_energy_hold_counter_;
+  SignalClassifier signal_classifier_;
+};
+
+}  // namespace
+
+std::unique_ptr<NoiseLevelEstimator> CreateNoiseLevelEstimator(
+    ApmDataDumper* data_dumper) {
+  return std::make_unique<NoiseLevelEstimatorImpl>(data_dumper);
 }
 
 }  // namespace webrtc
diff --git a/modules/audio_processing/agc2/noise_level_estimator.h b/modules/audio_processing/agc2/noise_level_estimator.h
index 65d4623..7e57b4c 100644
--- a/modules/audio_processing/agc2/noise_level_estimator.h
+++ b/modules/audio_processing/agc2/noise_level_estimator.h
@@ -11,33 +11,26 @@
 #ifndef MODULES_AUDIO_PROCESSING_AGC2_NOISE_LEVEL_ESTIMATOR_H_
 #define MODULES_AUDIO_PROCESSING_AGC2_NOISE_LEVEL_ESTIMATOR_H_
 
-#include "modules/audio_processing/agc2/signal_classifier.h"
+#include <memory>
+
 #include "modules/audio_processing/include/audio_frame_view.h"
 
 namespace webrtc {
 class ApmDataDumper;
 
+// Noise level estimator interface.
 class NoiseLevelEstimator {
  public:
-  NoiseLevelEstimator(ApmDataDumper* data_dumper);
-  NoiseLevelEstimator(const NoiseLevelEstimator&) = delete;
-  NoiseLevelEstimator& operator=(const NoiseLevelEstimator&) = delete;
-  ~NoiseLevelEstimator();
-  // Returns the estimated noise level in dBFS.
-  float Analyze(const AudioFrameView<const float>& frame);
-
- private:
-  void Initialize(int sample_rate_hz);
-
-  ApmDataDumper* const data_dumper_;
-  int sample_rate_hz_;
-  float min_noise_energy_;
-  bool first_update_;
-  float noise_energy_;
-  int noise_energy_hold_counter_;
-  SignalClassifier signal_classifier_;
+  virtual ~NoiseLevelEstimator() = default;
+  // Analyzes a 10 ms `frame`, updates the noise level estimation and returns
+  // the value for the latter in dBFS.
+  virtual float Analyze(const AudioFrameView<const float>& frame) = 0;
 };
 
+// Creates a noise level estimator based on stationarity detection.
+std::unique_ptr<NoiseLevelEstimator> CreateNoiseLevelEstimator(
+    ApmDataDumper* data_dumper);
+
 }  // namespace webrtc
 
 #endif  // MODULES_AUDIO_PROCESSING_AGC2_NOISE_LEVEL_ESTIMATOR_H_
diff --git a/modules/audio_processing/agc2/noise_level_estimator_unittest.cc b/modules/audio_processing/agc2/noise_level_estimator_unittest.cc
index 327fcee..ccee34a 100644
--- a/modules/audio_processing/agc2/noise_level_estimator_unittest.cc
+++ b/modules/audio_processing/agc2/noise_level_estimator_unittest.cc
@@ -31,7 +31,7 @@
 float RunEstimator(rtc::FunctionView<float()> sample_generator,
                    int sample_rate_hz) {
   ApmDataDumper data_dumper(0);
-  NoiseLevelEstimator estimator(&data_dumper);
+  auto estimator = CreateNoiseLevelEstimator(&data_dumper);
   const int samples_per_channel =
       rtc::CheckedDivExact(sample_rate_hz, kFramesPerSecond);
   VectorFloatFrame signal(1, samples_per_channel, 0.0f);
@@ -41,9 +41,9 @@
     for (int j = 0; j < samples_per_channel; ++j) {
       frame_view.channel(0)[j] = sample_generator();
     }
-    estimator.Analyze(frame_view);
+    estimator->Analyze(frame_view);
   }
-  return estimator.Analyze(signal.float_frame_view());
+  return estimator->Analyze(signal.float_frame_view());
 }
 
 class NoiseEstimatorParametrization : public ::testing::TestWithParam<int> {