AudioProcessingImpl: Add a VAD submodule

Add a VoiceActivityDetectorWrapper submodule in AudioProcessingImpl
and enable injecting speech probability into GainController2.

Bug: webrtc:13663
Change-Id: I05e13b737d085b45ac8ce76660191867c56834c2
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/265166
Commit-Queue: Hanna Silen <silen@webrtc.org>
Reviewed-by: Alessio Bazzica <alessiob@webrtc.org>
Cr-Commit-Position: refs/heads/main@{#37275}
diff --git a/modules/audio_processing/BUILD.gn b/modules/audio_processing/BUILD.gn
index 649ed21..0bd4331 100644
--- a/modules/audio_processing/BUILD.gn
+++ b/modules/audio_processing/BUILD.gn
@@ -401,6 +401,7 @@
         "../../rtc_base/system:file_wrapper",
         "../../system_wrappers",
         "../../system_wrappers:denormal_disabler",
+        "../../test:field_trial",
         "../../test:fileutils",
         "../../test:rtc_expect_death",
         "../../test:test_support",
diff --git a/modules/audio_processing/audio_processing_impl.cc b/modules/audio_processing/audio_processing_impl.cc
index fa5e929..5714d6b 100644
--- a/modules/audio_processing/audio_processing_impl.cc
+++ b/modules/audio_processing/audio_processing_impl.cc
@@ -162,6 +162,7 @@
     bool noise_suppressor_enabled,
     bool adaptive_gain_controller_enabled,
     bool gain_controller2_enabled,
+    bool voice_activity_detector_enabled,
     bool gain_adjustment_enabled,
     bool echo_controller_enabled,
     bool transient_suppressor_enabled) {
@@ -173,6 +174,8 @@
   changed |=
       (adaptive_gain_controller_enabled != adaptive_gain_controller_enabled_);
   changed |= (gain_controller2_enabled != gain_controller2_enabled_);
+  changed |=
+      (voice_activity_detector_enabled != voice_activity_detector_enabled_);
   changed |= (gain_adjustment_enabled != gain_adjustment_enabled_);
   changed |= (echo_controller_enabled != echo_controller_enabled_);
   changed |= (transient_suppressor_enabled != transient_suppressor_enabled_);
@@ -182,6 +185,7 @@
     noise_suppressor_enabled_ = noise_suppressor_enabled;
     adaptive_gain_controller_enabled_ = adaptive_gain_controller_enabled;
     gain_controller2_enabled_ = gain_controller2_enabled;
+    voice_activity_detector_enabled_ = voice_activity_detector_enabled;
     gain_adjustment_enabled_ = gain_adjustment_enabled;
     echo_controller_enabled_ = echo_controller_enabled;
     transient_suppressor_enabled_ = transient_suppressor_enabled;
@@ -395,6 +399,7 @@
   InitializeResidualEchoDetector();
   InitializeEchoController();
   InitializeGainController2(/*config_has_changed=*/true);
+  InitializeVoiceActivityDetector(/*config_has_changed=*/true);
   InitializeNoiseSuppressor();
   InitializeAnalyzer();
   InitializePostProcessor();
@@ -569,6 +574,7 @@
   }
 
   InitializeGainController2(agc2_config_changed);
+  InitializeVoiceActivityDetector(agc2_config_changed);
 
   if (pre_amplifier_config_changed || gain_adjustment_config_changed) {
     InitializeCaptureLevelsAdjuster();
@@ -1297,10 +1303,19 @@
       submodules_.capture_analyzer->Analyze(capture_buffer);
     }
 
+    absl::optional<float> voice_activity_probability = absl::nullopt;
     if (submodules_.gain_controller2) {
       submodules_.gain_controller2->NotifyAnalogLevel(
           recommended_stream_analog_level_locked());
-      submodules_.gain_controller2->Process(capture_buffer);
+      if (submodules_.voice_activity_detector) {
+        voice_activity_probability =
+            submodules_.voice_activity_detector->Analyze(
+                AudioFrameView<const float>(capture_buffer->channels(),
+                                            capture_buffer->num_channels(),
+                                            capture_buffer->num_frames()));
+      }
+      submodules_.gain_controller2->Process(voice_activity_probability,
+                                            capture_buffer);
     }
 
     if (submodules_.capture_post_processor) {
@@ -1692,7 +1707,7 @@
   return submodule_states_.Update(
       config_.high_pass_filter.enabled, !!submodules_.echo_control_mobile,
       !!submodules_.noise_suppressor, !!submodules_.gain_control,
-      !!submodules_.gain_controller2,
+      !!submodules_.gain_controller2, !!submodules_.voice_activity_detector,
       config_.pre_amplifier.enabled || config_.capture_level_adjustment.enabled,
       capture_nonlocked_.echo_controller_enabled,
       !!submodules_.transient_suppressor);
@@ -1900,9 +1915,35 @@
     return;
   }
   if (!submodules_.gain_controller2 || config_has_changed) {
+    const bool use_internal_vad =
+        transient_suppressor_vad_mode_ != TransientSuppressor::VadMode::kRnnVad;
     submodules_.gain_controller2 = std::make_unique<GainController2>(
         config_.gain_controller2, proc_fullband_sample_rate_hz(),
-        num_input_channels());
+        num_input_channels(), use_internal_vad);
+  }
+}
+
+void AudioProcessingImpl::InitializeVoiceActivityDetector(
+    bool config_has_changed) {
+  if (!config_has_changed) {
+    return;
+  }
+  const bool use_vad =
+      transient_suppressor_vad_mode_ == TransientSuppressor::VadMode::kRnnVad &&
+      config_.gain_controller2.enabled &&
+      config_.gain_controller2.adaptive_digital.enabled;
+  if (!use_vad) {
+    submodules_.voice_activity_detector.reset();
+    return;
+  }
+  if (!submodules_.voice_activity_detector || config_has_changed) {
+    RTC_DCHECK(!!submodules_.gain_controller2);
+    // TODO(bugs.webrtc.org/13663): Cache CPU features in APM and use here.
+    submodules_.voice_activity_detector =
+        std::make_unique<VoiceActivityDetectorWrapper>(
+            config_.gain_controller2.adaptive_digital.vad_reset_period_ms,
+            submodules_.gain_controller2->GetCpuFeatures(),
+            proc_fullband_sample_rate_hz());
   }
 }
 
diff --git a/modules/audio_processing/audio_processing_impl.h b/modules/audio_processing/audio_processing_impl.h
index 0fc6c9f..974089b 100644
--- a/modules/audio_processing/audio_processing_impl.h
+++ b/modules/audio_processing/audio_processing_impl.h
@@ -207,6 +207,7 @@
                 bool noise_suppressor_enabled,
                 bool adaptive_gain_controller_enabled,
                 bool gain_controller2_enabled,
+                bool voice_activity_detector_enabled,
                 bool gain_adjustment_enabled,
                 bool echo_controller_enabled,
                 bool transient_suppressor_enabled);
@@ -228,6 +229,7 @@
     bool mobile_echo_controller_enabled_ = false;
     bool noise_suppressor_enabled_ = false;
     bool adaptive_gain_controller_enabled_ = false;
+    bool voice_activity_detector_enabled_ = false;
     bool gain_controller2_enabled_ = false;
     bool gain_adjustment_enabled_ = false;
     bool echo_controller_enabled_ = false;
@@ -273,6 +275,11 @@
   // and `config_has_changed` is true, recreates the sub-module.
   void InitializeGainController2(bool config_has_changed)
       RTC_EXCLUSIVE_LOCKS_REQUIRED(mutex_capture_);
+  // Initializes the `VoiceActivityDetectorWrapper` sub-module. If the
+  // sub-module is enabled and `config_has_changed` is true, recreates the
+  // sub-module.
+  void InitializeVoiceActivityDetector(bool config_has_changed)
+      RTC_EXCLUSIVE_LOCKS_REQUIRED(mutex_capture_);
   void InitializeNoiseSuppressor() RTC_EXCLUSIVE_LOCKS_REQUIRED(mutex_capture_);
   void InitializeCaptureLevelsAdjuster()
       RTC_EXCLUSIVE_LOCKS_REQUIRED(mutex_capture_);
@@ -393,6 +400,7 @@
     std::unique_ptr<AgcManagerDirect> agc_manager;
     std::unique_ptr<GainControlImpl> gain_control;
     std::unique_ptr<GainController2> gain_controller2;
+    std::unique_ptr<VoiceActivityDetectorWrapper> voice_activity_detector;
     std::unique_ptr<HighPassFilter> high_pass_filter;
     std::unique_ptr<EchoControl> echo_controller;
     std::unique_ptr<EchoControlMobileImpl> echo_control_mobile;
diff --git a/modules/audio_processing/audio_processing_impl_unittest.cc b/modules/audio_processing/audio_processing_impl_unittest.cc
index 7d617bf..5e4e355 100644
--- a/modules/audio_processing/audio_processing_impl_unittest.cc
+++ b/modules/audio_processing/audio_processing_impl_unittest.cc
@@ -23,6 +23,7 @@
 #include "modules/audio_processing/test/test_utils.h"
 #include "rtc_base/checks.h"
 #include "rtc_base/random.h"
+#include "test/field_trial.h"
 #include "test/gmock.h"
 #include "test/gtest.h"
 
@@ -481,6 +482,78 @@
   apm->ProcessStream(frame.data(), stream_config, stream_config, frame.data());
 }
 
+TEST(AudioProcessingImplTest,
+     EchoControllerObservesNoDigitalAgc2EchoPathGainChange) {
+  // Tests that the echo controller doesn't observe an echo path gain change
+  // when the AGC2 digital submodule changes the digital gain without analog
+  // gain changes.
+  auto echo_control_factory = std::make_unique<MockEchoControlFactory>();
+  const auto* echo_control_factory_ptr = echo_control_factory.get();
+  rtc::scoped_refptr<AudioProcessing> apm =
+      AudioProcessingBuilderForTesting()
+          .SetEchoControlFactory(std::move(echo_control_factory))
+          .Create();
+  webrtc::AudioProcessing::Config apm_config;
+  // Disable AGC1 analog.
+  apm_config.gain_controller1.enabled = false;
+  // Enable AGC2 digital.
+  apm_config.gain_controller2.enabled = true;
+  apm_config.gain_controller2.adaptive_digital.enabled = true;
+  apm->ApplyConfig(apm_config);
+
+  constexpr int16_t kAudioLevel = 1000;
+  constexpr size_t kSampleRateHz = 48000;
+  constexpr size_t kNumChannels = 2;
+  std::array<int16_t, kNumChannels * kSampleRateHz / 100> frame;
+  StreamConfig stream_config(kSampleRateHz, kNumChannels);
+  frame.fill(kAudioLevel);
+
+  MockEchoControl* echo_control_mock = echo_control_factory_ptr->GetNext();
+
+  EXPECT_CALL(*echo_control_mock, AnalyzeCapture(testing::_)).Times(1);
+  EXPECT_CALL(*echo_control_mock, ProcessCapture(NotNull(), testing::_,
+                                                 /*echo_path_change=*/false))
+      .Times(1);
+  apm->ProcessStream(frame.data(), stream_config, stream_config, frame.data());
+
+  EXPECT_CALL(*echo_control_mock, AnalyzeCapture(testing::_)).Times(1);
+  EXPECT_CALL(*echo_control_mock, ProcessCapture(NotNull(), testing::_,
+                                                 /*echo_path_change=*/false))
+      .Times(1);
+  apm->ProcessStream(frame.data(), stream_config, stream_config, frame.data());
+}
+
+TEST(AudioProcessingImplTest, ProcessWithAgc2InjectedSpeechProbability) {
+  // Tests that a stream is successfully processed for the field trial
+  // `WebRTC-Audio-TransientSuppressorVadMode/Enabled-RnnVad/` using
+  // injected speech probability in AGC2 digital.
+  webrtc::test::ScopedFieldTrials field_trials(
+      "WebRTC-Audio-TransientSuppressorVadMode/Enabled-RnnVad/");
+  rtc::scoped_refptr<AudioProcessing> apm = AudioProcessingBuilder().Create();
+  ASSERT_EQ(apm->Initialize(), AudioProcessing::kNoError);
+  webrtc::AudioProcessing::Config apm_config;
+  // Disable AGC1 analog.
+  apm_config.gain_controller1.enabled = false;
+  // Enable AGC2 digital.
+  apm_config.gain_controller2.enabled = true;
+  apm_config.gain_controller2.adaptive_digital.enabled = true;
+  apm->ApplyConfig(apm_config);
+  constexpr int kSampleRateHz = 48000;
+  constexpr int kNumChannels = 1;
+  std::array<float, kSampleRateHz / 100> buffer;
+  float* channel_pointers[] = {buffer.data()};
+  StreamConfig stream_config(/*sample_rate_hz=*/kSampleRateHz,
+                             /*num_channels=*/kNumChannels);
+  Random random_generator(2341U);
+  constexpr int kFramesToProcess = 10;
+  for (int i = 0; i < kFramesToProcess; ++i) {
+    RandomizeSampleVector(&random_generator, buffer);
+    ASSERT_EQ(apm->ProcessStream(channel_pointers, stream_config, stream_config,
+                                 channel_pointers),
+              kNoErr);
+  }
+}
+
 TEST(AudioProcessingImplTest, EchoControllerObservesPlayoutVolumeChange) {
   // Tests that the echo controller observes an echo path gain change when a
   // playout volume change is reported.
diff --git a/modules/audio_processing/gain_controller2.cc b/modules/audio_processing/gain_controller2.cc
index 466e4b0..83a595e 100644
--- a/modules/audio_processing/gain_controller2.cc
+++ b/modules/audio_processing/gain_controller2.cc
@@ -69,7 +69,8 @@
 
 GainController2::GainController2(const Agc2Config& config,
                                  int sample_rate_hz,
-                                 int num_channels)
+                                 int num_channels,
+                                 bool use_internal_vad)
     : cpu_features_(GetAllowedCpuFeatures()),
       data_dumper_(rtc::AtomicOps::Increment(&instance_count_)),
       fixed_gain_applier_(
@@ -86,7 +87,7 @@
   RTC_DCHECK(Validate(config));
   data_dumper_.InitiateNewSetOfRecordings();
   const bool use_vad = config.adaptive_digital.enabled;
-  if (use_vad) {
+  if (use_vad && use_internal_vad) {
     // TODO(bugs.webrtc.org/7494): Move `vad_reset_period_ms` from adaptive
     // digital to gain controller 2 config.
     vad_ = std::make_unique<VoiceActivityDetectorWrapper>(
@@ -125,13 +126,18 @@
   fixed_gain_applier_.SetGainFactor(gain_factor);
 }
 
-void GainController2::Process(AudioBuffer* audio) {
+void GainController2::Process(absl::optional<float> speech_probability,
+                              AudioBuffer* audio) {
   data_dumper_.DumpRaw("agc2_notified_analog_level", analog_level_);
   AudioFrameView<float> float_frame(audio->channels(), audio->num_channels(),
                                     audio->num_frames());
-  absl::optional<float> speech_probability;
   if (vad_) {
     speech_probability = vad_->Analyze(float_frame);
+  } else if (speech_probability.has_value()) {
+    RTC_DCHECK_GE(speech_probability.value(), 0.0f);
+    RTC_DCHECK_LE(speech_probability.value(), 1.0f);
+  }
+  if (speech_probability.has_value()) {
     data_dumper_.DumpRaw("agc2_speech_probability", speech_probability.value());
   }
   fixed_gain_applier_.ApplyGain(float_frame);
diff --git a/modules/audio_processing/gain_controller2.h b/modules/audio_processing/gain_controller2.h
index 8c82d74..616f88a 100644
--- a/modules/audio_processing/gain_controller2.h
+++ b/modules/audio_processing/gain_controller2.h
@@ -30,9 +30,12 @@
 // microphone gain and/or applying digital gain.
 class GainController2 {
  public:
+  // Ctor. If `use_internal_vad` is true, an internal voice activity
+  // detector is used for digital adaptive gain.
   GainController2(const AudioProcessing::Config::GainController2& config,
                   int sample_rate_hz,
-                  int num_channels);
+                  int num_channels,
+                  bool use_internal_vad);
   GainController2(const GainController2&) = delete;
   GainController2& operator=(const GainController2&) = delete;
   ~GainController2();
@@ -44,13 +47,18 @@
   void SetFixedGainDb(float gain_db);
 
   // Applies fixed and adaptive digital gains to `audio` and runs a limiter.
-  void Process(AudioBuffer* audio);
+  // If the internal VAD is used, `speech_probability` is ignored. Otherwise
+  // `speech_probability` is used for digital adaptive gain if it's available
+  // (limited to values [0.0, 1.0]).
+  void Process(absl::optional<float> speech_probability, AudioBuffer* audio);
 
   // Handles analog level changes.
   void NotifyAnalogLevel(int level);
 
   static bool Validate(const AudioProcessing::Config::GainController2& config);
 
+  AvailableCpuFeatures GetCpuFeatures() const { return cpu_features_; }
+
  private:
   static int instance_count_;
   const AvailableCpuFeatures cpu_features_;
diff --git a/modules/audio_processing/gain_controller2_unittest.cc b/modules/audio_processing/gain_controller2_unittest.cc
index 850562f..88a93b0 100644
--- a/modules/audio_processing/gain_controller2_unittest.cc
+++ b/modules/audio_processing/gain_controller2_unittest.cc
@@ -47,7 +47,7 @@
   // Give time to the level estimator to converge.
   for (int i = 0; i < num_frames + 1; ++i) {
     SetAudioBufferSamples(input_level, ab);
-    agc2.Process(&ab);
+    agc2.Process(/*speech_probability=*/absl::nullopt, &ab);
   }
 
   // Return the last sample from the last processed frame.
@@ -62,7 +62,8 @@
   config.fixed_digital.gain_db = fixed_gain_db;
   EXPECT_TRUE(GainController2::Validate(config));
   return std::make_unique<GainController2>(config, sample_rate_hz,
-                                           /*num_channels=*/1);
+                                           /*num_channels=*/1,
+                                           /*use_internal_vad=*/true);
 }
 
 }  // namespace
@@ -138,7 +139,8 @@
 // Checks that the default config is applied.
 TEST(GainController2, ApplyDefaultConfig) {
   auto gain_controller2 = std::make_unique<GainController2>(
-      Agc2Config{}, /*sample_rate_hz=*/16000, /*num_channels=*/2);
+      Agc2Config{}, /*sample_rate_hz=*/16000, /*num_channels=*/2,
+      /*use_internal_vad=*/true);
   EXPECT_TRUE(gain_controller2.get());
 }
 
@@ -253,7 +255,8 @@
   Agc2Config config;
   config.fixed_digital.gain_db = 0.0f;
   config.adaptive_digital.enabled = true;
-  GainController2 agc2(config, kSampleRateHz, kStereo);
+  GainController2 agc2(config, kSampleRateHz, kStereo,
+                       /*use_internal_vad=*/true);
 
   test::InputAudioFile input_file(
       test::GetApmCaptureTestVectorFileName(kSampleRateHz),
@@ -276,16 +279,16 @@
                                    stream_config.num_channels(), &input_file,
                                    frame);
     // Apply a fixed gain to the input audio.
-    for (float& x : frame)
+    for (float& x : frame) {
       x *= gain;
+    }
     test::CopyVectorToAudioBuffer(stream_config, frame, &audio_buffer);
-    // Process.
-    agc2.Process(&audio_buffer);
+    agc2.Process(/*speech_probability=*/absl::nullopt, &audio_buffer);
   }
 
   // Estimate the applied gain by processing a probing frame.
   SetAudioBufferSamples(/*value=*/1.0f, audio_buffer);
-  agc2.Process(&audio_buffer);
+  agc2.Process(/*speech_probability=*/absl::nullopt, &audio_buffer);
   const float applied_gain_db =
       20.0f * std::log10(audio_buffer.channels_const()[0][0]);
 
@@ -294,5 +297,196 @@
   EXPECT_NEAR(applied_gain_db, kExpectedGainDb, kToleranceDb);
 }
 
+// Processes a test audio file and checks that the injected speech probability
+// is ignored when the internal VAD is used.
+TEST(GainController2,
+     CheckInjectedVadProbabilityNotUsedWithAdaptiveDigitalController) {
+  constexpr int kSampleRateHz = AudioProcessing::kSampleRate48kHz;
+  constexpr int kStereo = 2;
+
+  // Create AGC2 enabling only the adaptive digital controller.
+  Agc2Config config;
+  config.fixed_digital.gain_db = 0.0f;
+  config.adaptive_digital.enabled = true;
+  GainController2 agc2(config, kSampleRateHz, kStereo,
+                       /*use_internal_vad=*/true);
+  GainController2 agc2_reference(config, kSampleRateHz, kStereo,
+                                 /*use_internal_vad=*/true);
+
+  test::InputAudioFile input_file(
+      test::GetApmCaptureTestVectorFileName(kSampleRateHz),
+      /*loop_at_end=*/true);
+  const StreamConfig stream_config(kSampleRateHz, kStereo);
+
+  // Init buffers.
+  constexpr int kFrameDurationMs = 10;
+  std::vector<float> frame(kStereo * stream_config.num_frames());
+  AudioBuffer audio_buffer(kSampleRateHz, kStereo, kSampleRateHz, kStereo,
+                           kSampleRateHz, kStereo);
+  AudioBuffer audio_buffer_reference(kSampleRateHz, kStereo, kSampleRateHz,
+                                     kStereo, kSampleRateHz, kStereo);
+
+  // Simulate.
+  constexpr float kGainDb = -6.0f;
+  const float gain = std::pow(10.0f, kGainDb / 20.0f);
+  constexpr int kDurationMs = 10000;
+  constexpr int kNumFramesToProcess = kDurationMs / kFrameDurationMs;
+  constexpr float kSpeechProbabilities[] = {1.0f, 0.3f};
+  constexpr float kEpsilon = 0.0001f;
+  bool all_samples_zero = true;
+  for (int i = 0, j = 0; i < kNumFramesToProcess; ++i, j = 1 - j) {
+    ReadFloatSamplesFromStereoFile(stream_config.num_frames(),
+                                   stream_config.num_channels(), &input_file,
+                                   frame);
+    // Apply a fixed gain to the input audio.
+    for (float& x : frame) {
+      x *= gain;
+    }
+    test::CopyVectorToAudioBuffer(stream_config, frame, &audio_buffer);
+    agc2.Process(kSpeechProbabilities[j], &audio_buffer);
+    test::CopyVectorToAudioBuffer(stream_config, frame,
+                                  &audio_buffer_reference);
+    agc2_reference.Process(absl::nullopt, &audio_buffer_reference);
+
+    // Check the output buffers.
+    for (int i = 0; i < kStereo; ++i) {
+      for (int j = 0; j < static_cast<int>(audio_buffer.num_frames()); ++j) {
+        all_samples_zero &=
+            fabs(audio_buffer.channels_const()[i][j]) < kEpsilon;
+        EXPECT_FLOAT_EQ(audio_buffer.channels_const()[i][j],
+                        audio_buffer_reference.channels_const()[i][j]);
+      }
+    }
+  }
+  EXPECT_FALSE(all_samples_zero);
+}
+
+// Processes a test audio file and checks that the injected speech probability
+// is not ignored when the internal VAD is not used.
+TEST(GainController2,
+     CheckInjectedVadProbabilityUsedWithAdaptiveDigitalController) {
+  constexpr int kSampleRateHz = AudioProcessing::kSampleRate48kHz;
+  constexpr int kStereo = 2;
+
+  // Create AGC2 enabling only the adaptive digital controller.
+  Agc2Config config;
+  config.fixed_digital.gain_db = 0.0f;
+  config.adaptive_digital.enabled = true;
+  GainController2 agc2(config, kSampleRateHz, kStereo,
+                       /*use_internal_vad=*/false);
+  GainController2 agc2_reference(config, kSampleRateHz, kStereo,
+                                 /*use_internal_vad=*/true);
+
+  test::InputAudioFile input_file(
+      test::GetApmCaptureTestVectorFileName(kSampleRateHz),
+      /*loop_at_end=*/true);
+  const StreamConfig stream_config(kSampleRateHz, kStereo);
+
+  // Init buffers.
+  constexpr int kFrameDurationMs = 10;
+  std::vector<float> frame(kStereo * stream_config.num_frames());
+  AudioBuffer audio_buffer(kSampleRateHz, kStereo, kSampleRateHz, kStereo,
+                           kSampleRateHz, kStereo);
+  AudioBuffer audio_buffer_reference(kSampleRateHz, kStereo, kSampleRateHz,
+                                     kStereo, kSampleRateHz, kStereo);
+  // Simulate.
+  constexpr float kGainDb = -6.0f;
+  const float gain = std::pow(10.0f, kGainDb / 20.0f);
+  constexpr int kDurationMs = 10000;
+  constexpr int kNumFramesToProcess = kDurationMs / kFrameDurationMs;
+  constexpr float kSpeechProbabilities[] = {1.0f, 0.3f};
+  constexpr float kEpsilon = 0.0001f;
+  bool all_samples_zero = true;
+  bool all_samples_equal = true;
+  for (int i = 0, j = 0; i < kNumFramesToProcess; ++i, j = 1 - j) {
+    ReadFloatSamplesFromStereoFile(stream_config.num_frames(),
+                                   stream_config.num_channels(), &input_file,
+                                   frame);
+    // Apply a fixed gain to the input audio.
+    for (float& x : frame) {
+      x *= gain;
+    }
+    test::CopyVectorToAudioBuffer(stream_config, frame, &audio_buffer);
+    agc2.Process(kSpeechProbabilities[j], &audio_buffer);
+    test::CopyVectorToAudioBuffer(stream_config, frame,
+                                  &audio_buffer_reference);
+    agc2_reference.Process(absl::nullopt, &audio_buffer_reference);
+    // Check the output buffers.
+    for (int i = 0; i < kStereo; ++i) {
+      for (int j = 0; j < static_cast<int>(audio_buffer.num_frames()); ++j) {
+        all_samples_zero &=
+            fabs(audio_buffer.channels_const()[i][j]) < kEpsilon;
+        all_samples_equal &=
+            fabs(audio_buffer.channels_const()[i][j] -
+                 audio_buffer_reference.channels_const()[i][j]) < kEpsilon;
+      }
+    }
+  }
+  EXPECT_FALSE(all_samples_zero);
+  EXPECT_FALSE(all_samples_equal);
+}
+
+// Processes a test audio file and checks that the output is equal when
+// an injected speech probability from `VoiceActivityDetectorWrapper` and
+// the speech probability computed by the internal VAD are the same.
+TEST(GainController2,
+     CheckEqualResultFromInjectedVadProbabilityWithAdaptiveDigitalController) {
+  constexpr int kSampleRateHz = AudioProcessing::kSampleRate48kHz;
+  constexpr int kStereo = 2;
+
+  // Create AGC2 enabling only the adaptive digital controller.
+  Agc2Config config;
+  config.fixed_digital.gain_db = 0.0f;
+  config.adaptive_digital.enabled = true;
+  GainController2 agc2(config, kSampleRateHz, kStereo,
+                       /*use_internal_vad=*/false);
+  GainController2 agc2_reference(config, kSampleRateHz, kStereo,
+                                 /*use_internal_vad=*/true);
+  VoiceActivityDetectorWrapper vad(config.adaptive_digital.vad_reset_period_ms,
+                                   GetAvailableCpuFeatures(), kSampleRateHz);
+  test::InputAudioFile input_file(
+      test::GetApmCaptureTestVectorFileName(kSampleRateHz),
+      /*loop_at_end=*/true);
+  const StreamConfig stream_config(kSampleRateHz, kStereo);
+
+  // Init buffers.
+  constexpr int kFrameDurationMs = 10;
+  std::vector<float> frame(kStereo * stream_config.num_frames());
+  AudioBuffer audio_buffer(kSampleRateHz, kStereo, kSampleRateHz, kStereo,
+                           kSampleRateHz, kStereo);
+  AudioBuffer audio_buffer_reference(kSampleRateHz, kStereo, kSampleRateHz,
+                                     kStereo, kSampleRateHz, kStereo);
+
+  // Simulate.
+  constexpr float kGainDb = -6.0f;
+  const float gain = std::pow(10.0f, kGainDb / 20.0f);
+  constexpr int kDurationMs = 10000;
+  constexpr int kNumFramesToProcess = kDurationMs / kFrameDurationMs;
+  for (int i = 0; i < kNumFramesToProcess; ++i) {
+    ReadFloatSamplesFromStereoFile(stream_config.num_frames(),
+                                   stream_config.num_channels(), &input_file,
+                                   frame);
+    // Apply a fixed gain to the input audio.
+    for (float& x : frame) {
+      x *= gain;
+    }
+    test::CopyVectorToAudioBuffer(stream_config, frame,
+                                  &audio_buffer_reference);
+    agc2_reference.Process(absl::nullopt, &audio_buffer_reference);
+    test::CopyVectorToAudioBuffer(stream_config, frame, &audio_buffer);
+    agc2.Process(vad.Analyze(AudioFrameView<const float>(
+                     audio_buffer.channels(), audio_buffer.num_channels(),
+                     audio_buffer.num_frames())),
+                 &audio_buffer);
+    // Check the output buffer.
+    for (int i = 0; i < kStereo; ++i) {
+      for (int j = 0; j < static_cast<int>(audio_buffer.num_frames()); ++j) {
+        EXPECT_FLOAT_EQ(audio_buffer.channels_const()[i][j],
+                        audio_buffer_reference.channels_const()[i][j]);
+      }
+    }
+  }
+}
+
 }  // namespace test
 }  // namespace webrtc