AGC2: update adaptive digital test

This CL improves `GainController2::CheckGainAdaptiveDigital`, namely:
- correctly initialize AGC2 with the correct number of channels
- attenuate the input signal in order to avoid that the target gain is
  set to zero (which was the case before)
- run AG2 adaptive digital for a longer period to allow time to trigger
  the adaptive behavior (namely, from 2s to 10s)
- minor code style improvements

Bug: webrtc:7494
Change-Id: Ib41de088b341bb30460238b83e306a507b2bc5af
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/233101
Commit-Queue: Alessio Bazzica <alessiob@webrtc.org>
Reviewed-by: Sam Zackrisson <saza@webrtc.org>
Cr-Commit-Position: refs/heads/main@{#35099}
diff --git a/modules/audio_processing/gain_controller2_unittest.cc b/modules/audio_processing/gain_controller2_unittest.cc
index a4a6462..c8ee113 100644
--- a/modules/audio_processing/gain_controller2_unittest.cc
+++ b/modules/audio_processing/gain_controller2_unittest.cc
@@ -13,6 +13,7 @@
 #include <algorithm>
 #include <cmath>
 #include <memory>
+#include <numeric>
 
 #include "api/array_view.h"
 #include "modules/audio_processing/agc2/agc2_testing_common.h"
@@ -26,24 +27,24 @@
 namespace test {
 namespace {
 
-void SetAudioBufferSamples(float value, AudioBuffer* ab) {
-  // Sets all the samples in `ab` to `value`.
-  for (size_t k = 0; k < ab->num_channels(); ++k) {
-    std::fill(ab->channels()[k], ab->channels()[k] + ab->num_frames(), value);
+// Sets all the samples in `ab` to `value`.
+void SetAudioBufferSamples(float value, AudioBuffer& ab) {
+  for (size_t k = 0; k < ab.num_channels(); ++k) {
+    std::fill(ab.channels()[k], ab.channels()[k] + ab.num_frames(), value);
   }
 }
 
-float RunAgc2WithConstantInput(GainController2* agc2,
+float RunAgc2WithConstantInput(GainController2& agc2,
                                float input_level,
-                               size_t num_frames,
-                               int sample_rate) {
-  const int num_samples = rtc::CheckedDivExact(sample_rate, 100);
-  AudioBuffer ab(sample_rate, 1, sample_rate, 1, sample_rate, 1);
+                               int num_frames,
+                               int sample_rate_hz) {
+  const int num_samples = rtc::CheckedDivExact(sample_rate_hz, 100);
+  AudioBuffer ab(sample_rate_hz, 1, sample_rate_hz, 1, sample_rate_hz, 1);
 
   // Give time to the level estimator to converge.
-  for (size_t i = 0; i < num_frames + 1; ++i) {
-    SetAudioBufferSamples(input_level, &ab);
-    agc2->Process(&ab);
+  for (int i = 0; i < num_frames + 1; ++i) {
+    SetAudioBufferSamples(input_level, ab);
+    agc2.Process(&ab);
   }
 
   // Return the last sample from the last processed frame.
@@ -55,60 +56,19 @@
   AudioProcessing::Config::GainController2 config;
   config.adaptive_digital.enabled = false;
   config.fixed_digital.gain_db = fixed_gain_db;
-  // TODO(alessiob): Check why ASSERT_TRUE() below does not compile.
   EXPECT_TRUE(GainController2::Validate(config));
   return config;
 }
 
 std::unique_ptr<GainController2> CreateAgc2FixedDigitalMode(
     float fixed_gain_db,
-    size_t sample_rate_hz) {
+    int sample_rate_hz) {
   auto agc2 = std::make_unique<GainController2>();
   agc2->ApplyConfig(CreateAgc2FixedDigitalModeConfig(fixed_gain_db));
   agc2->Initialize(sample_rate_hz, /*num_channels=*/1);
   return agc2;
 }
 
-float GainDbAfterProcessingFile(GainController2& gain_controller,
-                                int max_duration_ms) {
-  // Set up an AudioBuffer to be filled from the speech file.
-  constexpr size_t kStereo = 2u;
-  const StreamConfig capture_config(AudioProcessing::kSampleRate48kHz, kStereo,
-                                    false);
-  AudioBuffer ab(capture_config.sample_rate_hz(), capture_config.num_channels(),
-                 capture_config.sample_rate_hz(), capture_config.num_channels(),
-                 capture_config.sample_rate_hz(),
-                 capture_config.num_channels());
-  test::InputAudioFile capture_file(
-      test::GetApmCaptureTestVectorFileName(AudioProcessing::kSampleRate48kHz));
-  std::vector<float> capture_input(capture_config.num_frames() *
-                                   capture_config.num_channels());
-
-  // Process the input file which must be long enough to cover
-  // `max_duration_ms`.
-  RTC_DCHECK_GT(max_duration_ms, 0);
-  const int num_frames = rtc::CheckedDivExact(max_duration_ms, 10);
-  for (int i = 0; i < num_frames; ++i) {
-    ReadFloatSamplesFromStereoFile(capture_config.num_frames(),
-                                   capture_config.num_channels(), &capture_file,
-                                   capture_input);
-    test::CopyVectorToAudioBuffer(capture_config, capture_input, &ab);
-    gain_controller.Process(&ab);
-  }
-
-  // Send in a last frame with minimum dBFS level.
-  constexpr float sample_value = 1.f;
-  SetAudioBufferSamples(sample_value, &ab);
-  gain_controller.Process(&ab);
-  // Measure the RMS level after processing.
-  float rms = 0.0f;
-  for (size_t i = 0; i < capture_config.num_frames(); ++i) {
-    rms += ab.channels()[0][i] * ab.channels()[0][i];
-  }
-  // Return the applied gain in dB.
-  return 20.0f * std::log10(std::sqrt(rms / capture_config.num_frames()));
-}
-
 }  // namespace
 
 TEST(GainController2, CheckDefaultConfig) {
@@ -119,33 +79,33 @@
 TEST(GainController2, CheckFixedDigitalConfig) {
   AudioProcessing::Config::GainController2 config;
   // Attenuation is not allowed.
-  config.fixed_digital.gain_db = -5.f;
+  config.fixed_digital.gain_db = -5.0f;
   EXPECT_FALSE(GainController2::Validate(config));
   // No gain is allowed.
-  config.fixed_digital.gain_db = 0.f;
+  config.fixed_digital.gain_db = 0.0f;
   EXPECT_TRUE(GainController2::Validate(config));
   // Positive gain is allowed.
-  config.fixed_digital.gain_db = 15.f;
+  config.fixed_digital.gain_db = 15.0f;
   EXPECT_TRUE(GainController2::Validate(config));
 }
 
 TEST(GainController2, CheckAdaptiveDigitalMaxGainChangeSpeedConfig) {
   AudioProcessing::Config::GainController2 config;
-  config.adaptive_digital.max_gain_change_db_per_second = -1.f;
+  config.adaptive_digital.max_gain_change_db_per_second = -1.0f;
   EXPECT_FALSE(GainController2::Validate(config));
-  config.adaptive_digital.max_gain_change_db_per_second = 0.f;
+  config.adaptive_digital.max_gain_change_db_per_second = 0.0f;
   EXPECT_FALSE(GainController2::Validate(config));
-  config.adaptive_digital.max_gain_change_db_per_second = 5.f;
+  config.adaptive_digital.max_gain_change_db_per_second = 5.0f;
   EXPECT_TRUE(GainController2::Validate(config));
 }
 
 TEST(GainController2, CheckAdaptiveDigitalMaxOutputNoiseLevelConfig) {
   AudioProcessing::Config::GainController2 config;
-  config.adaptive_digital.max_output_noise_level_dbfs = 5.f;
+  config.adaptive_digital.max_output_noise_level_dbfs = 5.0f;
   EXPECT_FALSE(GainController2::Validate(config));
-  config.adaptive_digital.max_output_noise_level_dbfs = 0.f;
+  config.adaptive_digital.max_output_noise_level_dbfs = 0.0f;
   EXPECT_TRUE(GainController2::Validate(config));
-  config.adaptive_digital.max_output_noise_level_dbfs = -5.f;
+  config.adaptive_digital.max_output_noise_level_dbfs = -5.0f;
   EXPECT_TRUE(GainController2::Validate(config));
 }
 
@@ -157,23 +117,23 @@
 }
 
 TEST(GainController2FixedDigital, GainShouldChangeOnSetGain) {
-  constexpr float kInputLevel = 1000.f;
+  constexpr float kInputLevel = 1000.0f;
   constexpr size_t kNumFrames = 5;
   constexpr size_t kSampleRateHz = 8000;
-  constexpr float kGain0Db = 0.f;
-  constexpr float kGain20Db = 20.f;
+  constexpr float kGain0Db = 0.0f;
+  constexpr float kGain20Db = 20.0f;
 
   auto agc2_fixed = CreateAgc2FixedDigitalMode(kGain0Db, kSampleRateHz);
 
   // Signal level is unchanged with 0 db gain.
-  EXPECT_FLOAT_EQ(RunAgc2WithConstantInput(agc2_fixed.get(), kInputLevel,
-                                           kNumFrames, kSampleRateHz),
+  EXPECT_FLOAT_EQ(RunAgc2WithConstantInput(*agc2_fixed, kInputLevel, kNumFrames,
+                                           kSampleRateHz),
                   kInputLevel);
 
   // +20 db should increase signal by a factor of 10.
   agc2_fixed->ApplyConfig(CreateAgc2FixedDigitalModeConfig(kGain20Db));
-  EXPECT_FLOAT_EQ(RunAgc2WithConstantInput(agc2_fixed.get(), kInputLevel,
-                                           kNumFrames, kSampleRateHz),
+  EXPECT_FLOAT_EQ(RunAgc2WithConstantInput(*agc2_fixed, kInputLevel, kNumFrames,
+                                           kSampleRateHz),
                   kInputLevel * 10);
 }
 
@@ -182,27 +142,27 @@
   // input signal when the gain changes.
   constexpr size_t kNumFrames = 5;
 
-  constexpr float kInputLevel = 1000.f;
+  constexpr float kInputLevel = 1000.0f;
   constexpr size_t kSampleRateHz = 8000;
-  constexpr float kGainDbLow = 0.f;
-  constexpr float kGainDbHigh = 25.f;
+  constexpr float kGainDbLow = 0.0f;
+  constexpr float kGainDbHigh = 25.0f;
   static_assert(kGainDbLow < kGainDbHigh, "");
 
   auto agc2_fixed = CreateAgc2FixedDigitalMode(kGainDbLow, kSampleRateHz);
 
   // Start with a lower gain.
   const float output_level_pre = RunAgc2WithConstantInput(
-      agc2_fixed.get(), kInputLevel, kNumFrames, kSampleRateHz);
+      *agc2_fixed, kInputLevel, kNumFrames, kSampleRateHz);
 
   // Increase gain.
   agc2_fixed->ApplyConfig(CreateAgc2FixedDigitalModeConfig(kGainDbHigh));
-  static_cast<void>(RunAgc2WithConstantInput(agc2_fixed.get(), kInputLevel,
+  static_cast<void>(RunAgc2WithConstantInput(*agc2_fixed, kInputLevel,
                                              kNumFrames, kSampleRateHz));
 
   // Back to the lower gain.
   agc2_fixed->ApplyConfig(CreateAgc2FixedDigitalModeConfig(kGainDbLow));
   const float output_level_post = RunAgc2WithConstantInput(
-      agc2_fixed.get(), kInputLevel, kNumFrames, kSampleRateHz);
+      *agc2_fixed, kInputLevel, kNumFrames, kSampleRateHz);
 
   EXPECT_EQ(output_level_pre, output_level_post);
 }
@@ -227,7 +187,7 @@
       public ::testing::WithParamInterface<FixedDigitalTestParams> {};
 
 TEST_P(FixedDigitalTest, CheckSaturationBehaviorWithLimiter) {
-  const float kInputLevel = 32767.f;
+  const float kInputLevel = 32767.0f;
   const size_t kNumFrames = 5;
 
   const auto params = GetParam();
@@ -238,11 +198,11 @@
     SCOPED_TRACE(std::to_string(gain_db));
     auto agc2_fixed = CreateAgc2FixedDigitalMode(gain_db, params.sample_rate);
     const float processed_sample = RunAgc2WithConstantInput(
-        agc2_fixed.get(), kInputLevel, kNumFrames, params.sample_rate);
+        *agc2_fixed, kInputLevel, kNumFrames, params.sample_rate);
     if (params.saturation_expected) {
-      EXPECT_FLOAT_EQ(processed_sample, 32767.f);
+      EXPECT_FLOAT_EQ(processed_sample, 32767.0f);
     } else {
-      EXPECT_LT(processed_sample, 32767.f);
+      EXPECT_LT(processed_sample, 32767.0f);
     }
   }
 }
@@ -265,29 +225,68 @@
         // When gain > `test::kLimiterMaxInputLevelDbFs`, the limiter will
         // saturate the signal (at any sample rate).
         FixedDigitalTestParams(test::kLimiterMaxInputLevelDbFs + 0.01f,
-                               10.f,
+                               10.0f,
                                8000,
                                true),
         FixedDigitalTestParams(test::kLimiterMaxInputLevelDbFs + 0.01f,
-                               10.f,
+                               10.0f,
                                48000,
                                true)));
 
-// Checks that the gain applied at the end of a PCM samples file is close to the
-// expected value.
-TEST(GainController2, CheckGainAdaptiveDigital) {
-  constexpr float kExpectedGainDb = 4.3f;
-  constexpr float kToleranceDb = 0.5f;
-  GainController2 gain_controller2;
-  gain_controller2.Initialize(AudioProcessing::kSampleRate48kHz,
-                              /*num_channels=*/1);
+// Processes a test audio file and checks that the gain applied at the end of
+// the recording is close to the expected value.
+TEST(GainController2, CheckFinalGainWithAdaptiveDigitalController) {
+  // Create AGC2 enabling only the adaptive digital controller.
+  GainController2 agc2;
   AudioProcessing::Config::GainController2 config;
   config.fixed_digital.gain_db = 0.0f;
   config.adaptive_digital.enabled = true;
-  gain_controller2.ApplyConfig(config);
-  EXPECT_NEAR(
-      GainDbAfterProcessingFile(gain_controller2, /*max_duration_ms=*/2000),
-      kExpectedGainDb, kToleranceDb);
+  agc2.ApplyConfig(config);
+
+  // The input audio is a 48k stereo recording.
+  constexpr int kSampleRateHz = AudioProcessing::kSampleRate48kHz;
+  constexpr int kStereo = 2;
+  test::InputAudioFile input_file(
+      test::GetApmCaptureTestVectorFileName(kSampleRateHz),
+      /*loop_at_end=*/true);
+  const StreamConfig stream_config(kSampleRateHz, kStereo,
+                                   /*has_keyboard=*/false);
+
+  // Initialize AGC2 for the used input.
+  agc2.Initialize(kSampleRateHz, kStereo);
+
+  // Init buffers.
+  constexpr int kFrameDurationMs = 10;
+  std::vector<float> frame(kStereo * stream_config.num_frames());
+  AudioBuffer audio_buffer(kSampleRateHz, kStereo, kSampleRateHz, kStereo,
+                           kSampleRateHz, kStereo);
+
+  // Simulate.
+  constexpr float kGainDb = -6.0f;
+  const float gain = std::pow(10.0f, kGainDb / 20.0f);
+  constexpr int kDurationMs = 10000;
+  constexpr int kNumFramesToProcess = kDurationMs / kFrameDurationMs;
+  for (int i = 0; i < kNumFramesToProcess; ++i) {
+    ReadFloatSamplesFromStereoFile(stream_config.num_frames(),
+                                   stream_config.num_channels(), &input_file,
+                                   frame);
+    // Apply a fixed gain to the input audio.
+    for (float& x : frame)
+      x *= gain;
+    test::CopyVectorToAudioBuffer(stream_config, frame, &audio_buffer);
+    // Process.
+    agc2.Process(&audio_buffer);
+  }
+
+  // Estimate the applied gain by processing a probing frame.
+  SetAudioBufferSamples(/*value=*/1.0f, audio_buffer);
+  agc2.Process(&audio_buffer);
+  const float applied_gain_db =
+      20.0f * std::log10(audio_buffer.channels_const()[0][0]);
+
+  constexpr float kExpectedGainDb = 5.6f;
+  constexpr float kToleranceDb = 0.3f;
+  EXPECT_NEAR(applied_gain_db, kExpectedGainDb, kToleranceDb);
 }
 
 }  // namespace test