AGC2: VadWithLevel -> VoiceActivityDetectorWrapper 2/2

Internal refactoring of AGC2 to decouple the VAD, its wrapper and the
peak and RMS level measurements.

Bit exactness verified with audioproc_f on a collection of AEC dumps
and Wav files (42 recordings in total).

Bug: webrtc:7494
Change-Id: Ib560f1fcaa601557f4f30e47025c69e91b1b62e0
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/234524
Commit-Queue: Alessio Bazzica <alessiob@webrtc.org>
Reviewed-by: Hanna Silen <silen@webrtc.org>
Cr-Commit-Position: refs/heads/main@{#35208}
diff --git a/modules/audio_processing/agc2/BUILD.gn b/modules/audio_processing/agc2/BUILD.gn
index a897c0b1..f767a6d 100644
--- a/modules/audio_processing/agc2/BUILD.gn
+++ b/modules/audio_processing/agc2/BUILD.gn
@@ -280,6 +280,7 @@
     ":common",
     ":vad_wrapper",
     "..:audio_frame_view",
+    "../../../rtc_base:checks",
     "../../../rtc_base:gunit_helpers",
     "../../../rtc_base:safe_compare",
     "../../../test:test_support",
diff --git a/modules/audio_processing/agc2/adaptive_agc.cc b/modules/audio_processing/agc2/adaptive_agc.cc
index ab1822d..fb06549 100644
--- a/modules/audio_processing/agc2/adaptive_agc.cc
+++ b/modules/audio_processing/agc2/adaptive_agc.cc
@@ -36,6 +36,24 @@
   return features;
 }
 
+// Peak and RMS audio levels in dBFS.
+struct AudioLevels {
+  float peak_dbfs;
+  float rms_dbfs;
+};
+
+// Computes the audio levels for the first channel in `frame`.
+AudioLevels ComputeAudioLevels(AudioFrameView<float> frame) {
+  float peak = 0.0f;
+  float rms = 0.0f;
+  for (const auto& x : frame.channel(0)) {
+    peak = std::max(std::fabs(x), peak);
+    rms += x * x;
+  }
+  return {FloatS16ToDbfs(peak),
+          FloatS16ToDbfs(std::sqrt(rms / frame.samples_per_channel()))};
+}
+
 }  // namespace
 
 AdaptiveAgc::AdaptiveAgc(
@@ -62,16 +80,17 @@
 }
 
 void AdaptiveAgc::Process(AudioFrameView<float> frame, float limiter_envelope) {
+  AudioLevels levels = ComputeAudioLevels(frame);
+
   AdaptiveDigitalGainApplier::FrameInfo info;
 
-  VadLevelAnalyzer::Result vad_result = vad_.AnalyzeFrame(frame);
-  info.speech_probability = vad_result.speech_probability;
-  apm_data_dumper_->DumpRaw("agc2_speech_probability",
-                            vad_result.speech_probability);
-  apm_data_dumper_->DumpRaw("agc2_input_rms_dbfs", vad_result.rms_dbfs);
-  apm_data_dumper_->DumpRaw("agc2_input_peak_dbfs", vad_result.peak_dbfs);
+  info.speech_probability = vad_.Analyze(frame);
+  apm_data_dumper_->DumpRaw("agc2_speech_probability", info.speech_probability);
+  apm_data_dumper_->DumpRaw("agc2_input_rms_dbfs", levels.rms_dbfs);
+  apm_data_dumper_->DumpRaw("agc2_input_peak_dbfs", levels.peak_dbfs);
 
-  speech_level_estimator_.Update(vad_result);
+  speech_level_estimator_.Update(levels.rms_dbfs, levels.peak_dbfs,
+                                 info.speech_probability);
   info.speech_level_dbfs = speech_level_estimator_.level_dbfs();
   info.speech_level_reliable = speech_level_estimator_.IsConfident();
   apm_data_dumper_->DumpRaw("agc2_speech_level_dbfs", info.speech_level_dbfs);
@@ -81,7 +100,7 @@
   info.noise_rms_dbfs = noise_level_estimator_->Analyze(frame);
   apm_data_dumper_->DumpRaw("agc2_noise_rms_dbfs", info.noise_rms_dbfs);
 
-  saturation_protector_->Analyze(info.speech_probability, vad_result.peak_dbfs,
+  saturation_protector_->Analyze(info.speech_probability, levels.peak_dbfs,
                                  info.speech_level_dbfs);
   info.headroom_db = saturation_protector_->HeadroomDb();
   apm_data_dumper_->DumpRaw("agc2_headroom_db", info.headroom_db);
diff --git a/modules/audio_processing/agc2/adaptive_agc.h b/modules/audio_processing/agc2/adaptive_agc.h
index 8ee8378..32de680 100644
--- a/modules/audio_processing/agc2/adaptive_agc.h
+++ b/modules/audio_processing/agc2/adaptive_agc.h
@@ -47,7 +47,7 @@
 
  private:
   AdaptiveModeLevelEstimator speech_level_estimator_;
-  VadLevelAnalyzer vad_;
+  VoiceActivityDetectorWrapper vad_;
   AdaptiveDigitalGainApplier gain_controller_;
   ApmDataDumper* const apm_data_dumper_;
   std::unique_ptr<NoiseLevelEstimator> noise_level_estimator_;
diff --git a/modules/audio_processing/agc2/adaptive_mode_level_estimator.cc b/modules/audio_processing/agc2/adaptive_mode_level_estimator.cc
index 81e7d29..fe021fe 100644
--- a/modules/audio_processing/agc2/adaptive_mode_level_estimator.cc
+++ b/modules/audio_processing/agc2/adaptive_mode_level_estimator.cc
@@ -57,15 +57,16 @@
   Reset();
 }
 
-void AdaptiveModeLevelEstimator::Update(
-    const VadLevelAnalyzer::Result& vad_level) {
-  RTC_DCHECK_GT(vad_level.rms_dbfs, -150.f);
-  RTC_DCHECK_LT(vad_level.rms_dbfs, 50.f);
-  RTC_DCHECK_GT(vad_level.peak_dbfs, -150.f);
-  RTC_DCHECK_LT(vad_level.peak_dbfs, 50.f);
-  RTC_DCHECK_GE(vad_level.speech_probability, 0.f);
-  RTC_DCHECK_LE(vad_level.speech_probability, 1.f);
-  if (vad_level.speech_probability < kVadConfidenceThreshold) {
+void AdaptiveModeLevelEstimator::Update(float rms_dbfs,
+                                        float peak_dbfs,
+                                        float speech_probability) {
+  RTC_DCHECK_GT(rms_dbfs, -150.0f);
+  RTC_DCHECK_LT(rms_dbfs, 50.0f);
+  RTC_DCHECK_GT(peak_dbfs, -150.0f);
+  RTC_DCHECK_LT(peak_dbfs, 50.0f);
+  RTC_DCHECK_GE(speech_probability, 0.0f);
+  RTC_DCHECK_LE(speech_probability, 1.0f);
+  if (speech_probability < kVadConfidenceThreshold) {
     // Not a speech frame.
     if (adjacent_speech_frames_threshold_ > 1) {
       // When two or more adjacent speech frames are required in order to update
@@ -93,14 +94,14 @@
       preliminary_state_.time_to_confidence_ms -= kFrameDurationMs;
     }
     // Weighted average of levels with speech probability as weight.
-    RTC_DCHECK_GT(vad_level.speech_probability, 0.f);
-    const float leak_factor = buffer_is_full ? kLevelEstimatorLeakFactor : 1.f;
+    RTC_DCHECK_GT(speech_probability, 0.0f);
+    const float leak_factor = buffer_is_full ? kLevelEstimatorLeakFactor : 1.0f;
     preliminary_state_.level_dbfs.numerator =
         preliminary_state_.level_dbfs.numerator * leak_factor +
-        vad_level.rms_dbfs * vad_level.speech_probability;
+        rms_dbfs * speech_probability;
     preliminary_state_.level_dbfs.denominator =
         preliminary_state_.level_dbfs.denominator * leak_factor +
-        vad_level.speech_probability;
+        speech_probability;
 
     const float level_dbfs = preliminary_state_.level_dbfs.GetRatio();
 
diff --git a/modules/audio_processing/agc2/adaptive_mode_level_estimator.h b/modules/audio_processing/agc2/adaptive_mode_level_estimator.h
index 14da6b7..989c8c3 100644
--- a/modules/audio_processing/agc2/adaptive_mode_level_estimator.h
+++ b/modules/audio_processing/agc2/adaptive_mode_level_estimator.h
@@ -33,7 +33,7 @@
       delete;
 
   // Updates the level estimation.
-  void Update(const VadLevelAnalyzer::Result& vad_data);
+  void Update(float rms_dbfs, float peak_dbfs, float speech_probability);
   // Returns the estimated speech plus noise level.
   float level_dbfs() const { return level_dbfs_; }
   // Returns true if the estimator is confident on its current estimate.
diff --git a/modules/audio_processing/agc2/adaptive_mode_level_estimator_unittest.cc b/modules/audio_processing/agc2/adaptive_mode_level_estimator_unittest.cc
index 1cdd91d..684fca1 100644
--- a/modules/audio_processing/agc2/adaptive_mode_level_estimator_unittest.cc
+++ b/modules/audio_processing/agc2/adaptive_mode_level_estimator_unittest.cc
@@ -33,10 +33,12 @@
 
 // Provides the `vad_level` value `num_iterations` times to `level_estimator`.
 void RunOnConstantLevel(int num_iterations,
-                        const VadLevelAnalyzer::Result& vad_level,
+                        float rms_dbfs,
+                        float peak_dbfs,
+                        float speech_probability,
                         AdaptiveModeLevelEstimator& level_estimator) {
   for (int i = 0; i < num_iterations; ++i) {
-    level_estimator.Update(vad_level);
+    level_estimator.Update(rms_dbfs, peak_dbfs, speech_probability);
   }
 }
 
@@ -47,6 +49,10 @@
   return config;
 }
 
+constexpr float kNoSpeechProbability = 0.0f;
+constexpr float kLowSpeechProbability = kVadConfidenceThreshold / 2.0f;
+constexpr float kMaxSpeechProbability = 1.0f;
+
 // Level estimator with data dumper.
 struct TestLevelEstimator {
   explicit TestLevelEstimator(int adjacent_speech_frames_threshold)
@@ -55,36 +61,31 @@
             &data_dumper,
             GetAdaptiveDigitalConfig(adjacent_speech_frames_threshold))),
         initial_speech_level_dbfs(estimator->level_dbfs()),
-        vad_level_rms(initial_speech_level_dbfs / 2.0f),
-        vad_level_peak(initial_speech_level_dbfs / 3.0f),
-        vad_data_speech(
-            {/*speech_probability=*/1.0f, vad_level_rms, vad_level_peak}),
-        vad_data_non_speech(
-            {/*speech_probability=*/kVadConfidenceThreshold / 2.0f,
-             vad_level_rms, vad_level_peak}) {
-    RTC_DCHECK_LT(vad_level_rms, vad_level_peak);
-    RTC_DCHECK_LT(initial_speech_level_dbfs, vad_level_rms);
-    RTC_DCHECK_GT(vad_level_rms - initial_speech_level_dbfs, 5.0f)
-        << "Adjust `vad_level_rms` so that the difference from the initial "
+        level_rms_dbfs(initial_speech_level_dbfs / 2.0f),
+        level_peak_dbfs(initial_speech_level_dbfs / 3.0f) {
+    RTC_DCHECK_LT(level_rms_dbfs, level_peak_dbfs);
+    RTC_DCHECK_LT(initial_speech_level_dbfs, level_rms_dbfs);
+    RTC_DCHECK_GT(level_rms_dbfs - initial_speech_level_dbfs, 5.0f)
+        << "Adjust `level_rms_dbfs` so that the difference from the initial "
            "level is wide enough for the tests";
   }
   ApmDataDumper data_dumper;
   std::unique_ptr<AdaptiveModeLevelEstimator> estimator;
   const float initial_speech_level_dbfs;
-  const float vad_level_rms;
-  const float vad_level_peak;
-  const VadLevelAnalyzer::Result vad_data_speech;
-  const VadLevelAnalyzer::Result vad_data_non_speech;
+  const float level_rms_dbfs;
+  const float level_peak_dbfs;
 };
 
 // Checks that the level estimator converges to a constant input speech level.
 TEST(GainController2AdaptiveModeLevelEstimator, LevelStabilizes) {
   TestLevelEstimator level_estimator(/*adjacent_speech_frames_threshold=*/1);
   RunOnConstantLevel(/*num_iterations=*/kNumFramesToConfidence,
-                     level_estimator.vad_data_speech,
+                     level_estimator.level_rms_dbfs,
+                     level_estimator.level_peak_dbfs, kMaxSpeechProbability,
                      *level_estimator.estimator);
   const float estimated_level_dbfs = level_estimator.estimator->level_dbfs();
-  RunOnConstantLevel(/*num_iterations=*/1, level_estimator.vad_data_speech,
+  RunOnConstantLevel(/*num_iterations=*/1, level_estimator.level_rms_dbfs,
+                     level_estimator.level_peak_dbfs, kMaxSpeechProbability,
                      *level_estimator.estimator);
   EXPECT_NEAR(level_estimator.estimator->level_dbfs(), estimated_level_dbfs,
               0.1f);
@@ -95,7 +96,8 @@
 TEST(GainController2AdaptiveModeLevelEstimator, IsNotConfident) {
   TestLevelEstimator level_estimator(/*adjacent_speech_frames_threshold=*/1);
   RunOnConstantLevel(/*num_iterations=*/kNumFramesToConfidence / 2,
-                     level_estimator.vad_data_speech,
+                     level_estimator.level_rms_dbfs,
+                     level_estimator.level_peak_dbfs, kMaxSpeechProbability,
                      *level_estimator.estimator);
   EXPECT_FALSE(level_estimator.estimator->IsConfident());
 }
@@ -105,7 +107,8 @@
 TEST(GainController2AdaptiveModeLevelEstimator, IsConfident) {
   TestLevelEstimator level_estimator(/*adjacent_speech_frames_threshold=*/1);
   RunOnConstantLevel(/*num_iterations=*/kNumFramesToConfidence,
-                     level_estimator.vad_data_speech,
+                     level_estimator.level_rms_dbfs,
+                     level_estimator.level_peak_dbfs, kMaxSpeechProbability,
                      *level_estimator.estimator);
   EXPECT_TRUE(level_estimator.estimator->IsConfident());
 }
@@ -117,15 +120,14 @@
   TestLevelEstimator level_estimator(/*adjacent_speech_frames_threshold=*/1);
   // Simulate speech.
   RunOnConstantLevel(/*num_iterations=*/kNumFramesToConfidence,
-                     level_estimator.vad_data_speech,
+                     level_estimator.level_rms_dbfs,
+                     level_estimator.level_peak_dbfs, kMaxSpeechProbability,
                      *level_estimator.estimator);
   const float estimated_level_dbfs = level_estimator.estimator->level_dbfs();
   // Simulate full-scale non-speech.
   RunOnConstantLevel(/*num_iterations=*/kNumFramesToConfidence,
-                     VadLevelAnalyzer::Result{/*speech_probability=*/0.0f,
-                                              /*rms_dbfs=*/0.0f,
-                                              /*peak_dbfs=*/0.0f},
-                     *level_estimator.estimator);
+                     /*rms_dbfs=*/0.0f, /*peak_dbfs=*/0.0f,
+                     kNoSpeechProbability, *level_estimator.estimator);
   // No estimated level change is expected.
   EXPECT_FLOAT_EQ(level_estimator.estimator->level_dbfs(),
                   estimated_level_dbfs);
@@ -136,10 +138,11 @@
      ConvergenceSpeedBeforeConfidence) {
   TestLevelEstimator level_estimator(/*adjacent_speech_frames_threshold=*/1);
   RunOnConstantLevel(/*num_iterations=*/kNumFramesToConfidence,
-                     level_estimator.vad_data_speech,
+                     level_estimator.level_rms_dbfs,
+                     level_estimator.level_peak_dbfs, kMaxSpeechProbability,
                      *level_estimator.estimator);
   EXPECT_NEAR(level_estimator.estimator->level_dbfs(),
-              level_estimator.vad_data_speech.rms_dbfs,
+              level_estimator.level_rms_dbfs,
               kConvergenceSpeedTestsLevelTolerance);
 }
 
@@ -150,11 +153,9 @@
   // Reach confidence using the initial level estimate.
   RunOnConstantLevel(
       /*num_iterations=*/kNumFramesToConfidence,
-      VadLevelAnalyzer::Result{
-          /*speech_probability=*/1.0f,
-          /*rms_dbfs=*/level_estimator.initial_speech_level_dbfs,
-          /*peak_dbfs=*/level_estimator.initial_speech_level_dbfs + 6.0f},
-      *level_estimator.estimator);
+      /*rms_dbfs=*/level_estimator.initial_speech_level_dbfs,
+      /*peak_dbfs=*/level_estimator.initial_speech_level_dbfs + 6.0f,
+      kMaxSpeechProbability, *level_estimator.estimator);
   // No estimate change should occur, but confidence is achieved.
   ASSERT_FLOAT_EQ(level_estimator.estimator->level_dbfs(),
                   level_estimator.initial_speech_level_dbfs);
@@ -165,9 +166,10 @@
       kConvergenceTimeAfterConfidenceNumFrames > kNumFramesToConfidence, "");
   RunOnConstantLevel(
       /*num_iterations=*/kConvergenceTimeAfterConfidenceNumFrames,
-      level_estimator.vad_data_speech, *level_estimator.estimator);
+      level_estimator.level_rms_dbfs, level_estimator.level_peak_dbfs,
+      kMaxSpeechProbability, *level_estimator.estimator);
   EXPECT_NEAR(level_estimator.estimator->level_dbfs(),
-              level_estimator.vad_data_speech.rms_dbfs,
+              level_estimator.level_rms_dbfs,
               kConvergenceSpeedTestsLevelTolerance);
 }
 
@@ -181,22 +183,28 @@
        DoNotAdaptToShortSpeechSegments) {
   TestLevelEstimator level_estimator(adjacent_speech_frames_threshold());
   const float initial_level = level_estimator.estimator->level_dbfs();
-  ASSERT_LT(initial_level, level_estimator.vad_data_speech.peak_dbfs);
+  ASSERT_LT(initial_level, level_estimator.level_peak_dbfs);
   for (int i = 0; i < adjacent_speech_frames_threshold() - 1; ++i) {
     SCOPED_TRACE(i);
-    level_estimator.estimator->Update(level_estimator.vad_data_speech);
+    level_estimator.estimator->Update(level_estimator.level_rms_dbfs,
+                                      level_estimator.level_peak_dbfs,
+                                      kMaxSpeechProbability);
     EXPECT_EQ(initial_level, level_estimator.estimator->level_dbfs());
   }
-  level_estimator.estimator->Update(level_estimator.vad_data_non_speech);
+  level_estimator.estimator->Update(level_estimator.level_rms_dbfs,
+                                    level_estimator.level_peak_dbfs,
+                                    kLowSpeechProbability);
   EXPECT_EQ(initial_level, level_estimator.estimator->level_dbfs());
 }
 
 TEST_P(AdaptiveModeLevelEstimatorParametrization, AdaptToEnoughSpeechSegments) {
   TestLevelEstimator level_estimator(adjacent_speech_frames_threshold());
   const float initial_level = level_estimator.estimator->level_dbfs();
-  ASSERT_LT(initial_level, level_estimator.vad_data_speech.peak_dbfs);
+  ASSERT_LT(initial_level, level_estimator.level_peak_dbfs);
   for (int i = 0; i < adjacent_speech_frames_threshold(); ++i) {
-    level_estimator.estimator->Update(level_estimator.vad_data_speech);
+    level_estimator.estimator->Update(level_estimator.level_rms_dbfs,
+                                      level_estimator.level_peak_dbfs,
+                                      kMaxSpeechProbability);
   }
   EXPECT_LT(initial_level, level_estimator.estimator->level_dbfs());
 }
diff --git a/modules/audio_processing/agc2/vad_wrapper.cc b/modules/audio_processing/agc2/vad_wrapper.cc
index 94d5f67..7b61aee 100644
--- a/modules/audio_processing/agc2/vad_wrapper.cc
+++ b/modules/audio_processing/agc2/vad_wrapper.cc
@@ -10,13 +10,10 @@
 
 #include "modules/audio_processing/agc2/vad_wrapper.h"
 
-#include <algorithm>
 #include <array>
-#include <cmath>
 #include <utility>
 
 #include "api/array_view.h"
-#include "common_audio/include/audio_util.h"
 #include "common_audio/resampler/include/push_resampler.h"
 #include "modules/audio_processing/agc2/agc2_common.h"
 #include "modules/audio_processing/agc2/rnn_vad/common.h"
@@ -27,82 +24,72 @@
 namespace webrtc {
 namespace {
 
-using VoiceActivityDetector = VadLevelAnalyzer::VoiceActivityDetector;
+constexpr int kNumFramesPerSecond = 100;
 
-// Default VAD that combines a resampler and the RNN VAD.
-// Computes the speech probability on the first channel.
-class Vad : public VoiceActivityDetector {
+class MonoVadImpl : public VoiceActivityDetectorWrapper::MonoVad {
  public:
-  explicit Vad(const AvailableCpuFeatures& cpu_features)
+  explicit MonoVadImpl(const AvailableCpuFeatures& cpu_features)
       : features_extractor_(cpu_features), rnn_vad_(cpu_features) {}
-  Vad(const Vad&) = delete;
-  Vad& operator=(const Vad&) = delete;
-  ~Vad() = default;
+  MonoVadImpl(const MonoVadImpl&) = delete;
+  MonoVadImpl& operator=(const MonoVadImpl&) = delete;
+  ~MonoVadImpl() = default;
 
+  int SampleRateHz() const override { return rnn_vad::kSampleRate24kHz; }
   void Reset() override { rnn_vad_.Reset(); }
-
-  float ComputeProbability(AudioFrameView<const float> frame) override {
-    // The source number of channels is 1, because we always use the 1st
-    // channel.
-    resampler_.InitializeIfNeeded(
-        /*sample_rate_hz=*/static_cast<int>(frame.samples_per_channel() * 100),
-        rnn_vad::kSampleRate24kHz,
-        /*num_channels=*/1);
-
-    std::array<float, rnn_vad::kFrameSize10ms24kHz> work_frame;
-    // Feed the 1st channel to the resampler.
-    resampler_.Resample(frame.channel(0).data(), frame.samples_per_channel(),
-                        work_frame.data(), rnn_vad::kFrameSize10ms24kHz);
-
+  float Analyze(rtc::ArrayView<const float> frame) override {
+    RTC_DCHECK_EQ(frame.size(), rnn_vad::kFrameSize10ms24kHz);
     std::array<float, rnn_vad::kFeatureVectorSize> feature_vector;
     const bool is_silence = features_extractor_.CheckSilenceComputeFeatures(
-        work_frame, feature_vector);
+        /*samples=*/{frame.data(), rnn_vad::kFrameSize10ms24kHz},
+        feature_vector);
     return rnn_vad_.ComputeVadProbability(feature_vector, is_silence);
   }
 
  private:
-  PushResampler<float> resampler_;
   rnn_vad::FeaturesExtractor features_extractor_;
   rnn_vad::RnnVad rnn_vad_;
 };
 
 }  // namespace
 
-VadLevelAnalyzer::VadLevelAnalyzer(int vad_reset_period_ms,
-                                   const AvailableCpuFeatures& cpu_features)
-    : VadLevelAnalyzer(vad_reset_period_ms,
-                       std::make_unique<Vad>(cpu_features)) {}
+VoiceActivityDetectorWrapper::VoiceActivityDetectorWrapper(
+    int vad_reset_period_ms,
+    const AvailableCpuFeatures& cpu_features)
+    : VoiceActivityDetectorWrapper(
+          vad_reset_period_ms,
+          std::make_unique<MonoVadImpl>(cpu_features)) {}
 
-VadLevelAnalyzer::VadLevelAnalyzer(int vad_reset_period_ms,
-                                   std::unique_ptr<VoiceActivityDetector> vad)
-    : vad_(std::move(vad)),
-      vad_reset_period_frames_(
+VoiceActivityDetectorWrapper::VoiceActivityDetectorWrapper(
+    int vad_reset_period_ms,
+    std::unique_ptr<MonoVad> vad)
+    : vad_reset_period_frames_(
           rtc::CheckedDivExact(vad_reset_period_ms, kFrameDurationMs)),
-      time_to_vad_reset_(vad_reset_period_frames_) {
+      time_to_vad_reset_(vad_reset_period_frames_),
+      vad_(std::move(vad)) {
   RTC_DCHECK(vad_);
   RTC_DCHECK_GT(vad_reset_period_frames_, 1);
+  resampled_buffer_.resize(
+      rtc::CheckedDivExact(vad_->SampleRateHz(), kNumFramesPerSecond));
 }
 
-VadLevelAnalyzer::~VadLevelAnalyzer() = default;
+VoiceActivityDetectorWrapper::~VoiceActivityDetectorWrapper() = default;
 
-VadLevelAnalyzer::Result VadLevelAnalyzer::AnalyzeFrame(
-    AudioFrameView<const float> frame) {
+float VoiceActivityDetectorWrapper::Analyze(AudioFrameView<const float> frame) {
   // Periodically reset the VAD.
   time_to_vad_reset_--;
   if (time_to_vad_reset_ <= 0) {
     vad_->Reset();
     time_to_vad_reset_ = vad_reset_period_frames_;
   }
-  // Compute levels.
-  float peak = 0.0f;
-  float rms = 0.0f;
-  for (const auto& x : frame.channel(0)) {
-    peak = std::max(std::fabs(x), peak);
-    rms += x * x;
-  }
-  return {vad_->ComputeProbability(frame),
-          FloatS16ToDbfs(std::sqrt(rms / frame.samples_per_channel())),
-          FloatS16ToDbfs(peak)};
+
+  // Resample the first channel of `frame`.
+  resampler_.InitializeIfNeeded(
+      /*sample_rate_hz=*/frame.samples_per_channel() * kNumFramesPerSecond,
+      vad_->SampleRateHz(), /*num_channels=*/1);
+  resampler_.Resample(frame.channel(0).data(), frame.samples_per_channel(),
+                      resampled_buffer_.data(), resampled_buffer_.size());
+
+  return vad_->Analyze(resampled_buffer_);
 }
 
 }  // namespace webrtc
diff --git a/modules/audio_processing/agc2/vad_wrapper.h b/modules/audio_processing/agc2/vad_wrapper.h
index de73eab..f17fcda 100644
--- a/modules/audio_processing/agc2/vad_wrapper.h
+++ b/modules/audio_processing/agc2/vad_wrapper.h
@@ -12,51 +12,57 @@
 #define MODULES_AUDIO_PROCESSING_AGC2_VAD_WRAPPER_H_
 
 #include <memory>
+#include <vector>
 
+#include "api/array_view.h"
+#include "common_audio/resampler/include/push_resampler.h"
 #include "modules/audio_processing/agc2/cpu_features.h"
 #include "modules/audio_processing/include/audio_frame_view.h"
 
 namespace webrtc {
 
-// Class to analyze voice activity and audio levels.
-class VadLevelAnalyzer {
+// Wraps a single-channel Voice Activity Detector (VAD) which is used to analyze
+// the first channel of the input audio frames. Takes care of resampling the
+// input frames to match the sample rate of the wrapped VAD and periodically
+// resets the VAD.
+class VoiceActivityDetectorWrapper {
  public:
-  struct Result {
-    float speech_probability;  // Range: [0, 1].
-    float rms_dbfs;            // Root mean square power (dBFS).
-    float peak_dbfs;           // Peak power (dBFS).
-  };
-
-  // Voice Activity Detector (VAD) interface.
-  class VoiceActivityDetector {
+  // Single channel VAD interface.
+  class MonoVad {
    public:
-    virtual ~VoiceActivityDetector() = default;
+    virtual ~MonoVad() = default;
+    // Returns the sample rate (Hz) required for the input frames analyzed by
+    // `ComputeProbability`.
+    virtual int SampleRateHz() const = 0;
     // Resets the internal state.
     virtual void Reset() = 0;
     // Analyzes an audio frame and returns the speech probability.
-    virtual float ComputeProbability(AudioFrameView<const float> frame) = 0;
+    virtual float Analyze(rtc::ArrayView<const float> frame) = 0;
   };
 
   // Ctor. `vad_reset_period_ms` indicates the period in milliseconds to call
-  // `VadLevelAnalyzer::Reset()`; it must be equal to or greater than the
-  // duration of two frames. Uses `cpu_features` to instantiate the default VAD.
-  VadLevelAnalyzer(int vad_reset_period_ms,
-                   const AvailableCpuFeatures& cpu_features);
+  // `MonoVad::Reset()`; it must be equal to or greater than the duration of two
+  // frames. Uses `cpu_features` to instantiate the default VAD.
+  VoiceActivityDetectorWrapper(int vad_reset_period_ms,
+                               const AvailableCpuFeatures& cpu_features);
   // Ctor. Uses a custom `vad`.
-  VadLevelAnalyzer(int vad_reset_period_ms,
-                   std::unique_ptr<VoiceActivityDetector> vad);
+  VoiceActivityDetectorWrapper(int vad_reset_period_ms,
+                               std::unique_ptr<MonoVad> vad);
 
-  VadLevelAnalyzer(const VadLevelAnalyzer&) = delete;
-  VadLevelAnalyzer& operator=(const VadLevelAnalyzer&) = delete;
-  ~VadLevelAnalyzer();
+  VoiceActivityDetectorWrapper(const VoiceActivityDetectorWrapper&) = delete;
+  VoiceActivityDetectorWrapper& operator=(const VoiceActivityDetectorWrapper&) =
+      delete;
+  ~VoiceActivityDetectorWrapper();
 
-  // Computes the speech probability and the level for `frame`.
-  Result AnalyzeFrame(AudioFrameView<const float> frame);
+  // Analyzes the first channel of `frame` and returns the speech probability.
+  float Analyze(AudioFrameView<const float> frame);
 
  private:
-  std::unique_ptr<VoiceActivityDetector> vad_;
   const int vad_reset_period_frames_;
   int time_to_vad_reset_;
+  PushResampler<float> resampler_;
+  std::unique_ptr<MonoVad> vad_;
+  std::vector<float> resampled_buffer_;
 };
 
 }  // namespace webrtc
diff --git a/modules/audio_processing/agc2/vad_wrapper_unittest.cc b/modules/audio_processing/agc2/vad_wrapper_unittest.cc
index a6e776c..c1f7029 100644
--- a/modules/audio_processing/agc2/vad_wrapper_unittest.cc
+++ b/modules/audio_processing/agc2/vad_wrapper_unittest.cc
@@ -18,6 +18,7 @@
 
 #include "modules/audio_processing/agc2/agc2_common.h"
 #include "modules/audio_processing/include/audio_frame_view.h"
+#include "rtc_base/checks.h"
 #include "rtc_base/gunit.h"
 #include "rtc_base/numerics/safe_compare.h"
 #include "test/gmock.h"
@@ -26,90 +27,78 @@
 namespace {
 
 using ::testing::AnyNumber;
+using ::testing::Return;
 using ::testing::ReturnRoundRobin;
+using ::testing::Truly;
 
 constexpr int kNoVadPeriodicReset =
     kFrameDurationMs * (std::numeric_limits<int>::max() / kFrameDurationMs);
 
-constexpr int kSampleRateHz = 8000;
+constexpr int kSampleRate8kHz = 8000;
 
-class MockVad : public VadLevelAnalyzer::VoiceActivityDetector {
+class MockVad : public VoiceActivityDetectorWrapper::MonoVad {
  public:
+  MOCK_METHOD(int, SampleRateHz, (), (const override));
   MOCK_METHOD(void, Reset, (), (override));
-  MOCK_METHOD(float,
-              ComputeProbability,
-              (AudioFrameView<const float> frame),
-              (override));
+  MOCK_METHOD(float, Analyze, (rtc::ArrayView<const float> frame), (override));
 };
 
-// Creates a `VadLevelAnalyzer` injecting a mock VAD which repeatedly returns
-// the next value from `speech_probabilities` until it reaches the end and will
-// restart from the beginning.
-std::unique_ptr<VadLevelAnalyzer> CreateVadLevelAnalyzerWithMockVad(
+// Creates a `VoiceActivityDetectorWrapper` injecting a mock VAD that
+// repeatedly returns the next value from `speech_probabilities` and that
+// restarts from the beginning when after the last element is returned.
+std::unique_ptr<VoiceActivityDetectorWrapper> CreateMockVadWrapper(
     int vad_reset_period_ms,
     const std::vector<float>& speech_probabilities,
     int expected_vad_reset_calls = 0) {
   auto vad = std::make_unique<MockVad>();
-  EXPECT_CALL(*vad, ComputeProbability)
+  EXPECT_CALL(*vad, SampleRateHz)
       .Times(AnyNumber())
-      .WillRepeatedly(ReturnRoundRobin(speech_probabilities));
+      .WillRepeatedly(Return(kSampleRate8kHz));
   if (expected_vad_reset_calls >= 0) {
     EXPECT_CALL(*vad, Reset).Times(expected_vad_reset_calls);
   }
-  return std::make_unique<VadLevelAnalyzer>(vad_reset_period_ms,
-                                            std::move(vad));
+  EXPECT_CALL(*vad, Analyze)
+      .Times(AnyNumber())
+      .WillRepeatedly(ReturnRoundRobin(speech_probabilities));
+  return std::make_unique<VoiceActivityDetectorWrapper>(vad_reset_period_ms,
+                                                        std::move(vad));
 }
 
 // 10 ms mono frame.
 struct FrameWithView {
   // Ctor. Initializes the frame samples with `value`.
-  explicit FrameWithView(float value = 0.0f)
-      : channel0(samples.data()),
-        view(&channel0, /*num_channels=*/1, samples.size()) {
-    samples.fill(value);
-  }
-  std::array<float, kSampleRateHz / 100> samples;
+  explicit FrameWithView(int sample_rate_hz = kSampleRate8kHz)
+      : samples(rtc::CheckedDivExact(sample_rate_hz, 100), 0.0f),
+        channel0(samples.data()),
+        view(&channel0, /*num_channels=*/1, samples.size()) {}
+  std::vector<float> samples;
   const float* const channel0;
   const AudioFrameView<const float> view;
 };
 
-TEST(GainController2VadLevelAnalyzer, RmsLessThanPeakLevel) {
-  auto analyzer = CreateVadLevelAnalyzerWithMockVad(
-      /*vad_reset_period_ms=*/1500,
-      /*speech_probabilities=*/{1.0f},
-      /*expected_vad_reset_calls=*/0);
-  // Handcrafted frame so that the average is lower than the peak value.
-  FrameWithView frame(1000.0f);  // Constant frame.
-  frame.samples[10] = 2000.0f;   // Except for one peak value.
-  // Compute audio frame levels.
-  auto levels_and_vad_prob = analyzer->AnalyzeFrame(frame.view);
-  EXPECT_LT(levels_and_vad_prob.rms_dbfs, levels_and_vad_prob.peak_dbfs);
-}
-
-// Checks that the expect VAD probabilities are returned.
-TEST(GainController2VadLevelAnalyzer, NoSpeechProbabilitySmoothing) {
+// Checks that the expected speech probabilities are returned.
+TEST(GainController2VoiceActivityDetectorWrapper, CheckSpeechProbabilities) {
   const std::vector<float> speech_probabilities{0.709f, 0.484f, 0.882f, 0.167f,
                                                 0.44f,  0.525f, 0.858f, 0.314f,
                                                 0.653f, 0.965f, 0.413f, 0.0f};
-  auto analyzer = CreateVadLevelAnalyzerWithMockVad(kNoVadPeriodicReset,
-                                                    speech_probabilities);
+  auto vad_wrapper =
+      CreateMockVadWrapper(kNoVadPeriodicReset, speech_probabilities);
   FrameWithView frame;
   for (int i = 0; rtc::SafeLt(i, speech_probabilities.size()); ++i) {
     SCOPED_TRACE(i);
-    EXPECT_EQ(speech_probabilities[i],
-              analyzer->AnalyzeFrame(frame.view).speech_probability);
+    EXPECT_EQ(speech_probabilities[i], vad_wrapper->Analyze(frame.view));
   }
 }
 
 // Checks that the VAD is not periodically reset.
-TEST(GainController2VadLevelAnalyzer, VadNoPeriodicReset) {
+TEST(GainController2VoiceActivityDetectorWrapper, VadNoPeriodicReset) {
   constexpr int kNumFrames = 19;
-  auto analyzer = CreateVadLevelAnalyzerWithMockVad(
-      kNoVadPeriodicReset, /*speech_probabilities=*/{1.0f},
-      /*expected_vad_reset_calls=*/0);
+  auto vad_wrapper =
+      CreateMockVadWrapper(kNoVadPeriodicReset, /*speech_probabilities=*/{1.0f},
+                           /*expected_vad_reset_calls=*/0);
   FrameWithView frame;
   for (int i = 0; i < kNumFrames; ++i) {
-    analyzer->AnalyzeFrame(frame.view);
+    vad_wrapper->Analyze(frame.view);
   }
 }
 
@@ -122,20 +111,52 @@
 
 // Checks that the VAD is periodically reset with the expected period.
 TEST_P(VadPeriodResetParametrization, VadPeriodicReset) {
-  auto analyzer = CreateVadLevelAnalyzerWithMockVad(
+  auto vad_wrapper = CreateMockVadWrapper(
       /*vad_reset_period_ms=*/vad_reset_period_frames() * kFrameDurationMs,
       /*speech_probabilities=*/{1.0f},
       /*expected_vad_reset_calls=*/num_frames() / vad_reset_period_frames());
   FrameWithView frame;
   for (int i = 0; i < num_frames(); ++i) {
-    analyzer->AnalyzeFrame(frame.view);
+    vad_wrapper->Analyze(frame.view);
   }
 }
 
-INSTANTIATE_TEST_SUITE_P(GainController2VadLevelAnalyzer,
+INSTANTIATE_TEST_SUITE_P(GainController2VoiceActivityDetectorWrapper,
                          VadPeriodResetParametrization,
                          ::testing::Combine(::testing::Values(1, 19, 123),
                                             ::testing::Values(2, 5, 20, 53)));
 
+class VadResamplingParametrization
+    : public ::testing::TestWithParam<std::tuple<int, int>> {
+ protected:
+  int input_sample_rate_hz() const { return std::get<0>(GetParam()); }
+  int vad_sample_rate_hz() const { return std::get<1>(GetParam()); }
+};
+
+// Checks that regardless of the input audio sample rate, the wrapped VAD
+// analyzes frames having the expected size, that is according to its internal
+// sample rate.
+TEST_P(VadResamplingParametrization, CheckResampledFrameSize) {
+  auto vad = std::make_unique<MockVad>();
+  EXPECT_CALL(*vad, SampleRateHz)
+      .Times(AnyNumber())
+      .WillRepeatedly(Return(vad_sample_rate_hz()));
+  EXPECT_CALL(*vad, Reset).Times(0);
+  EXPECT_CALL(*vad, Analyze(Truly([this](rtc::ArrayView<const float> frame) {
+    return rtc::SafeEq(frame.size(),
+                       rtc::CheckedDivExact(vad_sample_rate_hz(), 100));
+  }))).Times(1);
+  auto vad_wrapper = std::make_unique<VoiceActivityDetectorWrapper>(
+      kNoVadPeriodicReset, std::move(vad));
+  FrameWithView frame(input_sample_rate_hz());
+  vad_wrapper->Analyze(frame.view);
+}
+
+INSTANTIATE_TEST_SUITE_P(
+    GainController2VoiceActivityDetectorWrapper,
+    VadResamplingParametrization,
+    ::testing::Combine(::testing::Values(8000, 16000, 44100, 48000),
+                       ::testing::Values(6000, 8000, 12000, 16000, 24000)));
+
 }  // namespace
 }  // namespace webrtc