AGC2 periodically reset VAD state

Bug: webrtc:7494
Change-Id: I880ef3991ade4e429ccde843571f069ede149c0e
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/213342
Commit-Queue: Alessio Bazzica <alessiob@webrtc.org>
Reviewed-by: Jesus de Vicente Pena <devicentepena@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#33604}
diff --git a/modules/audio_processing/agc2/adaptive_agc.cc b/modules/audio_processing/agc2/adaptive_agc.cc
index e72942a..9657aec 100644
--- a/modules/audio_processing/agc2/adaptive_agc.cc
+++ b/modules/audio_processing/agc2/adaptive_agc.cc
@@ -30,8 +30,8 @@
 }
 
 constexpr int kGainApplierAdjacentSpeechFramesThreshold = 1;
-constexpr float kMaxGainChangePerSecondDb = 3.f;
-constexpr float kMaxOutputNoiseLevelDbfs = -50.f;
+constexpr float kMaxGainChangePerSecondDb = 3.0f;
+constexpr float kMaxOutputNoiseLevelDbfs = -50.0f;
 
 // Detects the available CPU features and applies any kill-switches.
 AvailableCpuFeatures GetAllowedCpuFeatures(
@@ -71,7 +71,8 @@
               .level_estimator_adjacent_speech_frames_threshold,
           config.adaptive_digital.initial_saturation_margin_db,
           config.adaptive_digital.extra_saturation_margin_db),
-      vad_(config.adaptive_digital.vad_probability_attack,
+      vad_(config.adaptive_digital.vad_reset_period_ms,
+           config.adaptive_digital.vad_probability_attack,
            GetAllowedCpuFeatures(config.adaptive_digital)),
       gain_applier_(
           apm_data_dumper,
@@ -95,7 +96,7 @@
   info.input_level_dbfs = speech_level_estimator_.level_dbfs();
   info.input_noise_level_dbfs = noise_level_estimator_.Analyze(frame);
   info.limiter_envelope_dbfs =
-      limiter_envelope > 0 ? FloatS16ToDbfs(limiter_envelope) : -90.f;
+      limiter_envelope > 0 ? FloatS16ToDbfs(limiter_envelope) : -90.0f;
   info.estimate_is_confident = speech_level_estimator_.IsConfident();
   DumpDebugData(info, *apm_data_dumper_);
   gain_applier_.Process(info, frame);
diff --git a/modules/audio_processing/agc2/agc2_common.h b/modules/audio_processing/agc2/agc2_common.h
index 594a37e..d0df43f 100644
--- a/modules/audio_processing/agc2/agc2_common.h
+++ b/modules/audio_processing/agc2/agc2_common.h
@@ -15,20 +15,20 @@
 
 namespace webrtc {
 
-constexpr float kMinFloatS16Value = -32768.f;
-constexpr float kMaxFloatS16Value = 32767.f;
+constexpr float kMinFloatS16Value = -32768.0f;
+constexpr float kMaxFloatS16Value = 32767.0f;
 constexpr float kMaxAbsFloatS16Value = 32768.0f;
 
 constexpr int kFrameDurationMs = 10;
 constexpr int kSubFramesInFrame = 20;
 constexpr int kMaximalNumberOfSamplesPerChannel = 480;
 
-constexpr float kAttackFilterConstant = 0.f;
+constexpr float kAttackFilterConstant = 0.0f;
 
 // Adaptive digital gain applier settings below.
-constexpr float kHeadroomDbfs = 1.f;
-constexpr float kMaxGainDb = 30.f;
-constexpr float kInitialAdaptiveDigitalGainDb = 8.f;
+constexpr float kHeadroomDbfs = 1.0f;
+constexpr float kMaxGainDb = 30.0f;
+constexpr float kInitialAdaptiveDigitalGainDb = 8.0f;
 // At what limiter levels should we start decreasing the adaptive digital gain.
 constexpr float kLimiterThresholdForAgcGainDbfs = -kHeadroomDbfs;
 
@@ -39,17 +39,18 @@
 
 // The amount of 'memory' of the Level Estimator. Decides leak factors.
 constexpr int kFullBufferSizeMs = 1200;
-constexpr float kFullBufferLeakFactor = 1.f - 1.f / kFullBufferSizeMs;
+constexpr float kFullBufferLeakFactor = 1.0f - 1.0f / kFullBufferSizeMs;
 
-constexpr float kInitialSpeechLevelEstimateDbfs = -30.f;
+constexpr float kInitialSpeechLevelEstimateDbfs = -30.0f;
 
 // Robust VAD probability and speech decisions.
-constexpr float kDefaultSmoothedVadProbabilityAttack = 1.f;
+constexpr int kDefaultVadRnnResetPeriodMs = 1500;
+constexpr float kDefaultSmoothedVadProbabilityAttack = 1.0f;
 constexpr int kDefaultLevelEstimatorAdjacentSpeechFramesThreshold = 1;
 
 // Saturation Protector settings.
-constexpr float kDefaultInitialSaturationMarginDb = 20.f;
-constexpr float kDefaultExtraSaturationMarginDb = 2.f;
+constexpr float kDefaultInitialSaturationMarginDb = 20.0f;
+constexpr float kDefaultExtraSaturationMarginDb = 2.0f;
 
 constexpr int kPeakEnveloperSuperFrameLengthMs = 400;
 static_assert(kFullBufferSizeMs % kPeakEnveloperSuperFrameLengthMs == 0,
diff --git a/modules/audio_processing/agc2/vad_with_level.cc b/modules/audio_processing/agc2/vad_with_level.cc
index b54ae56..597c09c 100644
--- a/modules/audio_processing/agc2/vad_with_level.cc
+++ b/modules/audio_processing/agc2/vad_with_level.cc
@@ -38,6 +38,8 @@
   Vad& operator=(const Vad&) = delete;
   ~Vad() = default;
 
+  void Reset() override { rnn_vad_.Reset(); }
+
   float ComputeProbability(AudioFrameView<const float> frame) override {
     // The source number of channels is 1, because we always use the 1st
     // channel.
@@ -66,41 +68,57 @@
 // Returns an updated version of `p_old` by using instant decay and the given
 // `attack` on a new VAD probability value `p_new`.
 float SmoothedVadProbability(float p_old, float p_new, float attack) {
-  RTC_DCHECK_GT(attack, 0.f);
-  RTC_DCHECK_LE(attack, 1.f);
-  if (p_new < p_old || attack == 1.f) {
+  RTC_DCHECK_GT(attack, 0.0f);
+  RTC_DCHECK_LE(attack, 1.0f);
+  if (p_new < p_old || attack == 1.0f) {
     // Instant decay (or no smoothing).
     return p_new;
   } else {
     // Attack phase.
-    return attack * p_new + (1.f - attack) * p_old;
+    return attack * p_new + (1.0f - attack) * p_old;
   }
 }
 
 }  // namespace
 
 VadLevelAnalyzer::VadLevelAnalyzer()
-    : VadLevelAnalyzer(kDefaultSmoothedVadProbabilityAttack,
+    : VadLevelAnalyzer(kDefaultVadRnnResetPeriodMs,
+                       kDefaultSmoothedVadProbabilityAttack,
                        GetAvailableCpuFeatures()) {}
 
-VadLevelAnalyzer::VadLevelAnalyzer(float vad_probability_attack,
+VadLevelAnalyzer::VadLevelAnalyzer(int vad_reset_period_ms,
+                                   float vad_probability_attack,
                                    const AvailableCpuFeatures& cpu_features)
-    : VadLevelAnalyzer(vad_probability_attack,
+    : VadLevelAnalyzer(vad_reset_period_ms,
+                       vad_probability_attack,
                        std::make_unique<Vad>(cpu_features)) {}
 
-VadLevelAnalyzer::VadLevelAnalyzer(float vad_probability_attack,
+VadLevelAnalyzer::VadLevelAnalyzer(int vad_reset_period_ms,
+                                   float vad_probability_attack,
                                    std::unique_ptr<VoiceActivityDetector> vad)
-    : vad_(std::move(vad)), vad_probability_attack_(vad_probability_attack) {
+    : vad_(std::move(vad)),
+      vad_reset_period_frames_(
+          rtc::CheckedDivExact(vad_reset_period_ms, kFrameDurationMs)),
+      vad_probability_attack_(vad_probability_attack),
+      time_to_vad_reset_(vad_reset_period_frames_),
+      vad_probability_(0.0f) {
   RTC_DCHECK(vad_);
+  RTC_DCHECK_GT(vad_reset_period_frames_, 1);
 }
 
 VadLevelAnalyzer::~VadLevelAnalyzer() = default;
 
 VadLevelAnalyzer::Result VadLevelAnalyzer::AnalyzeFrame(
     AudioFrameView<const float> frame) {
+  // Periodically reset the VAD.
+  time_to_vad_reset_--;
+  if (time_to_vad_reset_ <= 0) {
+    vad_->Reset();
+    time_to_vad_reset_ = vad_reset_period_frames_;
+  }
   // Compute levels.
-  float peak = 0.f;
-  float rms = 0.f;
+  float peak = 0.0f;
+  float rms = 0.0f;
   for (const auto& x : frame.channel(0)) {
     peak = std::max(std::fabs(x), peak);
     rms += x * x;
diff --git a/modules/audio_processing/agc2/vad_with_level.h b/modules/audio_processing/agc2/vad_with_level.h
index 2a67882..386f162 100644
--- a/modules/audio_processing/agc2/vad_with_level.h
+++ b/modules/audio_processing/agc2/vad_with_level.h
@@ -31,17 +31,26 @@
   class VoiceActivityDetector {
    public:
     virtual ~VoiceActivityDetector() = default;
+    // Resets the internal state.
+    virtual void Reset() = 0;
     // Analyzes an audio frame and returns the speech probability.
     virtual float ComputeProbability(AudioFrameView<const float> frame) = 0;
   };
 
   // Ctor. Uses the default VAD.
   VadLevelAnalyzer();
-  VadLevelAnalyzer(float vad_probability_attack,
+  // Ctor. `vad_reset_period_ms` indicates the period in milliseconds to call
+  // `VadLevelAnalyzer::Reset()`; it must be equal to or greater than the
+  // duration of two frames. `vad_probability_attack` is a number in (0,1] used
+  // to smooth the speech probability (instant decay, slow attack).
+  VadLevelAnalyzer(int vad_reset_period_ms,
+                   float vad_probability_attack,
                    const AvailableCpuFeatures& cpu_features);
   // Ctor. Uses a custom `vad`.
-  VadLevelAnalyzer(float vad_probability_attack,
+  VadLevelAnalyzer(int vad_reset_period_ms,
+                   float vad_probability_attack,
                    std::unique_ptr<VoiceActivityDetector> vad);
+
   VadLevelAnalyzer(const VadLevelAnalyzer&) = delete;
   VadLevelAnalyzer& operator=(const VadLevelAnalyzer&) = delete;
   ~VadLevelAnalyzer();
@@ -51,8 +60,10 @@
 
  private:
   std::unique_ptr<VoiceActivityDetector> vad_;
+  const int vad_reset_period_frames_;
   const float vad_probability_attack_;
-  float vad_probability_ = 0.f;
+  int time_to_vad_reset_;
+  float vad_probability_;
 };
 
 }  // namespace webrtc
diff --git a/modules/audio_processing/agc2/vad_with_level_unittest.cc b/modules/audio_processing/agc2/vad_with_level_unittest.cc
index fb93c86..fd8265e 100644
--- a/modules/audio_processing/agc2/vad_with_level_unittest.cc
+++ b/modules/audio_processing/agc2/vad_with_level_unittest.cc
@@ -10,6 +10,7 @@
 
 #include "modules/audio_processing/agc2/vad_with_level.h"
 
+#include <limits>
 #include <memory>
 #include <vector>
 
@@ -25,13 +26,17 @@
 using ::testing::AnyNumber;
 using ::testing::ReturnRoundRobin;
 
-constexpr float kInstantAttack = 1.f;
+constexpr int kNoVadPeriodicReset =
+    kFrameDurationMs * (std::numeric_limits<int>::max() / kFrameDurationMs);
+
+constexpr float kInstantAttack = 1.0f;
 constexpr float kSlowAttack = 0.1f;
 
 constexpr int kSampleRateHz = 8000;
 
 class MockVad : public VadLevelAnalyzer::VoiceActivityDetector {
  public:
+  MOCK_METHOD(void, Reset, (), (override));
   MOCK_METHOD(float,
               ComputeProbability,
               (AudioFrameView<const float> frame),
@@ -42,20 +47,25 @@
 // the next value from `speech_probabilities` until it reaches the end and will
 // restart from the beginning.
 std::unique_ptr<VadLevelAnalyzer> CreateVadLevelAnalyzerWithMockVad(
+    int vad_reset_period_ms,
     float vad_probability_attack,
-    const std::vector<float>& speech_probabilities) {
+    const std::vector<float>& speech_probabilities,
+    int expected_vad_reset_calls = 0) {
   auto vad = std::make_unique<MockVad>();
   EXPECT_CALL(*vad, ComputeProbability)
       .Times(AnyNumber())
       .WillRepeatedly(ReturnRoundRobin(speech_probabilities));
-  return std::make_unique<VadLevelAnalyzer>(vad_probability_attack,
-                                            std::move(vad));
+  if (expected_vad_reset_calls >= 0) {
+    EXPECT_CALL(*vad, Reset).Times(expected_vad_reset_calls);
+  }
+  return std::make_unique<VadLevelAnalyzer>(
+      vad_reset_period_ms, vad_probability_attack, std::move(vad));
 }
 
 // 10 ms mono frame.
 struct FrameWithView {
   // Ctor. Initializes the frame samples with `value`.
-  FrameWithView(float value = 0.f)
+  FrameWithView(float value = 0.0f)
       : channel0(samples.data()),
         view(&channel0, /*num_channels=*/1, samples.size()) {
     samples.fill(value);
@@ -67,8 +77,8 @@
 
 TEST(AutomaticGainController2VadLevelAnalyzer, PeakLevelGreaterThanRmsLevel) {
   // Handcrafted frame so that the average is lower than the peak value.
-  FrameWithView frame(1000.f);  // Constant frame.
-  frame.samples[10] = 2000.f;   // Except for one peak value.
+  FrameWithView frame(1000.0f);  // Constant frame.
+  frame.samples[10] = 2000.0f;   // Except for one peak value.
 
   // Compute audio frame levels (the VAD result is ignored).
   VadLevelAnalyzer analyzer;
@@ -83,9 +93,9 @@
 TEST(AutomaticGainController2VadLevelAnalyzer, NoSpeechProbabilitySmoothing) {
   const std::vector<float> speech_probabilities{0.709f, 0.484f, 0.882f, 0.167f,
                                                 0.44f,  0.525f, 0.858f, 0.314f,
-                                                0.653f, 0.965f, 0.413f, 0.f};
-  auto analyzer =
-      CreateVadLevelAnalyzerWithMockVad(kInstantAttack, speech_probabilities);
+                                                0.653f, 0.965f, 0.413f, 0.0f};
+  auto analyzer = CreateVadLevelAnalyzerWithMockVad(
+      kNoVadPeriodicReset, kInstantAttack, speech_probabilities);
   FrameWithView frame;
   for (int i = 0; rtc::SafeLt(i, speech_probabilities.size()); ++i) {
     SCOPED_TRACE(i);
@@ -98,16 +108,17 @@
 // the unprocessed one when slow attack is used.
 TEST(AutomaticGainController2VadLevelAnalyzer,
      SlowAttackSpeechProbabilitySmoothing) {
-  const std::vector<float> speech_probabilities{0.f, 0.f, 1.f, 1.f, 1.f, 1.f};
-  auto analyzer =
-      CreateVadLevelAnalyzerWithMockVad(kSlowAttack, speech_probabilities);
+  const std::vector<float> speech_probabilities{0.0f, 0.0f, 1.0f,
+                                                1.0f, 1.0f, 1.0f};
+  auto analyzer = CreateVadLevelAnalyzerWithMockVad(
+      kNoVadPeriodicReset, kSlowAttack, speech_probabilities);
   FrameWithView frame;
-  float prev_probability = 0.f;
+  float prev_probability = 0.0f;
   for (int i = 0; rtc::SafeLt(i, speech_probabilities.size()); ++i) {
     SCOPED_TRACE(i);
     const float smoothed_probability =
         analyzer->AnalyzeFrame(frame.view).speech_probability;
-    EXPECT_LT(smoothed_probability, 1.f);  // Not enough time to reach 1.
+    EXPECT_LT(smoothed_probability, 1.0f);  // Not enough time to reach 1.
     EXPECT_LE(prev_probability, smoothed_probability);  // Converge towards 1.
     prev_probability = smoothed_probability;
   }
@@ -116,15 +127,52 @@
 // Checks that the smoothed speech probability instantly decays to the
 // unprocessed one when slow attack is used.
 TEST(AutomaticGainController2VadLevelAnalyzer, SpeechProbabilityInstantDecay) {
-  const std::vector<float> speech_probabilities{1.f, 1.f, 1.f, 1.f, 1.f, 0.f};
-  auto analyzer =
-      CreateVadLevelAnalyzerWithMockVad(kSlowAttack, speech_probabilities);
+  const std::vector<float> speech_probabilities{1.0f, 1.0f, 1.0f,
+                                                1.0f, 1.0f, 0.0f};
+  auto analyzer = CreateVadLevelAnalyzerWithMockVad(
+      kNoVadPeriodicReset, kSlowAttack, speech_probabilities);
   FrameWithView frame;
   for (int i = 0; rtc::SafeLt(i, speech_probabilities.size() - 1); ++i) {
     analyzer->AnalyzeFrame(frame.view);
   }
-  EXPECT_EQ(0.f, analyzer->AnalyzeFrame(frame.view).speech_probability);
+  EXPECT_EQ(0.0f, analyzer->AnalyzeFrame(frame.view).speech_probability);
 }
 
+// Checks that the VAD is not periodically reset.
+TEST(AutomaticGainController2VadLevelAnalyzer, VadNoPeriodicReset) {
+  constexpr int kNumFrames = 19;
+  auto analyzer = CreateVadLevelAnalyzerWithMockVad(
+      kNoVadPeriodicReset, kSlowAttack, /*speech_probabilities=*/{1.0f},
+      /*expected_vad_reset_calls=*/0);
+  FrameWithView frame;
+  for (int i = 0; i < kNumFrames; ++i) {
+    analyzer->AnalyzeFrame(frame.view);
+  }
+}
+
+class VadPeriodResetParametrization
+    : public ::testing::TestWithParam<std::tuple<int, int>> {
+ protected:
+  int num_frames() const { return std::get<0>(GetParam()); }
+  int vad_reset_period_frames() const { return std::get<1>(GetParam()); }
+};
+
+// Checks that the VAD is periodically reset with the expected period.
+TEST_P(VadPeriodResetParametrization, VadPeriodicReset) {
+  auto analyzer = CreateVadLevelAnalyzerWithMockVad(
+      /*vad_reset_period_ms=*/vad_reset_period_frames() * kFrameDurationMs,
+      kSlowAttack, /*speech_probabilities=*/{1.0f},
+      /*expected_vad_reset_calls=*/num_frames() / vad_reset_period_frames());
+  FrameWithView frame;
+  for (int i = 0; i < num_frames(); ++i) {
+    analyzer->AnalyzeFrame(frame.view);
+  }
+}
+
+INSTANTIATE_TEST_SUITE_P(AutomaticGainController2VadLevelAnalyzer,
+                         VadPeriodResetParametrization,
+                         ::testing::Combine(::testing::Values(1, 19, 123),
+                                            ::testing::Values(2, 5, 20, 53)));
+
 }  // namespace
 }  // namespace webrtc
diff --git a/modules/audio_processing/include/audio_processing.h b/modules/audio_processing/include/audio_processing.h
index bb24a48..a5c266a 100644
--- a/modules/audio_processing/include/audio_processing.h
+++ b/modules/audio_processing/include/audio_processing.h
@@ -210,7 +210,7 @@
     // capture_level_adjustment instead.
     struct PreAmplifier {
       bool enabled = false;
-      float fixed_gain_factor = 1.f;
+      float fixed_gain_factor = 1.0f;
     } pre_amplifier;
 
     // Functionality for general level adjustment in the capture pipeline. This
@@ -222,9 +222,9 @@
       }
       bool enabled = false;
       // The `pre_gain_factor` scales the signal before any processing is done.
-      float pre_gain_factor = 1.f;
+      float pre_gain_factor = 1.0f;
       // The `post_gain_factor` scales the signal after all processing is done.
-      float post_gain_factor = 1.f;
+      float post_gain_factor = 1.0f;
       struct AnalogMicGainEmulation {
         bool operator==(const AnalogMicGainEmulation& rhs) const;
         bool operator!=(const AnalogMicGainEmulation& rhs) const {
@@ -352,20 +352,21 @@
       enum LevelEstimator { kRms, kPeak };
       bool enabled = false;
       struct FixedDigital {
-        float gain_db = 0.f;
+        float gain_db = 0.0f;
       } fixed_digital;
       struct AdaptiveDigital {
         bool enabled = false;
+        int vad_reset_period_ms = 1500;
         float vad_probability_attack = 0.3f;
         LevelEstimator level_estimator = kRms;
         int level_estimator_adjacent_speech_frames_threshold = 6;
         // TODO(crbug.com/webrtc/7494): Remove `use_saturation_protector`.
         bool use_saturation_protector = true;
-        float initial_saturation_margin_db = 20.f;
-        float extra_saturation_margin_db = 5.f;
+        float initial_saturation_margin_db = 20.0f;
+        float extra_saturation_margin_db = 5.0f;
         int gain_applier_adjacent_speech_frames_threshold = 6;
-        float max_gain_change_db_per_second = 3.f;
-        float max_output_noise_level_dbfs = -55.f;
+        float max_gain_change_db_per_second = 3.0f;
+        float max_output_noise_level_dbfs = -55.0f;
         bool sse2_allowed = true;
         bool avx2_allowed = true;
         bool neon_allowed = true;
@@ -417,7 +418,7 @@
       int max_volume;  // Maximum play-out volume.
     };
 
-    RuntimeSetting() : type_(Type::kNotSpecified), value_(0.f) {}
+    RuntimeSetting() : type_(Type::kNotSpecified), value_(0.0f) {}
     ~RuntimeSetting() = default;
 
     static RuntimeSetting CreateCapturePreGain(float gain) {
@@ -439,8 +440,8 @@
     // Corresponds to Config::GainController2::fixed_digital::gain_db, but for
     // runtime configuration.
     static RuntimeSetting CreateCaptureFixedPostGain(float gain_db) {
-      RTC_DCHECK_GE(gain_db, 0.f);
-      RTC_DCHECK_LE(gain_db, 90.f);
+      RTC_DCHECK_GE(gain_db, 0.0f);
+      RTC_DCHECK_LE(gain_db, 90.0f);
       return {Type::kCaptureFixedPostGain, gain_db};
     }