AGC2: VAD moved into `GainController2`
Bit exactness verified with audioproc_f on a collection of AEC dumps
and Wav files (42 recordings in total).
Bug: webrtc:7494
Change-Id: Id9849c4463791f5a203afe31efc163efb4d4458e
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/234583
Reviewed-by: Hanna Silen <silen@webrtc.org>
Commit-Queue: Alessio Bazzica <alessiob@webrtc.org>
Cr-Commit-Position: refs/heads/main@{#35248}
diff --git a/modules/audio_processing/BUILD.gn b/modules/audio_processing/BUILD.gn
index 4223101..7fbae51 100644
--- a/modules/audio_processing/BUILD.gn
+++ b/modules/audio_processing/BUILD.gn
@@ -132,8 +132,10 @@
"../../rtc_base:stringutils",
"../../system_wrappers:field_trial",
"agc2:adaptive_digital",
+ "agc2:cpu_features",
"agc2:fixed_digital",
"agc2:gain_applier",
+ "agc2:vad_wrapper",
]
}
diff --git a/modules/audio_processing/agc2/BUILD.gn b/modules/audio_processing/agc2/BUILD.gn
index f767a6d..5ec3ce1 100644
--- a/modules/audio_processing/agc2/BUILD.gn
+++ b/modules/audio_processing/agc2/BUILD.gn
@@ -52,7 +52,6 @@
"../../../rtc_base:rtc_base_approved",
"../../../rtc_base:safe_compare",
"../../../rtc_base:safe_minmax",
- "../../../system_wrappers:field_trial",
"../../../system_wrappers:metrics",
]
@@ -150,7 +149,11 @@
"vad_wrapper.cc",
"vad_wrapper.h",
]
- visibility = [ "./*" ]
+
+ visibility = [
+ "..:gain_controller2",
+ "./*",
+ ]
defines = []
if (rtc_build_with_neon && current_cpu != "arm64") {
diff --git a/modules/audio_processing/agc2/adaptive_agc.cc b/modules/audio_processing/agc2/adaptive_agc.cc
index b543365..295a33f 100644
--- a/modules/audio_processing/agc2/adaptive_agc.cc
+++ b/modules/audio_processing/agc2/adaptive_agc.cc
@@ -11,31 +11,14 @@
#include "modules/audio_processing/agc2/adaptive_agc.h"
#include "common_audio/include/audio_util.h"
-#include "modules/audio_processing/agc2/cpu_features.h"
#include "modules/audio_processing/agc2/vad_wrapper.h"
#include "modules/audio_processing/logging/apm_data_dumper.h"
#include "rtc_base/checks.h"
#include "rtc_base/logging.h"
-#include "system_wrappers/include/field_trial.h"
namespace webrtc {
namespace {
-// Detects the available CPU features and applies any kill-switches.
-AvailableCpuFeatures GetAllowedCpuFeatures() {
- AvailableCpuFeatures features = GetAvailableCpuFeatures();
- if (field_trial::IsEnabled("WebRTC-Agc2SimdSse2KillSwitch")) {
- features.sse2 = false;
- }
- if (field_trial::IsEnabled("WebRTC-Agc2SimdAvx2KillSwitch")) {
- features.avx2 = false;
- }
- if (field_trial::IsEnabled("WebRTC-Agc2SimdNeonKillSwitch")) {
- features.neon = false;
- }
- return features;
-}
-
// Peak and RMS audio levels in dBFS.
struct AudioLevels {
float peak_dbfs;
@@ -60,7 +43,6 @@
ApmDataDumper* apm_data_dumper,
const AudioProcessing::Config::GainController2::AdaptiveDigital& config)
: speech_level_estimator_(apm_data_dumper, config),
- vad_(config.vad_reset_period_ms, GetAllowedCpuFeatures()),
gain_controller_(apm_data_dumper, config),
apm_data_dumper_(apm_data_dumper),
noise_level_estimator_(CreateNoiseFloorEstimator(apm_data_dumper)),
@@ -77,18 +59,18 @@
void AdaptiveAgc::Initialize(int sample_rate_hz, int num_channels) {
gain_controller_.Initialize(sample_rate_hz, num_channels);
- vad_.Initialize(sample_rate_hz);
}
-void AdaptiveAgc::Process(AudioFrameView<float> frame, float limiter_envelope) {
+void AdaptiveAgc::Process(AudioFrameView<float> frame,
+ float speech_probability,
+ float limiter_envelope) {
AudioLevels levels = ComputeAudioLevels(frame);
+ apm_data_dumper_->DumpRaw("agc2_input_rms_dbfs", levels.rms_dbfs);
+ apm_data_dumper_->DumpRaw("agc2_input_peak_dbfs", levels.peak_dbfs);
AdaptiveDigitalGainApplier::FrameInfo info;
- info.speech_probability = vad_.Analyze(frame);
- apm_data_dumper_->DumpRaw("agc2_speech_probability", info.speech_probability);
- apm_data_dumper_->DumpRaw("agc2_input_rms_dbfs", levels.rms_dbfs);
- apm_data_dumper_->DumpRaw("agc2_input_peak_dbfs", levels.peak_dbfs);
+ info.speech_probability = speech_probability;
speech_level_estimator_.Update(levels.rms_dbfs, levels.peak_dbfs,
info.speech_probability);
diff --git a/modules/audio_processing/agc2/adaptive_agc.h b/modules/audio_processing/agc2/adaptive_agc.h
index 32de680..a9a6985 100644
--- a/modules/audio_processing/agc2/adaptive_agc.h
+++ b/modules/audio_processing/agc2/adaptive_agc.h
@@ -17,7 +17,6 @@
#include "modules/audio_processing/agc2/adaptive_mode_level_estimator.h"
#include "modules/audio_processing/agc2/noise_level_estimator.h"
#include "modules/audio_processing/agc2/saturation_protector.h"
-#include "modules/audio_processing/agc2/vad_wrapper.h"
#include "modules/audio_processing/include/audio_frame_view.h"
#include "modules/audio_processing/include/audio_processing.h"
@@ -38,16 +37,17 @@
// TODO(crbug.com/webrtc/7494): Add `SetLimiterEnvelope()`.
// Analyzes `frame` and applies a digital adaptive gain to it. Takes into
- // account the envelope measured by the limiter.
+ // account the speech probability and the envelope measured by the limiter.
// TODO(crbug.com/webrtc/7494): Remove `limiter_envelope`.
- void Process(AudioFrameView<float> frame, float limiter_envelope);
+ void Process(AudioFrameView<float> frame,
+ float speech_probability,
+ float limiter_envelope);
// Handles a gain change applied to the input signal (e.g., analog gain).
void HandleInputGainChange();
private:
AdaptiveModeLevelEstimator speech_level_estimator_;
- VoiceActivityDetectorWrapper vad_;
AdaptiveDigitalGainApplier gain_controller_;
ApmDataDumper* const apm_data_dumper_;
std::unique_ptr<NoiseLevelEstimator> noise_level_estimator_;
diff --git a/modules/audio_processing/agc2/vad_wrapper.cc b/modules/audio_processing/agc2/vad_wrapper.cc
index 17d9638..91448f8 100644
--- a/modules/audio_processing/agc2/vad_wrapper.cc
+++ b/modules/audio_processing/agc2/vad_wrapper.cc
@@ -54,24 +54,25 @@
VoiceActivityDetectorWrapper::VoiceActivityDetectorWrapper(
int vad_reset_period_ms,
- const AvailableCpuFeatures& cpu_features)
- : VoiceActivityDetectorWrapper(
- vad_reset_period_ms,
- std::make_unique<MonoVadImpl>(cpu_features)) {}
+ const AvailableCpuFeatures& cpu_features,
+ int sample_rate_hz)
+ : VoiceActivityDetectorWrapper(vad_reset_period_ms,
+ std::make_unique<MonoVadImpl>(cpu_features),
+ sample_rate_hz) {}
VoiceActivityDetectorWrapper::VoiceActivityDetectorWrapper(
int vad_reset_period_ms,
- std::unique_ptr<MonoVad> vad)
+ std::unique_ptr<MonoVad> vad,
+ int sample_rate_hz)
: vad_reset_period_frames_(
rtc::CheckedDivExact(vad_reset_period_ms, kFrameDurationMs)),
- initialized_(false),
- frame_size_(0),
time_to_vad_reset_(vad_reset_period_frames_),
vad_(std::move(vad)) {
RTC_DCHECK(vad_);
RTC_DCHECK_GT(vad_reset_period_frames_, 1);
resampled_buffer_.resize(
rtc::CheckedDivExact(vad_->SampleRateHz(), kNumFramesPerSecond));
+ Initialize(sample_rate_hz);
}
VoiceActivityDetectorWrapper::~VoiceActivityDetectorWrapper() = default;
@@ -85,11 +86,9 @@
constexpr int kStatusOk = 0;
RTC_DCHECK_EQ(status, kStatusOk);
vad_->Reset();
- initialized_ = true;
}
float VoiceActivityDetectorWrapper::Analyze(AudioFrameView<const float> frame) {
- RTC_DCHECK(initialized_);
// Periodically reset the VAD.
time_to_vad_reset_--;
if (time_to_vad_reset_ <= 0) {
diff --git a/modules/audio_processing/agc2/vad_wrapper.h b/modules/audio_processing/agc2/vad_wrapper.h
index 0579ca1..6df0ead 100644
--- a/modules/audio_processing/agc2/vad_wrapper.h
+++ b/modules/audio_processing/agc2/vad_wrapper.h
@@ -43,20 +43,20 @@
// Ctor. `vad_reset_period_ms` indicates the period in milliseconds to call
// `MonoVad::Reset()`; it must be equal to or greater than the duration of two
// frames. Uses `cpu_features` to instantiate the default VAD.
- // TODO(bugs.webrtc.org/7494): Pass sample rate.
VoiceActivityDetectorWrapper(int vad_reset_period_ms,
- const AvailableCpuFeatures& cpu_features);
+ const AvailableCpuFeatures& cpu_features,
+ int sample_rate_hz);
// Ctor. Uses a custom `vad`.
VoiceActivityDetectorWrapper(int vad_reset_period_ms,
- std::unique_ptr<MonoVad> vad);
+ std::unique_ptr<MonoVad> vad,
+ int sample_rate_hz);
VoiceActivityDetectorWrapper(const VoiceActivityDetectorWrapper&) = delete;
VoiceActivityDetectorWrapper& operator=(const VoiceActivityDetectorWrapper&) =
delete;
~VoiceActivityDetectorWrapper();
- // TODO(bugs.webrtc.org/7494): Call initialize in the ctor.
- // Initializes the VAD wrapper. Must be called before `Analyze()`.
+ // Initializes the VAD wrapper.
void Initialize(int sample_rate_hz);
// Analyzes the first channel of `frame` and returns the speech probability.
@@ -66,8 +66,6 @@
private:
const int vad_reset_period_frames_;
- // TODO(bugs.webrtc.org/7494): Remove `initialized_`.
- bool initialized_;
int frame_size_;
int time_to_vad_reset_;
PushResampler<float> resampler_;
diff --git a/modules/audio_processing/agc2/vad_wrapper_unittest.cc b/modules/audio_processing/agc2/vad_wrapper_unittest.cc
index 27e5af6..b61b015 100644
--- a/modules/audio_processing/agc2/vad_wrapper_unittest.cc
+++ b/modules/audio_processing/agc2/vad_wrapper_unittest.cc
@@ -31,6 +31,8 @@
using ::testing::ReturnRoundRobin;
using ::testing::Truly;
+constexpr int kNumFramesPerSecond = 100;
+
constexpr int kNoVadPeriodicReset =
kFrameDurationMs * (std::numeric_limits<int>::max() / kFrameDurationMs);
@@ -52,8 +54,7 @@
.WillRepeatedly(Return(kSampleRate8kHz));
EXPECT_CALL(*vad, Reset).Times(AnyNumber());
auto vad_wrapper = std::make_unique<VoiceActivityDetectorWrapper>(
- kNoVadPeriodicReset, std::move(vad));
- vad_wrapper->Initialize(kSampleRate8kHz);
+ kNoVadPeriodicReset, std::move(vad), kSampleRate8kHz);
}
// Creates a `VoiceActivityDetectorWrapper` injecting a mock VAD that
@@ -61,27 +62,29 @@
// restarts from the beginning when after the last element is returned.
std::unique_ptr<VoiceActivityDetectorWrapper> CreateMockVadWrapper(
int vad_reset_period_ms,
+ int sample_rate_hz,
const std::vector<float>& speech_probabilities,
int expected_vad_reset_calls) {
auto vad = std::make_unique<MockVad>();
EXPECT_CALL(*vad, SampleRateHz)
.Times(AnyNumber())
- .WillRepeatedly(Return(kSampleRate8kHz));
+ .WillRepeatedly(Return(sample_rate_hz));
if (expected_vad_reset_calls >= 0) {
EXPECT_CALL(*vad, Reset).Times(expected_vad_reset_calls);
}
EXPECT_CALL(*vad, Analyze)
.Times(AnyNumber())
.WillRepeatedly(ReturnRoundRobin(speech_probabilities));
- return std::make_unique<VoiceActivityDetectorWrapper>(vad_reset_period_ms,
- std::move(vad));
+ return std::make_unique<VoiceActivityDetectorWrapper>(
+ vad_reset_period_ms, std::move(vad), kSampleRate8kHz);
}
// 10 ms mono frame.
struct FrameWithView {
// Ctor. Initializes the frame samples with `value`.
explicit FrameWithView(int sample_rate_hz)
- : samples(rtc::CheckedDivExact(sample_rate_hz, 100), 0.0f),
+ : samples(rtc::CheckedDivExact(sample_rate_hz, kNumFramesPerSecond),
+ 0.0f),
channel0(samples.data()),
view(&channel0, /*num_channels=*/1, samples.size()) {}
std::vector<float> samples;
@@ -94,10 +97,9 @@
const std::vector<float> speech_probabilities{0.709f, 0.484f, 0.882f, 0.167f,
0.44f, 0.525f, 0.858f, 0.314f,
0.653f, 0.965f, 0.413f, 0.0f};
- auto vad_wrapper =
- CreateMockVadWrapper(kNoVadPeriodicReset, speech_probabilities,
- /*expected_vad_reset_calls=*/1);
- vad_wrapper->Initialize(kSampleRate8kHz);
+ auto vad_wrapper = CreateMockVadWrapper(kNoVadPeriodicReset, kSampleRate8kHz,
+ speech_probabilities,
+ /*expected_vad_reset_calls=*/1);
FrameWithView frame(kSampleRate8kHz);
for (int i = 0; rtc::SafeLt(i, speech_probabilities.size()); ++i) {
SCOPED_TRACE(i);
@@ -108,10 +110,9 @@
// Checks that the VAD is not periodically reset.
TEST(GainController2VoiceActivityDetectorWrapper, VadNoPeriodicReset) {
constexpr int kNumFrames = 19;
- auto vad_wrapper =
- CreateMockVadWrapper(kNoVadPeriodicReset, /*speech_probabilities=*/{1.0f},
- /*expected_vad_reset_calls=*/1);
- vad_wrapper->Initialize(kSampleRate8kHz);
+ auto vad_wrapper = CreateMockVadWrapper(kNoVadPeriodicReset, kSampleRate8kHz,
+ /*speech_probabilities=*/{1.0f},
+ /*expected_vad_reset_calls=*/1);
FrameWithView frame(kSampleRate8kHz);
for (int i = 0; i < kNumFrames; ++i) {
vad_wrapper->Analyze(frame.view);
@@ -129,10 +130,10 @@
TEST_P(VadPeriodResetParametrization, VadPeriodicReset) {
auto vad_wrapper = CreateMockVadWrapper(
/*vad_reset_period_ms=*/vad_reset_period_frames() * kFrameDurationMs,
+ kSampleRate8kHz,
/*speech_probabilities=*/{1.0f},
/*expected_vad_reset_calls=*/1 +
num_frames() / vad_reset_period_frames());
- vad_wrapper->Initialize(kSampleRate8kHz);
FrameWithView frame(kSampleRate8kHz);
for (int i = 0; i < num_frames(); ++i) {
vad_wrapper->Analyze(frame.view);
@@ -161,13 +162,12 @@
.WillRepeatedly(Return(vad_sample_rate_hz()));
EXPECT_CALL(*vad, Reset).Times(1);
EXPECT_CALL(*vad, Analyze(Truly([this](rtc::ArrayView<const float> frame) {
- return rtc::SafeEq(frame.size(),
- rtc::CheckedDivExact(vad_sample_rate_hz(), 100));
+ return rtc::SafeEq(frame.size(), rtc::CheckedDivExact(vad_sample_rate_hz(),
+ kNumFramesPerSecond));
}))).Times(1);
auto vad_wrapper = std::make_unique<VoiceActivityDetectorWrapper>(
- kNoVadPeriodicReset, std::move(vad));
+ kNoVadPeriodicReset, std::move(vad), input_sample_rate_hz());
FrameWithView frame(input_sample_rate_hz());
- vad_wrapper->Initialize(input_sample_rate_hz());
vad_wrapper->Analyze(frame.view);
}
diff --git a/modules/audio_processing/gain_controller2.cc b/modules/audio_processing/gain_controller2.cc
index 416bc78..a21ef72 100644
--- a/modules/audio_processing/gain_controller2.cc
+++ b/modules/audio_processing/gain_controller2.cc
@@ -14,6 +14,7 @@
#include <utility>
#include "common_audio/include/audio_util.h"
+#include "modules/audio_processing/agc2/cpu_features.h"
#include "modules/audio_processing/audio_buffer.h"
#include "modules/audio_processing/include/audio_frame_view.h"
#include "modules/audio_processing/logging/apm_data_dumper.h"
@@ -21,6 +22,7 @@
#include "rtc_base/checks.h"
#include "rtc_base/logging.h"
#include "rtc_base/strings/string_builder.h"
+#include "system_wrappers/include/field_trial.h"
namespace webrtc {
namespace {
@@ -33,6 +35,21 @@
constexpr int kLogLimiterStatsPeriodNumFrames =
kLogLimiterStatsPeriodMs / kFrameLengthMs;
+// Detects the available CPU features and applies any kill-switches.
+AvailableCpuFeatures GetAllowedCpuFeatures() {
+ AvailableCpuFeatures features = GetAvailableCpuFeatures();
+ if (field_trial::IsEnabled("WebRTC-Agc2SimdSse2KillSwitch")) {
+ features.sse2 = false;
+ }
+ if (field_trial::IsEnabled("WebRTC-Agc2SimdAvx2KillSwitch")) {
+ features.avx2 = false;
+ }
+ if (field_trial::IsEnabled("WebRTC-Agc2SimdNeonKillSwitch")) {
+ features.neon = false;
+ }
+ return features;
+}
+
// Creates an adaptive digital gain controller if enabled.
std::unique_ptr<AdaptiveAgc> CreateAdaptiveDigitalController(
const Agc2Config::AdaptiveDigital& config,
@@ -40,7 +57,8 @@
int num_channels,
ApmDataDumper* data_dumper) {
if (config.enabled) {
- // TODO(bugs.webrtc.org/7494): Also init with sample rate and num channels.
+ // TODO(bugs.webrtc.org/7494): Also init with sample rate and num
+ // channels.
auto controller = std::make_unique<AdaptiveAgc>(data_dumper, config);
// TODO(bugs.webrtc.org/7494): Remove once passed to the ctor.
controller->Initialize(sample_rate_hz, num_channels);
@@ -56,7 +74,8 @@
GainController2::GainController2(const Agc2Config& config,
int sample_rate_hz,
int num_channels)
- : data_dumper_(rtc::AtomicOps::Increment(&instance_count_)),
+ : cpu_features_(GetAllowedCpuFeatures()),
+ data_dumper_(rtc::AtomicOps::Increment(&instance_count_)),
fixed_gain_applier_(/*hard_clip_samples=*/false,
/*initial_gain_factor=*/0.0f),
adaptive_digital_controller_(
@@ -71,6 +90,14 @@
data_dumper_.InitiateNewSetOfRecordings();
// TODO(bugs.webrtc.org/7494): Set gain when `fixed_gain_applier_` is init'd.
fixed_gain_applier_.SetGainFactor(DbToRatio(config.fixed_digital.gain_db));
+ const bool use_vad = config.adaptive_digital.enabled;
+ if (use_vad) {
+ // TODO(bugs.webrtc.org/7494): Move `vad_reset_period_ms` from adaptive
+ // digital to gain controller 2 config.
+ vad_ = std::make_unique<VoiceActivityDetectorWrapper>(
+ config.adaptive_digital.vad_reset_period_ms, cpu_features_,
+ sample_rate_hz);
+ }
}
GainController2::~GainController2() = default;
@@ -82,6 +109,9 @@
sample_rate_hz == AudioProcessing::kSampleRate48kHz);
// TODO(bugs.webrtc.org/7494): Initialize `fixed_gain_applier_`.
limiter_.SetSampleRate(sample_rate_hz);
+ if (vad_) {
+ vad_->Initialize(sample_rate_hz);
+ }
if (adaptive_digital_controller_) {
adaptive_digital_controller_->Initialize(sample_rate_hz, num_channels);
}
@@ -104,10 +134,17 @@
data_dumper_.DumpRaw("agc2_notified_analog_level", analog_level_);
AudioFrameView<float> float_frame(audio->channels(), audio->num_channels(),
audio->num_frames());
+ absl::optional<float> speech_probability;
+ // TODO(bugs.webrtc.org/7494): Apply fixed digital gain after VAD.
fixed_gain_applier_.ApplyGain(float_frame);
+ if (vad_) {
+ speech_probability = vad_->Analyze(float_frame);
+ data_dumper_.DumpRaw("agc2_speech_probability", speech_probability.value());
+ }
if (adaptive_digital_controller_) {
- adaptive_digital_controller_->Process(float_frame,
- limiter_.LastAudioLevel());
+ RTC_DCHECK(speech_probability.has_value());
+ adaptive_digital_controller_->Process(
+ float_frame, speech_probability.value(), limiter_.LastAudioLevel());
}
limiter_.Process(float_frame);
diff --git a/modules/audio_processing/gain_controller2.h b/modules/audio_processing/gain_controller2.h
index cfe51c2..7e28ee2 100644
--- a/modules/audio_processing/gain_controller2.h
+++ b/modules/audio_processing/gain_controller2.h
@@ -15,8 +15,10 @@
#include <string>
#include "modules/audio_processing/agc2/adaptive_agc.h"
+#include "modules/audio_processing/agc2/cpu_features.h"
#include "modules/audio_processing/agc2/gain_applier.h"
#include "modules/audio_processing/agc2/limiter.h"
+#include "modules/audio_processing/agc2/vad_wrapper.h"
#include "modules/audio_processing/include/audio_processing.h"
#include "modules/audio_processing/logging/apm_data_dumper.h"
@@ -51,8 +53,10 @@
private:
static int instance_count_;
+ const AvailableCpuFeatures cpu_features_;
ApmDataDumper data_dumper_;
GainApplier fixed_gain_applier_;
+ std::unique_ptr<VoiceActivityDetectorWrapper> vad_;
std::unique_ptr<AdaptiveAgc> adaptive_digital_controller_;
Limiter limiter_;
int calls_since_last_limiter_log_;