AGC2: VAD moved into `GainController2`

Bit exactness verified with audioproc_f on a collection of AEC dumps
and Wav files (42 recordings in total).

Bug: webrtc:7494
Change-Id: Id9849c4463791f5a203afe31efc163efb4d4458e
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/234583
Reviewed-by: Hanna Silen <silen@webrtc.org>
Commit-Queue: Alessio Bazzica <alessiob@webrtc.org>
Cr-Commit-Position: refs/heads/main@{#35248}
diff --git a/modules/audio_processing/BUILD.gn b/modules/audio_processing/BUILD.gn
index 4223101..7fbae51 100644
--- a/modules/audio_processing/BUILD.gn
+++ b/modules/audio_processing/BUILD.gn
@@ -132,8 +132,10 @@
     "../../rtc_base:stringutils",
     "../../system_wrappers:field_trial",
     "agc2:adaptive_digital",
+    "agc2:cpu_features",
     "agc2:fixed_digital",
     "agc2:gain_applier",
+    "agc2:vad_wrapper",
   ]
 }
 
diff --git a/modules/audio_processing/agc2/BUILD.gn b/modules/audio_processing/agc2/BUILD.gn
index f767a6d..5ec3ce1 100644
--- a/modules/audio_processing/agc2/BUILD.gn
+++ b/modules/audio_processing/agc2/BUILD.gn
@@ -52,7 +52,6 @@
     "../../../rtc_base:rtc_base_approved",
     "../../../rtc_base:safe_compare",
     "../../../rtc_base:safe_minmax",
-    "../../../system_wrappers:field_trial",
     "../../../system_wrappers:metrics",
   ]
 
@@ -150,7 +149,11 @@
     "vad_wrapper.cc",
     "vad_wrapper.h",
   ]
-  visibility = [ "./*" ]
+
+  visibility = [
+    "..:gain_controller2",
+    "./*",
+  ]
 
   defines = []
   if (rtc_build_with_neon && current_cpu != "arm64") {
diff --git a/modules/audio_processing/agc2/adaptive_agc.cc b/modules/audio_processing/agc2/adaptive_agc.cc
index b543365..295a33f 100644
--- a/modules/audio_processing/agc2/adaptive_agc.cc
+++ b/modules/audio_processing/agc2/adaptive_agc.cc
@@ -11,31 +11,14 @@
 #include "modules/audio_processing/agc2/adaptive_agc.h"
 
 #include "common_audio/include/audio_util.h"
-#include "modules/audio_processing/agc2/cpu_features.h"
 #include "modules/audio_processing/agc2/vad_wrapper.h"
 #include "modules/audio_processing/logging/apm_data_dumper.h"
 #include "rtc_base/checks.h"
 #include "rtc_base/logging.h"
-#include "system_wrappers/include/field_trial.h"
 
 namespace webrtc {
 namespace {
 
-// Detects the available CPU features and applies any kill-switches.
-AvailableCpuFeatures GetAllowedCpuFeatures() {
-  AvailableCpuFeatures features = GetAvailableCpuFeatures();
-  if (field_trial::IsEnabled("WebRTC-Agc2SimdSse2KillSwitch")) {
-    features.sse2 = false;
-  }
-  if (field_trial::IsEnabled("WebRTC-Agc2SimdAvx2KillSwitch")) {
-    features.avx2 = false;
-  }
-  if (field_trial::IsEnabled("WebRTC-Agc2SimdNeonKillSwitch")) {
-    features.neon = false;
-  }
-  return features;
-}
-
 // Peak and RMS audio levels in dBFS.
 struct AudioLevels {
   float peak_dbfs;
@@ -60,7 +43,6 @@
     ApmDataDumper* apm_data_dumper,
     const AudioProcessing::Config::GainController2::AdaptiveDigital& config)
     : speech_level_estimator_(apm_data_dumper, config),
-      vad_(config.vad_reset_period_ms, GetAllowedCpuFeatures()),
       gain_controller_(apm_data_dumper, config),
       apm_data_dumper_(apm_data_dumper),
       noise_level_estimator_(CreateNoiseFloorEstimator(apm_data_dumper)),
@@ -77,18 +59,18 @@
 
 void AdaptiveAgc::Initialize(int sample_rate_hz, int num_channels) {
   gain_controller_.Initialize(sample_rate_hz, num_channels);
-  vad_.Initialize(sample_rate_hz);
 }
 
-void AdaptiveAgc::Process(AudioFrameView<float> frame, float limiter_envelope) {
+void AdaptiveAgc::Process(AudioFrameView<float> frame,
+                          float speech_probability,
+                          float limiter_envelope) {
   AudioLevels levels = ComputeAudioLevels(frame);
+  apm_data_dumper_->DumpRaw("agc2_input_rms_dbfs", levels.rms_dbfs);
+  apm_data_dumper_->DumpRaw("agc2_input_peak_dbfs", levels.peak_dbfs);
 
   AdaptiveDigitalGainApplier::FrameInfo info;
 
-  info.speech_probability = vad_.Analyze(frame);
-  apm_data_dumper_->DumpRaw("agc2_speech_probability", info.speech_probability);
-  apm_data_dumper_->DumpRaw("agc2_input_rms_dbfs", levels.rms_dbfs);
-  apm_data_dumper_->DumpRaw("agc2_input_peak_dbfs", levels.peak_dbfs);
+  info.speech_probability = speech_probability;
 
   speech_level_estimator_.Update(levels.rms_dbfs, levels.peak_dbfs,
                                  info.speech_probability);
diff --git a/modules/audio_processing/agc2/adaptive_agc.h b/modules/audio_processing/agc2/adaptive_agc.h
index 32de680..a9a6985 100644
--- a/modules/audio_processing/agc2/adaptive_agc.h
+++ b/modules/audio_processing/agc2/adaptive_agc.h
@@ -17,7 +17,6 @@
 #include "modules/audio_processing/agc2/adaptive_mode_level_estimator.h"
 #include "modules/audio_processing/agc2/noise_level_estimator.h"
 #include "modules/audio_processing/agc2/saturation_protector.h"
-#include "modules/audio_processing/agc2/vad_wrapper.h"
 #include "modules/audio_processing/include/audio_frame_view.h"
 #include "modules/audio_processing/include/audio_processing.h"
 
@@ -38,16 +37,17 @@
   // TODO(crbug.com/webrtc/7494): Add `SetLimiterEnvelope()`.
 
   // Analyzes `frame` and applies a digital adaptive gain to it. Takes into
-  // account the envelope measured by the limiter.
+  // account the speech probability and the envelope measured by the limiter.
   // TODO(crbug.com/webrtc/7494): Remove `limiter_envelope`.
-  void Process(AudioFrameView<float> frame, float limiter_envelope);
+  void Process(AudioFrameView<float> frame,
+               float speech_probability,
+               float limiter_envelope);
 
   // Handles a gain change applied to the input signal (e.g., analog gain).
   void HandleInputGainChange();
 
  private:
   AdaptiveModeLevelEstimator speech_level_estimator_;
-  VoiceActivityDetectorWrapper vad_;
   AdaptiveDigitalGainApplier gain_controller_;
   ApmDataDumper* const apm_data_dumper_;
   std::unique_ptr<NoiseLevelEstimator> noise_level_estimator_;
diff --git a/modules/audio_processing/agc2/vad_wrapper.cc b/modules/audio_processing/agc2/vad_wrapper.cc
index 17d9638..91448f8 100644
--- a/modules/audio_processing/agc2/vad_wrapper.cc
+++ b/modules/audio_processing/agc2/vad_wrapper.cc
@@ -54,24 +54,25 @@
 
 VoiceActivityDetectorWrapper::VoiceActivityDetectorWrapper(
     int vad_reset_period_ms,
-    const AvailableCpuFeatures& cpu_features)
-    : VoiceActivityDetectorWrapper(
-          vad_reset_period_ms,
-          std::make_unique<MonoVadImpl>(cpu_features)) {}
+    const AvailableCpuFeatures& cpu_features,
+    int sample_rate_hz)
+    : VoiceActivityDetectorWrapper(vad_reset_period_ms,
+                                   std::make_unique<MonoVadImpl>(cpu_features),
+                                   sample_rate_hz) {}
 
 VoiceActivityDetectorWrapper::VoiceActivityDetectorWrapper(
     int vad_reset_period_ms,
-    std::unique_ptr<MonoVad> vad)
+    std::unique_ptr<MonoVad> vad,
+    int sample_rate_hz)
     : vad_reset_period_frames_(
           rtc::CheckedDivExact(vad_reset_period_ms, kFrameDurationMs)),
-      initialized_(false),
-      frame_size_(0),
       time_to_vad_reset_(vad_reset_period_frames_),
       vad_(std::move(vad)) {
   RTC_DCHECK(vad_);
   RTC_DCHECK_GT(vad_reset_period_frames_, 1);
   resampled_buffer_.resize(
       rtc::CheckedDivExact(vad_->SampleRateHz(), kNumFramesPerSecond));
+  Initialize(sample_rate_hz);
 }
 
 VoiceActivityDetectorWrapper::~VoiceActivityDetectorWrapper() = default;
@@ -85,11 +86,9 @@
   constexpr int kStatusOk = 0;
   RTC_DCHECK_EQ(status, kStatusOk);
   vad_->Reset();
-  initialized_ = true;
 }
 
 float VoiceActivityDetectorWrapper::Analyze(AudioFrameView<const float> frame) {
-  RTC_DCHECK(initialized_);
   // Periodically reset the VAD.
   time_to_vad_reset_--;
   if (time_to_vad_reset_ <= 0) {
diff --git a/modules/audio_processing/agc2/vad_wrapper.h b/modules/audio_processing/agc2/vad_wrapper.h
index 0579ca1..6df0ead 100644
--- a/modules/audio_processing/agc2/vad_wrapper.h
+++ b/modules/audio_processing/agc2/vad_wrapper.h
@@ -43,20 +43,20 @@
   // Ctor. `vad_reset_period_ms` indicates the period in milliseconds to call
   // `MonoVad::Reset()`; it must be equal to or greater than the duration of two
   // frames. Uses `cpu_features` to instantiate the default VAD.
-  // TODO(bugs.webrtc.org/7494): Pass sample rate.
   VoiceActivityDetectorWrapper(int vad_reset_period_ms,
-                               const AvailableCpuFeatures& cpu_features);
+                               const AvailableCpuFeatures& cpu_features,
+                               int sample_rate_hz);
   // Ctor. Uses a custom `vad`.
   VoiceActivityDetectorWrapper(int vad_reset_period_ms,
-                               std::unique_ptr<MonoVad> vad);
+                               std::unique_ptr<MonoVad> vad,
+                               int sample_rate_hz);
 
   VoiceActivityDetectorWrapper(const VoiceActivityDetectorWrapper&) = delete;
   VoiceActivityDetectorWrapper& operator=(const VoiceActivityDetectorWrapper&) =
       delete;
   ~VoiceActivityDetectorWrapper();
 
-  // TODO(bugs.webrtc.org/7494): Call initialize in the ctor.
-  // Initializes the VAD wrapper. Must be called before `Analyze()`.
+  // Initializes the VAD wrapper.
   void Initialize(int sample_rate_hz);
 
   // Analyzes the first channel of `frame` and returns the speech probability.
@@ -66,8 +66,6 @@
 
  private:
   const int vad_reset_period_frames_;
-  // TODO(bugs.webrtc.org/7494): Remove `initialized_`.
-  bool initialized_;
   int frame_size_;
   int time_to_vad_reset_;
   PushResampler<float> resampler_;
diff --git a/modules/audio_processing/agc2/vad_wrapper_unittest.cc b/modules/audio_processing/agc2/vad_wrapper_unittest.cc
index 27e5af6..b61b015 100644
--- a/modules/audio_processing/agc2/vad_wrapper_unittest.cc
+++ b/modules/audio_processing/agc2/vad_wrapper_unittest.cc
@@ -31,6 +31,8 @@
 using ::testing::ReturnRoundRobin;
 using ::testing::Truly;
 
+constexpr int kNumFramesPerSecond = 100;
+
 constexpr int kNoVadPeriodicReset =
     kFrameDurationMs * (std::numeric_limits<int>::max() / kFrameDurationMs);
 
@@ -52,8 +54,7 @@
       .WillRepeatedly(Return(kSampleRate8kHz));
   EXPECT_CALL(*vad, Reset).Times(AnyNumber());
   auto vad_wrapper = std::make_unique<VoiceActivityDetectorWrapper>(
-      kNoVadPeriodicReset, std::move(vad));
-  vad_wrapper->Initialize(kSampleRate8kHz);
+      kNoVadPeriodicReset, std::move(vad), kSampleRate8kHz);
 }
 
 // Creates a `VoiceActivityDetectorWrapper` injecting a mock VAD that
@@ -61,27 +62,29 @@
 // restarts from the beginning when after the last element is returned.
 std::unique_ptr<VoiceActivityDetectorWrapper> CreateMockVadWrapper(
     int vad_reset_period_ms,
+    int sample_rate_hz,
     const std::vector<float>& speech_probabilities,
     int expected_vad_reset_calls) {
   auto vad = std::make_unique<MockVad>();
   EXPECT_CALL(*vad, SampleRateHz)
       .Times(AnyNumber())
-      .WillRepeatedly(Return(kSampleRate8kHz));
+      .WillRepeatedly(Return(sample_rate_hz));
   if (expected_vad_reset_calls >= 0) {
     EXPECT_CALL(*vad, Reset).Times(expected_vad_reset_calls);
   }
   EXPECT_CALL(*vad, Analyze)
       .Times(AnyNumber())
       .WillRepeatedly(ReturnRoundRobin(speech_probabilities));
-  return std::make_unique<VoiceActivityDetectorWrapper>(vad_reset_period_ms,
-                                                        std::move(vad));
+  return std::make_unique<VoiceActivityDetectorWrapper>(
+      vad_reset_period_ms, std::move(vad), kSampleRate8kHz);
 }
 
 // 10 ms mono frame.
 struct FrameWithView {
   // Ctor. Initializes the frame samples with `value`.
   explicit FrameWithView(int sample_rate_hz)
-      : samples(rtc::CheckedDivExact(sample_rate_hz, 100), 0.0f),
+      : samples(rtc::CheckedDivExact(sample_rate_hz, kNumFramesPerSecond),
+                0.0f),
         channel0(samples.data()),
         view(&channel0, /*num_channels=*/1, samples.size()) {}
   std::vector<float> samples;
@@ -94,10 +97,9 @@
   const std::vector<float> speech_probabilities{0.709f, 0.484f, 0.882f, 0.167f,
                                                 0.44f,  0.525f, 0.858f, 0.314f,
                                                 0.653f, 0.965f, 0.413f, 0.0f};
-  auto vad_wrapper =
-      CreateMockVadWrapper(kNoVadPeriodicReset, speech_probabilities,
-                           /*expected_vad_reset_calls=*/1);
-  vad_wrapper->Initialize(kSampleRate8kHz);
+  auto vad_wrapper = CreateMockVadWrapper(kNoVadPeriodicReset, kSampleRate8kHz,
+                                          speech_probabilities,
+                                          /*expected_vad_reset_calls=*/1);
   FrameWithView frame(kSampleRate8kHz);
   for (int i = 0; rtc::SafeLt(i, speech_probabilities.size()); ++i) {
     SCOPED_TRACE(i);
@@ -108,10 +110,9 @@
 // Checks that the VAD is not periodically reset.
 TEST(GainController2VoiceActivityDetectorWrapper, VadNoPeriodicReset) {
   constexpr int kNumFrames = 19;
-  auto vad_wrapper =
-      CreateMockVadWrapper(kNoVadPeriodicReset, /*speech_probabilities=*/{1.0f},
-                           /*expected_vad_reset_calls=*/1);
-  vad_wrapper->Initialize(kSampleRate8kHz);
+  auto vad_wrapper = CreateMockVadWrapper(kNoVadPeriodicReset, kSampleRate8kHz,
+                                          /*speech_probabilities=*/{1.0f},
+                                          /*expected_vad_reset_calls=*/1);
   FrameWithView frame(kSampleRate8kHz);
   for (int i = 0; i < kNumFrames; ++i) {
     vad_wrapper->Analyze(frame.view);
@@ -129,10 +130,10 @@
 TEST_P(VadPeriodResetParametrization, VadPeriodicReset) {
   auto vad_wrapper = CreateMockVadWrapper(
       /*vad_reset_period_ms=*/vad_reset_period_frames() * kFrameDurationMs,
+      kSampleRate8kHz,
       /*speech_probabilities=*/{1.0f},
       /*expected_vad_reset_calls=*/1 +
           num_frames() / vad_reset_period_frames());
-  vad_wrapper->Initialize(kSampleRate8kHz);
   FrameWithView frame(kSampleRate8kHz);
   for (int i = 0; i < num_frames(); ++i) {
     vad_wrapper->Analyze(frame.view);
@@ -161,13 +162,12 @@
       .WillRepeatedly(Return(vad_sample_rate_hz()));
   EXPECT_CALL(*vad, Reset).Times(1);
   EXPECT_CALL(*vad, Analyze(Truly([this](rtc::ArrayView<const float> frame) {
-    return rtc::SafeEq(frame.size(),
-                       rtc::CheckedDivExact(vad_sample_rate_hz(), 100));
+    return rtc::SafeEq(frame.size(), rtc::CheckedDivExact(vad_sample_rate_hz(),
+                                                          kNumFramesPerSecond));
   }))).Times(1);
   auto vad_wrapper = std::make_unique<VoiceActivityDetectorWrapper>(
-      kNoVadPeriodicReset, std::move(vad));
+      kNoVadPeriodicReset, std::move(vad), input_sample_rate_hz());
   FrameWithView frame(input_sample_rate_hz());
-  vad_wrapper->Initialize(input_sample_rate_hz());
   vad_wrapper->Analyze(frame.view);
 }
 
diff --git a/modules/audio_processing/gain_controller2.cc b/modules/audio_processing/gain_controller2.cc
index 416bc78..a21ef72 100644
--- a/modules/audio_processing/gain_controller2.cc
+++ b/modules/audio_processing/gain_controller2.cc
@@ -14,6 +14,7 @@
 #include <utility>
 
 #include "common_audio/include/audio_util.h"
+#include "modules/audio_processing/agc2/cpu_features.h"
 #include "modules/audio_processing/audio_buffer.h"
 #include "modules/audio_processing/include/audio_frame_view.h"
 #include "modules/audio_processing/logging/apm_data_dumper.h"
@@ -21,6 +22,7 @@
 #include "rtc_base/checks.h"
 #include "rtc_base/logging.h"
 #include "rtc_base/strings/string_builder.h"
+#include "system_wrappers/include/field_trial.h"
 
 namespace webrtc {
 namespace {
@@ -33,6 +35,21 @@
 constexpr int kLogLimiterStatsPeriodNumFrames =
     kLogLimiterStatsPeriodMs / kFrameLengthMs;
 
+// Detects the available CPU features and applies any kill-switches.
+AvailableCpuFeatures GetAllowedCpuFeatures() {
+  AvailableCpuFeatures features = GetAvailableCpuFeatures();
+  if (field_trial::IsEnabled("WebRTC-Agc2SimdSse2KillSwitch")) {
+    features.sse2 = false;
+  }
+  if (field_trial::IsEnabled("WebRTC-Agc2SimdAvx2KillSwitch")) {
+    features.avx2 = false;
+  }
+  if (field_trial::IsEnabled("WebRTC-Agc2SimdNeonKillSwitch")) {
+    features.neon = false;
+  }
+  return features;
+}
+
 // Creates an adaptive digital gain controller if enabled.
 std::unique_ptr<AdaptiveAgc> CreateAdaptiveDigitalController(
     const Agc2Config::AdaptiveDigital& config,
@@ -40,7 +57,8 @@
     int num_channels,
     ApmDataDumper* data_dumper) {
   if (config.enabled) {
-    // TODO(bugs.webrtc.org/7494): Also init with sample rate and num channels.
+    // TODO(bugs.webrtc.org/7494): Also init with sample rate and num
+    // channels.
     auto controller = std::make_unique<AdaptiveAgc>(data_dumper, config);
     // TODO(bugs.webrtc.org/7494): Remove once passed to the ctor.
     controller->Initialize(sample_rate_hz, num_channels);
@@ -56,7 +74,8 @@
 GainController2::GainController2(const Agc2Config& config,
                                  int sample_rate_hz,
                                  int num_channels)
-    : data_dumper_(rtc::AtomicOps::Increment(&instance_count_)),
+    : cpu_features_(GetAllowedCpuFeatures()),
+      data_dumper_(rtc::AtomicOps::Increment(&instance_count_)),
       fixed_gain_applier_(/*hard_clip_samples=*/false,
                           /*initial_gain_factor=*/0.0f),
       adaptive_digital_controller_(
@@ -71,6 +90,14 @@
   data_dumper_.InitiateNewSetOfRecordings();
   // TODO(bugs.webrtc.org/7494): Set gain when `fixed_gain_applier_` is init'd.
   fixed_gain_applier_.SetGainFactor(DbToRatio(config.fixed_digital.gain_db));
+  const bool use_vad = config.adaptive_digital.enabled;
+  if (use_vad) {
+    // TODO(bugs.webrtc.org/7494): Move `vad_reset_period_ms` from adaptive
+    // digital to gain controller 2 config.
+    vad_ = std::make_unique<VoiceActivityDetectorWrapper>(
+        config.adaptive_digital.vad_reset_period_ms, cpu_features_,
+        sample_rate_hz);
+  }
 }
 
 GainController2::~GainController2() = default;
@@ -82,6 +109,9 @@
              sample_rate_hz == AudioProcessing::kSampleRate48kHz);
   // TODO(bugs.webrtc.org/7494): Initialize `fixed_gain_applier_`.
   limiter_.SetSampleRate(sample_rate_hz);
+  if (vad_) {
+    vad_->Initialize(sample_rate_hz);
+  }
   if (adaptive_digital_controller_) {
     adaptive_digital_controller_->Initialize(sample_rate_hz, num_channels);
   }
@@ -104,10 +134,17 @@
   data_dumper_.DumpRaw("agc2_notified_analog_level", analog_level_);
   AudioFrameView<float> float_frame(audio->channels(), audio->num_channels(),
                                     audio->num_frames());
+  absl::optional<float> speech_probability;
+  // TODO(bugs.webrtc.org/7494): Apply fixed digital gain after VAD.
   fixed_gain_applier_.ApplyGain(float_frame);
+  if (vad_) {
+    speech_probability = vad_->Analyze(float_frame);
+    data_dumper_.DumpRaw("agc2_speech_probability", speech_probability.value());
+  }
   if (adaptive_digital_controller_) {
-    adaptive_digital_controller_->Process(float_frame,
-                                          limiter_.LastAudioLevel());
+    RTC_DCHECK(speech_probability.has_value());
+    adaptive_digital_controller_->Process(
+        float_frame, speech_probability.value(), limiter_.LastAudioLevel());
   }
   limiter_.Process(float_frame);
 
diff --git a/modules/audio_processing/gain_controller2.h b/modules/audio_processing/gain_controller2.h
index cfe51c2..7e28ee2 100644
--- a/modules/audio_processing/gain_controller2.h
+++ b/modules/audio_processing/gain_controller2.h
@@ -15,8 +15,10 @@
 #include <string>
 
 #include "modules/audio_processing/agc2/adaptive_agc.h"
+#include "modules/audio_processing/agc2/cpu_features.h"
 #include "modules/audio_processing/agc2/gain_applier.h"
 #include "modules/audio_processing/agc2/limiter.h"
+#include "modules/audio_processing/agc2/vad_wrapper.h"
 #include "modules/audio_processing/include/audio_processing.h"
 #include "modules/audio_processing/logging/apm_data_dumper.h"
 
@@ -51,8 +53,10 @@
 
  private:
   static int instance_count_;
+  const AvailableCpuFeatures cpu_features_;
   ApmDataDumper data_dumper_;
   GainApplier fixed_gain_applier_;
+  std::unique_ptr<VoiceActivityDetectorWrapper> vad_;
   std::unique_ptr<AdaptiveAgc> adaptive_digital_controller_;
   Limiter limiter_;
   int calls_since_last_limiter_log_;