AGC2: VAD wrapper, add `Initialize()` method

Not passing the sample rate to the `VoiceActivityDetectorWrapper` ctor
yet since that would require an unnecessary refactoring of `AdaptiveAgc`
which will soon be removed.
Instead, to ensure correct initialization until the child CL [1] lands,
`VoiceActivityDetectorWrapper::initialized_` is temporarily added.

Bit exactness verified with audioproc_f on a collection of AEC dumps
and Wav files (42 recordings in total).

[1] https://webrtc-review.googlesource.com/c/src/+/234583

Bug: webrtc:7494
Change-Id: I4b4be7b8106ba36c958d91bf263a7b30271a1ee3
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/234587
Commit-Queue: Alessio Bazzica <alessiob@webrtc.org>
Reviewed-by: Hanna Silen <silen@webrtc.org>
Cr-Commit-Position: refs/heads/main@{#35213}
diff --git a/modules/audio_processing/agc2/adaptive_agc.cc b/modules/audio_processing/agc2/adaptive_agc.cc
index fb06549..b543365 100644
--- a/modules/audio_processing/agc2/adaptive_agc.cc
+++ b/modules/audio_processing/agc2/adaptive_agc.cc
@@ -77,6 +77,7 @@
 
 void AdaptiveAgc::Initialize(int sample_rate_hz, int num_channels) {
   gain_controller_.Initialize(sample_rate_hz, num_channels);
+  vad_.Initialize(sample_rate_hz);
 }
 
 void AdaptiveAgc::Process(AudioFrameView<float> frame, float limiter_envelope) {
diff --git a/modules/audio_processing/agc2/vad_wrapper.cc b/modules/audio_processing/agc2/vad_wrapper.cc
index 7b61aee..17d9638 100644
--- a/modules/audio_processing/agc2/vad_wrapper.cc
+++ b/modules/audio_processing/agc2/vad_wrapper.cc
@@ -64,6 +64,8 @@
     std::unique_ptr<MonoVad> vad)
     : vad_reset_period_frames_(
           rtc::CheckedDivExact(vad_reset_period_ms, kFrameDurationMs)),
+      initialized_(false),
+      frame_size_(0),
       time_to_vad_reset_(vad_reset_period_frames_),
       vad_(std::move(vad)) {
   RTC_DCHECK(vad_);
@@ -74,19 +76,29 @@
 
 VoiceActivityDetectorWrapper::~VoiceActivityDetectorWrapper() = default;
 
+void VoiceActivityDetectorWrapper::Initialize(int sample_rate_hz) {
+  RTC_DCHECK_GT(sample_rate_hz, 0);
+  frame_size_ = rtc::CheckedDivExact(sample_rate_hz, kNumFramesPerSecond);
+  int status =
+      resampler_.InitializeIfNeeded(sample_rate_hz, vad_->SampleRateHz(),
+                                    /*num_channels=*/1);
+  constexpr int kStatusOk = 0;
+  RTC_DCHECK_EQ(status, kStatusOk);
+  vad_->Reset();
+  initialized_ = true;
+}
+
 float VoiceActivityDetectorWrapper::Analyze(AudioFrameView<const float> frame) {
+  RTC_DCHECK(initialized_);
   // Periodically reset the VAD.
   time_to_vad_reset_--;
   if (time_to_vad_reset_ <= 0) {
     vad_->Reset();
     time_to_vad_reset_ = vad_reset_period_frames_;
   }
-
   // Resample the first channel of `frame`.
-  resampler_.InitializeIfNeeded(
-      /*sample_rate_hz=*/frame.samples_per_channel() * kNumFramesPerSecond,
-      vad_->SampleRateHz(), /*num_channels=*/1);
-  resampler_.Resample(frame.channel(0).data(), frame.samples_per_channel(),
+  RTC_DCHECK_EQ(frame.samples_per_channel(), frame_size_);
+  resampler_.Resample(frame.channel(0).data(), frame_size_,
                       resampled_buffer_.data(), resampled_buffer_.size());
 
   return vad_->Analyze(resampled_buffer_);
diff --git a/modules/audio_processing/agc2/vad_wrapper.h b/modules/audio_processing/agc2/vad_wrapper.h
index f17fcda..0579ca1 100644
--- a/modules/audio_processing/agc2/vad_wrapper.h
+++ b/modules/audio_processing/agc2/vad_wrapper.h
@@ -43,6 +43,7 @@
   // Ctor. `vad_reset_period_ms` indicates the period in milliseconds to call
   // `MonoVad::Reset()`; it must be equal to or greater than the duration of two
   // frames. Uses `cpu_features` to instantiate the default VAD.
+  // TODO(bugs.webrtc.org/7494): Pass sample rate.
   VoiceActivityDetectorWrapper(int vad_reset_period_ms,
                                const AvailableCpuFeatures& cpu_features);
   // Ctor. Uses a custom `vad`.
@@ -54,11 +55,20 @@
       delete;
   ~VoiceActivityDetectorWrapper();
 
+  // TODO(bugs.webrtc.org/7494): Call initialize in the ctor.
+  // Initializes the VAD wrapper. Must be called before `Analyze()`.
+  void Initialize(int sample_rate_hz);
+
   // Analyzes the first channel of `frame` and returns the speech probability.
+  // `frame` must be a 10 ms frame with the sample rate specified in the last
+  // `Initialize()` call.
   float Analyze(AudioFrameView<const float> frame);
 
  private:
   const int vad_reset_period_frames_;
+  // TODO(bugs.webrtc.org/7494): Remove `initialized_`.
+  bool initialized_;
+  int frame_size_;
   int time_to_vad_reset_;
   PushResampler<float> resampler_;
   std::unique_ptr<MonoVad> vad_;
diff --git a/modules/audio_processing/agc2/vad_wrapper_unittest.cc b/modules/audio_processing/agc2/vad_wrapper_unittest.cc
index c1f7029..27e5af6 100644
--- a/modules/audio_processing/agc2/vad_wrapper_unittest.cc
+++ b/modules/audio_processing/agc2/vad_wrapper_unittest.cc
@@ -43,13 +43,26 @@
   MOCK_METHOD(float, Analyze, (rtc::ArrayView<const float> frame), (override));
 };
 
+// Checks that the ctor and `Initialize()` read the sample rate of the wrapped
+// VAD.
+TEST(GainController2VoiceActivityDetectorWrapper, CtorAndInitReadSampleRate) {
+  auto vad = std::make_unique<MockVad>();
+  EXPECT_CALL(*vad, SampleRateHz)
+      .Times(2)
+      .WillRepeatedly(Return(kSampleRate8kHz));
+  EXPECT_CALL(*vad, Reset).Times(AnyNumber());
+  auto vad_wrapper = std::make_unique<VoiceActivityDetectorWrapper>(
+      kNoVadPeriodicReset, std::move(vad));
+  vad_wrapper->Initialize(kSampleRate8kHz);
+}
+
 // Creates a `VoiceActivityDetectorWrapper` injecting a mock VAD that
 // repeatedly returns the next value from `speech_probabilities` and that
 // restarts from the beginning when after the last element is returned.
 std::unique_ptr<VoiceActivityDetectorWrapper> CreateMockVadWrapper(
     int vad_reset_period_ms,
     const std::vector<float>& speech_probabilities,
-    int expected_vad_reset_calls = 0) {
+    int expected_vad_reset_calls) {
   auto vad = std::make_unique<MockVad>();
   EXPECT_CALL(*vad, SampleRateHz)
       .Times(AnyNumber())
@@ -67,7 +80,7 @@
 // 10 ms mono frame.
 struct FrameWithView {
   // Ctor. Initializes the frame samples with `value`.
-  explicit FrameWithView(int sample_rate_hz = kSampleRate8kHz)
+  explicit FrameWithView(int sample_rate_hz)
       : samples(rtc::CheckedDivExact(sample_rate_hz, 100), 0.0f),
         channel0(samples.data()),
         view(&channel0, /*num_channels=*/1, samples.size()) {}
@@ -82,8 +95,10 @@
                                                 0.44f,  0.525f, 0.858f, 0.314f,
                                                 0.653f, 0.965f, 0.413f, 0.0f};
   auto vad_wrapper =
-      CreateMockVadWrapper(kNoVadPeriodicReset, speech_probabilities);
-  FrameWithView frame;
+      CreateMockVadWrapper(kNoVadPeriodicReset, speech_probabilities,
+                           /*expected_vad_reset_calls=*/1);
+  vad_wrapper->Initialize(kSampleRate8kHz);
+  FrameWithView frame(kSampleRate8kHz);
   for (int i = 0; rtc::SafeLt(i, speech_probabilities.size()); ++i) {
     SCOPED_TRACE(i);
     EXPECT_EQ(speech_probabilities[i], vad_wrapper->Analyze(frame.view));
@@ -95,8 +110,9 @@
   constexpr int kNumFrames = 19;
   auto vad_wrapper =
       CreateMockVadWrapper(kNoVadPeriodicReset, /*speech_probabilities=*/{1.0f},
-                           /*expected_vad_reset_calls=*/0);
-  FrameWithView frame;
+                           /*expected_vad_reset_calls=*/1);
+  vad_wrapper->Initialize(kSampleRate8kHz);
+  FrameWithView frame(kSampleRate8kHz);
   for (int i = 0; i < kNumFrames; ++i) {
     vad_wrapper->Analyze(frame.view);
   }
@@ -114,8 +130,10 @@
   auto vad_wrapper = CreateMockVadWrapper(
       /*vad_reset_period_ms=*/vad_reset_period_frames() * kFrameDurationMs,
       /*speech_probabilities=*/{1.0f},
-      /*expected_vad_reset_calls=*/num_frames() / vad_reset_period_frames());
-  FrameWithView frame;
+      /*expected_vad_reset_calls=*/1 +
+          num_frames() / vad_reset_period_frames());
+  vad_wrapper->Initialize(kSampleRate8kHz);
+  FrameWithView frame(kSampleRate8kHz);
   for (int i = 0; i < num_frames(); ++i) {
     vad_wrapper->Analyze(frame.view);
   }
@@ -141,7 +159,7 @@
   EXPECT_CALL(*vad, SampleRateHz)
       .Times(AnyNumber())
       .WillRepeatedly(Return(vad_sample_rate_hz()));
-  EXPECT_CALL(*vad, Reset).Times(0);
+  EXPECT_CALL(*vad, Reset).Times(1);
   EXPECT_CALL(*vad, Analyze(Truly([this](rtc::ArrayView<const float> frame) {
     return rtc::SafeEq(frame.size(),
                        rtc::CheckedDivExact(vad_sample_rate_hz(), 100));
@@ -149,6 +167,7 @@
   auto vad_wrapper = std::make_unique<VoiceActivityDetectorWrapper>(
       kNoVadPeriodicReset, std::move(vad));
   FrameWithView frame(input_sample_rate_hz());
+  vad_wrapper->Initialize(input_sample_rate_hz());
   vad_wrapper->Analyze(frame.view);
 }