AGC2: VAD moved into `GainController2` Bit exactness verified with audioproc_f on a collection of AEC dumps and Wav files (42 recordings in total). Bug: webrtc:7494 Change-Id: Id9849c4463791f5a203afe31efc163efb4d4458e Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/234583 Reviewed-by: Hanna Silen <silen@webrtc.org> Commit-Queue: Alessio Bazzica <alessiob@webrtc.org> Cr-Commit-Position: refs/heads/main@{#35248}
diff --git a/modules/audio_processing/BUILD.gn b/modules/audio_processing/BUILD.gn index 4223101..7fbae51 100644 --- a/modules/audio_processing/BUILD.gn +++ b/modules/audio_processing/BUILD.gn
@@ -132,8 +132,10 @@ "../../rtc_base:stringutils", "../../system_wrappers:field_trial", "agc2:adaptive_digital", + "agc2:cpu_features", "agc2:fixed_digital", "agc2:gain_applier", + "agc2:vad_wrapper", ] }
diff --git a/modules/audio_processing/agc2/BUILD.gn b/modules/audio_processing/agc2/BUILD.gn index f767a6d..5ec3ce1 100644 --- a/modules/audio_processing/agc2/BUILD.gn +++ b/modules/audio_processing/agc2/BUILD.gn
@@ -52,7 +52,6 @@ "../../../rtc_base:rtc_base_approved", "../../../rtc_base:safe_compare", "../../../rtc_base:safe_minmax", - "../../../system_wrappers:field_trial", "../../../system_wrappers:metrics", ] @@ -150,7 +149,11 @@ "vad_wrapper.cc", "vad_wrapper.h", ] - visibility = [ "./*" ] + + visibility = [ + "..:gain_controller2", + "./*", + ] defines = [] if (rtc_build_with_neon && current_cpu != "arm64") {
diff --git a/modules/audio_processing/agc2/adaptive_agc.cc b/modules/audio_processing/agc2/adaptive_agc.cc index b543365..295a33f 100644 --- a/modules/audio_processing/agc2/adaptive_agc.cc +++ b/modules/audio_processing/agc2/adaptive_agc.cc
@@ -11,31 +11,14 @@ #include "modules/audio_processing/agc2/adaptive_agc.h" #include "common_audio/include/audio_util.h" -#include "modules/audio_processing/agc2/cpu_features.h" #include "modules/audio_processing/agc2/vad_wrapper.h" #include "modules/audio_processing/logging/apm_data_dumper.h" #include "rtc_base/checks.h" #include "rtc_base/logging.h" -#include "system_wrappers/include/field_trial.h" namespace webrtc { namespace { -// Detects the available CPU features and applies any kill-switches. -AvailableCpuFeatures GetAllowedCpuFeatures() { - AvailableCpuFeatures features = GetAvailableCpuFeatures(); - if (field_trial::IsEnabled("WebRTC-Agc2SimdSse2KillSwitch")) { - features.sse2 = false; - } - if (field_trial::IsEnabled("WebRTC-Agc2SimdAvx2KillSwitch")) { - features.avx2 = false; - } - if (field_trial::IsEnabled("WebRTC-Agc2SimdNeonKillSwitch")) { - features.neon = false; - } - return features; -} - // Peak and RMS audio levels in dBFS. struct AudioLevels { float peak_dbfs; @@ -60,7 +43,6 @@ ApmDataDumper* apm_data_dumper, const AudioProcessing::Config::GainController2::AdaptiveDigital& config) : speech_level_estimator_(apm_data_dumper, config), - vad_(config.vad_reset_period_ms, GetAllowedCpuFeatures()), gain_controller_(apm_data_dumper, config), apm_data_dumper_(apm_data_dumper), noise_level_estimator_(CreateNoiseFloorEstimator(apm_data_dumper)), @@ -77,18 +59,18 @@ void AdaptiveAgc::Initialize(int sample_rate_hz, int num_channels) { gain_controller_.Initialize(sample_rate_hz, num_channels); - vad_.Initialize(sample_rate_hz); } -void AdaptiveAgc::Process(AudioFrameView<float> frame, float limiter_envelope) { +void AdaptiveAgc::Process(AudioFrameView<float> frame, + float speech_probability, + float limiter_envelope) { AudioLevels levels = ComputeAudioLevels(frame); + apm_data_dumper_->DumpRaw("agc2_input_rms_dbfs", levels.rms_dbfs); + apm_data_dumper_->DumpRaw("agc2_input_peak_dbfs", levels.peak_dbfs); AdaptiveDigitalGainApplier::FrameInfo info; - info.speech_probability = vad_.Analyze(frame); - apm_data_dumper_->DumpRaw("agc2_speech_probability", info.speech_probability); - apm_data_dumper_->DumpRaw("agc2_input_rms_dbfs", levels.rms_dbfs); - apm_data_dumper_->DumpRaw("agc2_input_peak_dbfs", levels.peak_dbfs); + info.speech_probability = speech_probability; speech_level_estimator_.Update(levels.rms_dbfs, levels.peak_dbfs, info.speech_probability);
diff --git a/modules/audio_processing/agc2/adaptive_agc.h b/modules/audio_processing/agc2/adaptive_agc.h index 32de680..a9a6985 100644 --- a/modules/audio_processing/agc2/adaptive_agc.h +++ b/modules/audio_processing/agc2/adaptive_agc.h
@@ -17,7 +17,6 @@ #include "modules/audio_processing/agc2/adaptive_mode_level_estimator.h" #include "modules/audio_processing/agc2/noise_level_estimator.h" #include "modules/audio_processing/agc2/saturation_protector.h" -#include "modules/audio_processing/agc2/vad_wrapper.h" #include "modules/audio_processing/include/audio_frame_view.h" #include "modules/audio_processing/include/audio_processing.h" @@ -38,16 +37,17 @@ // TODO(crbug.com/webrtc/7494): Add `SetLimiterEnvelope()`. // Analyzes `frame` and applies a digital adaptive gain to it. Takes into - // account the envelope measured by the limiter. + // account the speech probability and the envelope measured by the limiter. // TODO(crbug.com/webrtc/7494): Remove `limiter_envelope`. - void Process(AudioFrameView<float> frame, float limiter_envelope); + void Process(AudioFrameView<float> frame, + float speech_probability, + float limiter_envelope); // Handles a gain change applied to the input signal (e.g., analog gain). void HandleInputGainChange(); private: AdaptiveModeLevelEstimator speech_level_estimator_; - VoiceActivityDetectorWrapper vad_; AdaptiveDigitalGainApplier gain_controller_; ApmDataDumper* const apm_data_dumper_; std::unique_ptr<NoiseLevelEstimator> noise_level_estimator_;
diff --git a/modules/audio_processing/agc2/vad_wrapper.cc b/modules/audio_processing/agc2/vad_wrapper.cc index 17d9638..91448f8 100644 --- a/modules/audio_processing/agc2/vad_wrapper.cc +++ b/modules/audio_processing/agc2/vad_wrapper.cc
@@ -54,24 +54,25 @@ VoiceActivityDetectorWrapper::VoiceActivityDetectorWrapper( int vad_reset_period_ms, - const AvailableCpuFeatures& cpu_features) - : VoiceActivityDetectorWrapper( - vad_reset_period_ms, - std::make_unique<MonoVadImpl>(cpu_features)) {} + const AvailableCpuFeatures& cpu_features, + int sample_rate_hz) + : VoiceActivityDetectorWrapper(vad_reset_period_ms, + std::make_unique<MonoVadImpl>(cpu_features), + sample_rate_hz) {} VoiceActivityDetectorWrapper::VoiceActivityDetectorWrapper( int vad_reset_period_ms, - std::unique_ptr<MonoVad> vad) + std::unique_ptr<MonoVad> vad, + int sample_rate_hz) : vad_reset_period_frames_( rtc::CheckedDivExact(vad_reset_period_ms, kFrameDurationMs)), - initialized_(false), - frame_size_(0), time_to_vad_reset_(vad_reset_period_frames_), vad_(std::move(vad)) { RTC_DCHECK(vad_); RTC_DCHECK_GT(vad_reset_period_frames_, 1); resampled_buffer_.resize( rtc::CheckedDivExact(vad_->SampleRateHz(), kNumFramesPerSecond)); + Initialize(sample_rate_hz); } VoiceActivityDetectorWrapper::~VoiceActivityDetectorWrapper() = default; @@ -85,11 +86,9 @@ constexpr int kStatusOk = 0; RTC_DCHECK_EQ(status, kStatusOk); vad_->Reset(); - initialized_ = true; } float VoiceActivityDetectorWrapper::Analyze(AudioFrameView<const float> frame) { - RTC_DCHECK(initialized_); // Periodically reset the VAD. time_to_vad_reset_--; if (time_to_vad_reset_ <= 0) {
diff --git a/modules/audio_processing/agc2/vad_wrapper.h b/modules/audio_processing/agc2/vad_wrapper.h index 0579ca1..6df0ead 100644 --- a/modules/audio_processing/agc2/vad_wrapper.h +++ b/modules/audio_processing/agc2/vad_wrapper.h
@@ -43,20 +43,20 @@ // Ctor. `vad_reset_period_ms` indicates the period in milliseconds to call // `MonoVad::Reset()`; it must be equal to or greater than the duration of two // frames. Uses `cpu_features` to instantiate the default VAD. - // TODO(bugs.webrtc.org/7494): Pass sample rate. VoiceActivityDetectorWrapper(int vad_reset_period_ms, - const AvailableCpuFeatures& cpu_features); + const AvailableCpuFeatures& cpu_features, + int sample_rate_hz); // Ctor. Uses a custom `vad`. VoiceActivityDetectorWrapper(int vad_reset_period_ms, - std::unique_ptr<MonoVad> vad); + std::unique_ptr<MonoVad> vad, + int sample_rate_hz); VoiceActivityDetectorWrapper(const VoiceActivityDetectorWrapper&) = delete; VoiceActivityDetectorWrapper& operator=(const VoiceActivityDetectorWrapper&) = delete; ~VoiceActivityDetectorWrapper(); - // TODO(bugs.webrtc.org/7494): Call initialize in the ctor. - // Initializes the VAD wrapper. Must be called before `Analyze()`. + // Initializes the VAD wrapper. void Initialize(int sample_rate_hz); // Analyzes the first channel of `frame` and returns the speech probability. @@ -66,8 +66,6 @@ private: const int vad_reset_period_frames_; - // TODO(bugs.webrtc.org/7494): Remove `initialized_`. - bool initialized_; int frame_size_; int time_to_vad_reset_; PushResampler<float> resampler_;
diff --git a/modules/audio_processing/agc2/vad_wrapper_unittest.cc b/modules/audio_processing/agc2/vad_wrapper_unittest.cc index 27e5af6..b61b015 100644 --- a/modules/audio_processing/agc2/vad_wrapper_unittest.cc +++ b/modules/audio_processing/agc2/vad_wrapper_unittest.cc
@@ -31,6 +31,8 @@ using ::testing::ReturnRoundRobin; using ::testing::Truly; +constexpr int kNumFramesPerSecond = 100; + constexpr int kNoVadPeriodicReset = kFrameDurationMs * (std::numeric_limits<int>::max() / kFrameDurationMs); @@ -52,8 +54,7 @@ .WillRepeatedly(Return(kSampleRate8kHz)); EXPECT_CALL(*vad, Reset).Times(AnyNumber()); auto vad_wrapper = std::make_unique<VoiceActivityDetectorWrapper>( - kNoVadPeriodicReset, std::move(vad)); - vad_wrapper->Initialize(kSampleRate8kHz); + kNoVadPeriodicReset, std::move(vad), kSampleRate8kHz); } // Creates a `VoiceActivityDetectorWrapper` injecting a mock VAD that @@ -61,27 +62,29 @@ // restarts from the beginning when after the last element is returned. std::unique_ptr<VoiceActivityDetectorWrapper> CreateMockVadWrapper( int vad_reset_period_ms, + int sample_rate_hz, const std::vector<float>& speech_probabilities, int expected_vad_reset_calls) { auto vad = std::make_unique<MockVad>(); EXPECT_CALL(*vad, SampleRateHz) .Times(AnyNumber()) - .WillRepeatedly(Return(kSampleRate8kHz)); + .WillRepeatedly(Return(sample_rate_hz)); if (expected_vad_reset_calls >= 0) { EXPECT_CALL(*vad, Reset).Times(expected_vad_reset_calls); } EXPECT_CALL(*vad, Analyze) .Times(AnyNumber()) .WillRepeatedly(ReturnRoundRobin(speech_probabilities)); - return std::make_unique<VoiceActivityDetectorWrapper>(vad_reset_period_ms, - std::move(vad)); + return std::make_unique<VoiceActivityDetectorWrapper>( + vad_reset_period_ms, std::move(vad), kSampleRate8kHz); } // 10 ms mono frame. struct FrameWithView { // Ctor. Initializes the frame samples with `value`. explicit FrameWithView(int sample_rate_hz) - : samples(rtc::CheckedDivExact(sample_rate_hz, 100), 0.0f), + : samples(rtc::CheckedDivExact(sample_rate_hz, kNumFramesPerSecond), + 0.0f), channel0(samples.data()), view(&channel0, /*num_channels=*/1, samples.size()) {} std::vector<float> samples; @@ -94,10 +97,9 @@ const std::vector<float> speech_probabilities{0.709f, 0.484f, 0.882f, 0.167f, 0.44f, 0.525f, 0.858f, 0.314f, 0.653f, 0.965f, 0.413f, 0.0f}; - auto vad_wrapper = - CreateMockVadWrapper(kNoVadPeriodicReset, speech_probabilities, - /*expected_vad_reset_calls=*/1); - vad_wrapper->Initialize(kSampleRate8kHz); + auto vad_wrapper = CreateMockVadWrapper(kNoVadPeriodicReset, kSampleRate8kHz, + speech_probabilities, + /*expected_vad_reset_calls=*/1); FrameWithView frame(kSampleRate8kHz); for (int i = 0; rtc::SafeLt(i, speech_probabilities.size()); ++i) { SCOPED_TRACE(i); @@ -108,10 +110,9 @@ // Checks that the VAD is not periodically reset. TEST(GainController2VoiceActivityDetectorWrapper, VadNoPeriodicReset) { constexpr int kNumFrames = 19; - auto vad_wrapper = - CreateMockVadWrapper(kNoVadPeriodicReset, /*speech_probabilities=*/{1.0f}, - /*expected_vad_reset_calls=*/1); - vad_wrapper->Initialize(kSampleRate8kHz); + auto vad_wrapper = CreateMockVadWrapper(kNoVadPeriodicReset, kSampleRate8kHz, + /*speech_probabilities=*/{1.0f}, + /*expected_vad_reset_calls=*/1); FrameWithView frame(kSampleRate8kHz); for (int i = 0; i < kNumFrames; ++i) { vad_wrapper->Analyze(frame.view); @@ -129,10 +130,10 @@ TEST_P(VadPeriodResetParametrization, VadPeriodicReset) { auto vad_wrapper = CreateMockVadWrapper( /*vad_reset_period_ms=*/vad_reset_period_frames() * kFrameDurationMs, + kSampleRate8kHz, /*speech_probabilities=*/{1.0f}, /*expected_vad_reset_calls=*/1 + num_frames() / vad_reset_period_frames()); - vad_wrapper->Initialize(kSampleRate8kHz); FrameWithView frame(kSampleRate8kHz); for (int i = 0; i < num_frames(); ++i) { vad_wrapper->Analyze(frame.view); @@ -161,13 +162,12 @@ .WillRepeatedly(Return(vad_sample_rate_hz())); EXPECT_CALL(*vad, Reset).Times(1); EXPECT_CALL(*vad, Analyze(Truly([this](rtc::ArrayView<const float> frame) { - return rtc::SafeEq(frame.size(), - rtc::CheckedDivExact(vad_sample_rate_hz(), 100)); + return rtc::SafeEq(frame.size(), rtc::CheckedDivExact(vad_sample_rate_hz(), + kNumFramesPerSecond)); }))).Times(1); auto vad_wrapper = std::make_unique<VoiceActivityDetectorWrapper>( - kNoVadPeriodicReset, std::move(vad)); + kNoVadPeriodicReset, std::move(vad), input_sample_rate_hz()); FrameWithView frame(input_sample_rate_hz()); - vad_wrapper->Initialize(input_sample_rate_hz()); vad_wrapper->Analyze(frame.view); }
diff --git a/modules/audio_processing/gain_controller2.cc b/modules/audio_processing/gain_controller2.cc index 416bc78..a21ef72 100644 --- a/modules/audio_processing/gain_controller2.cc +++ b/modules/audio_processing/gain_controller2.cc
@@ -14,6 +14,7 @@ #include <utility> #include "common_audio/include/audio_util.h" +#include "modules/audio_processing/agc2/cpu_features.h" #include "modules/audio_processing/audio_buffer.h" #include "modules/audio_processing/include/audio_frame_view.h" #include "modules/audio_processing/logging/apm_data_dumper.h" @@ -21,6 +22,7 @@ #include "rtc_base/checks.h" #include "rtc_base/logging.h" #include "rtc_base/strings/string_builder.h" +#include "system_wrappers/include/field_trial.h" namespace webrtc { namespace { @@ -33,6 +35,21 @@ constexpr int kLogLimiterStatsPeriodNumFrames = kLogLimiterStatsPeriodMs / kFrameLengthMs; +// Detects the available CPU features and applies any kill-switches. +AvailableCpuFeatures GetAllowedCpuFeatures() { + AvailableCpuFeatures features = GetAvailableCpuFeatures(); + if (field_trial::IsEnabled("WebRTC-Agc2SimdSse2KillSwitch")) { + features.sse2 = false; + } + if (field_trial::IsEnabled("WebRTC-Agc2SimdAvx2KillSwitch")) { + features.avx2 = false; + } + if (field_trial::IsEnabled("WebRTC-Agc2SimdNeonKillSwitch")) { + features.neon = false; + } + return features; +} + // Creates an adaptive digital gain controller if enabled. std::unique_ptr<AdaptiveAgc> CreateAdaptiveDigitalController( const Agc2Config::AdaptiveDigital& config, @@ -40,7 +57,8 @@ int num_channels, ApmDataDumper* data_dumper) { if (config.enabled) { - // TODO(bugs.webrtc.org/7494): Also init with sample rate and num channels. + // TODO(bugs.webrtc.org/7494): Also init with sample rate and num + // channels. auto controller = std::make_unique<AdaptiveAgc>(data_dumper, config); // TODO(bugs.webrtc.org/7494): Remove once passed to the ctor. controller->Initialize(sample_rate_hz, num_channels); @@ -56,7 +74,8 @@ GainController2::GainController2(const Agc2Config& config, int sample_rate_hz, int num_channels) - : data_dumper_(rtc::AtomicOps::Increment(&instance_count_)), + : cpu_features_(GetAllowedCpuFeatures()), + data_dumper_(rtc::AtomicOps::Increment(&instance_count_)), fixed_gain_applier_(/*hard_clip_samples=*/false, /*initial_gain_factor=*/0.0f), adaptive_digital_controller_( @@ -71,6 +90,14 @@ data_dumper_.InitiateNewSetOfRecordings(); // TODO(bugs.webrtc.org/7494): Set gain when `fixed_gain_applier_` is init'd. fixed_gain_applier_.SetGainFactor(DbToRatio(config.fixed_digital.gain_db)); + const bool use_vad = config.adaptive_digital.enabled; + if (use_vad) { + // TODO(bugs.webrtc.org/7494): Move `vad_reset_period_ms` from adaptive + // digital to gain controller 2 config. + vad_ = std::make_unique<VoiceActivityDetectorWrapper>( + config.adaptive_digital.vad_reset_period_ms, cpu_features_, + sample_rate_hz); + } } GainController2::~GainController2() = default; @@ -82,6 +109,9 @@ sample_rate_hz == AudioProcessing::kSampleRate48kHz); // TODO(bugs.webrtc.org/7494): Initialize `fixed_gain_applier_`. limiter_.SetSampleRate(sample_rate_hz); + if (vad_) { + vad_->Initialize(sample_rate_hz); + } if (adaptive_digital_controller_) { adaptive_digital_controller_->Initialize(sample_rate_hz, num_channels); } @@ -104,10 +134,17 @@ data_dumper_.DumpRaw("agc2_notified_analog_level", analog_level_); AudioFrameView<float> float_frame(audio->channels(), audio->num_channels(), audio->num_frames()); + absl::optional<float> speech_probability; + // TODO(bugs.webrtc.org/7494): Apply fixed digital gain after VAD. fixed_gain_applier_.ApplyGain(float_frame); + if (vad_) { + speech_probability = vad_->Analyze(float_frame); + data_dumper_.DumpRaw("agc2_speech_probability", speech_probability.value()); + } if (adaptive_digital_controller_) { - adaptive_digital_controller_->Process(float_frame, - limiter_.LastAudioLevel()); + RTC_DCHECK(speech_probability.has_value()); + adaptive_digital_controller_->Process( + float_frame, speech_probability.value(), limiter_.LastAudioLevel()); } limiter_.Process(float_frame);
diff --git a/modules/audio_processing/gain_controller2.h b/modules/audio_processing/gain_controller2.h index cfe51c2..7e28ee2 100644 --- a/modules/audio_processing/gain_controller2.h +++ b/modules/audio_processing/gain_controller2.h
@@ -15,8 +15,10 @@ #include <string> #include "modules/audio_processing/agc2/adaptive_agc.h" +#include "modules/audio_processing/agc2/cpu_features.h" #include "modules/audio_processing/agc2/gain_applier.h" #include "modules/audio_processing/agc2/limiter.h" +#include "modules/audio_processing/agc2/vad_wrapper.h" #include "modules/audio_processing/include/audio_processing.h" #include "modules/audio_processing/logging/apm_data_dumper.h" @@ -51,8 +53,10 @@ private: static int instance_count_; + const AvailableCpuFeatures cpu_features_; ApmDataDumper data_dumper_; GainApplier fixed_gain_applier_; + std::unique_ptr<VoiceActivityDetectorWrapper> vad_; std::unique_ptr<AdaptiveAgc> adaptive_digital_controller_; Limiter limiter_; int calls_since_last_limiter_log_;