AGC2 periodically reset VAD state
Bug: webrtc:7494
Change-Id: I880ef3991ade4e429ccde843571f069ede149c0e
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/213342
Commit-Queue: Alessio Bazzica <alessiob@webrtc.org>
Reviewed-by: Jesus de Vicente Pena <devicentepena@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#33604}
diff --git a/modules/audio_processing/agc2/adaptive_agc.cc b/modules/audio_processing/agc2/adaptive_agc.cc
index e72942a..9657aec 100644
--- a/modules/audio_processing/agc2/adaptive_agc.cc
+++ b/modules/audio_processing/agc2/adaptive_agc.cc
@@ -30,8 +30,8 @@
}
constexpr int kGainApplierAdjacentSpeechFramesThreshold = 1;
-constexpr float kMaxGainChangePerSecondDb = 3.f;
-constexpr float kMaxOutputNoiseLevelDbfs = -50.f;
+constexpr float kMaxGainChangePerSecondDb = 3.0f;
+constexpr float kMaxOutputNoiseLevelDbfs = -50.0f;
// Detects the available CPU features and applies any kill-switches.
AvailableCpuFeatures GetAllowedCpuFeatures(
@@ -71,7 +71,8 @@
.level_estimator_adjacent_speech_frames_threshold,
config.adaptive_digital.initial_saturation_margin_db,
config.adaptive_digital.extra_saturation_margin_db),
- vad_(config.adaptive_digital.vad_probability_attack,
+ vad_(config.adaptive_digital.vad_reset_period_ms,
+ config.adaptive_digital.vad_probability_attack,
GetAllowedCpuFeatures(config.adaptive_digital)),
gain_applier_(
apm_data_dumper,
@@ -95,7 +96,7 @@
info.input_level_dbfs = speech_level_estimator_.level_dbfs();
info.input_noise_level_dbfs = noise_level_estimator_.Analyze(frame);
info.limiter_envelope_dbfs =
- limiter_envelope > 0 ? FloatS16ToDbfs(limiter_envelope) : -90.f;
+ limiter_envelope > 0 ? FloatS16ToDbfs(limiter_envelope) : -90.0f;
info.estimate_is_confident = speech_level_estimator_.IsConfident();
DumpDebugData(info, *apm_data_dumper_);
gain_applier_.Process(info, frame);
diff --git a/modules/audio_processing/agc2/agc2_common.h b/modules/audio_processing/agc2/agc2_common.h
index 594a37e..d0df43f 100644
--- a/modules/audio_processing/agc2/agc2_common.h
+++ b/modules/audio_processing/agc2/agc2_common.h
@@ -15,20 +15,20 @@
namespace webrtc {
-constexpr float kMinFloatS16Value = -32768.f;
-constexpr float kMaxFloatS16Value = 32767.f;
+constexpr float kMinFloatS16Value = -32768.0f;
+constexpr float kMaxFloatS16Value = 32767.0f;
constexpr float kMaxAbsFloatS16Value = 32768.0f;
constexpr int kFrameDurationMs = 10;
constexpr int kSubFramesInFrame = 20;
constexpr int kMaximalNumberOfSamplesPerChannel = 480;
-constexpr float kAttackFilterConstant = 0.f;
+constexpr float kAttackFilterConstant = 0.0f;
// Adaptive digital gain applier settings below.
-constexpr float kHeadroomDbfs = 1.f;
-constexpr float kMaxGainDb = 30.f;
-constexpr float kInitialAdaptiveDigitalGainDb = 8.f;
+constexpr float kHeadroomDbfs = 1.0f;
+constexpr float kMaxGainDb = 30.0f;
+constexpr float kInitialAdaptiveDigitalGainDb = 8.0f;
// At what limiter levels should we start decreasing the adaptive digital gain.
constexpr float kLimiterThresholdForAgcGainDbfs = -kHeadroomDbfs;
@@ -39,17 +39,18 @@
// The amount of 'memory' of the Level Estimator. Decides leak factors.
constexpr int kFullBufferSizeMs = 1200;
-constexpr float kFullBufferLeakFactor = 1.f - 1.f / kFullBufferSizeMs;
+constexpr float kFullBufferLeakFactor = 1.0f - 1.0f / kFullBufferSizeMs;
-constexpr float kInitialSpeechLevelEstimateDbfs = -30.f;
+constexpr float kInitialSpeechLevelEstimateDbfs = -30.0f;
// Robust VAD probability and speech decisions.
-constexpr float kDefaultSmoothedVadProbabilityAttack = 1.f;
+constexpr int kDefaultVadRnnResetPeriodMs = 1500;
+constexpr float kDefaultSmoothedVadProbabilityAttack = 1.0f;
constexpr int kDefaultLevelEstimatorAdjacentSpeechFramesThreshold = 1;
// Saturation Protector settings.
-constexpr float kDefaultInitialSaturationMarginDb = 20.f;
-constexpr float kDefaultExtraSaturationMarginDb = 2.f;
+constexpr float kDefaultInitialSaturationMarginDb = 20.0f;
+constexpr float kDefaultExtraSaturationMarginDb = 2.0f;
constexpr int kPeakEnveloperSuperFrameLengthMs = 400;
static_assert(kFullBufferSizeMs % kPeakEnveloperSuperFrameLengthMs == 0,
diff --git a/modules/audio_processing/agc2/vad_with_level.cc b/modules/audio_processing/agc2/vad_with_level.cc
index b54ae56..597c09c 100644
--- a/modules/audio_processing/agc2/vad_with_level.cc
+++ b/modules/audio_processing/agc2/vad_with_level.cc
@@ -38,6 +38,8 @@
Vad& operator=(const Vad&) = delete;
~Vad() = default;
+ void Reset() override { rnn_vad_.Reset(); }
+
float ComputeProbability(AudioFrameView<const float> frame) override {
// The source number of channels is 1, because we always use the 1st
// channel.
@@ -66,41 +68,57 @@
// Returns an updated version of `p_old` by using instant decay and the given
// `attack` on a new VAD probability value `p_new`.
float SmoothedVadProbability(float p_old, float p_new, float attack) {
- RTC_DCHECK_GT(attack, 0.f);
- RTC_DCHECK_LE(attack, 1.f);
- if (p_new < p_old || attack == 1.f) {
+ RTC_DCHECK_GT(attack, 0.0f);
+ RTC_DCHECK_LE(attack, 1.0f);
+ if (p_new < p_old || attack == 1.0f) {
// Instant decay (or no smoothing).
return p_new;
} else {
// Attack phase.
- return attack * p_new + (1.f - attack) * p_old;
+ return attack * p_new + (1.0f - attack) * p_old;
}
}
} // namespace
VadLevelAnalyzer::VadLevelAnalyzer()
- : VadLevelAnalyzer(kDefaultSmoothedVadProbabilityAttack,
+ : VadLevelAnalyzer(kDefaultVadRnnResetPeriodMs,
+ kDefaultSmoothedVadProbabilityAttack,
GetAvailableCpuFeatures()) {}
-VadLevelAnalyzer::VadLevelAnalyzer(float vad_probability_attack,
+VadLevelAnalyzer::VadLevelAnalyzer(int vad_reset_period_ms,
+ float vad_probability_attack,
const AvailableCpuFeatures& cpu_features)
- : VadLevelAnalyzer(vad_probability_attack,
+ : VadLevelAnalyzer(vad_reset_period_ms,
+ vad_probability_attack,
std::make_unique<Vad>(cpu_features)) {}
-VadLevelAnalyzer::VadLevelAnalyzer(float vad_probability_attack,
+VadLevelAnalyzer::VadLevelAnalyzer(int vad_reset_period_ms,
+ float vad_probability_attack,
std::unique_ptr<VoiceActivityDetector> vad)
- : vad_(std::move(vad)), vad_probability_attack_(vad_probability_attack) {
+ : vad_(std::move(vad)),
+ vad_reset_period_frames_(
+ rtc::CheckedDivExact(vad_reset_period_ms, kFrameDurationMs)),
+ vad_probability_attack_(vad_probability_attack),
+ time_to_vad_reset_(vad_reset_period_frames_),
+ vad_probability_(0.0f) {
RTC_DCHECK(vad_);
+ RTC_DCHECK_GT(vad_reset_period_frames_, 1);
}
VadLevelAnalyzer::~VadLevelAnalyzer() = default;
VadLevelAnalyzer::Result VadLevelAnalyzer::AnalyzeFrame(
AudioFrameView<const float> frame) {
+ // Periodically reset the VAD.
+ time_to_vad_reset_--;
+ if (time_to_vad_reset_ <= 0) {
+ vad_->Reset();
+ time_to_vad_reset_ = vad_reset_period_frames_;
+ }
// Compute levels.
- float peak = 0.f;
- float rms = 0.f;
+ float peak = 0.0f;
+ float rms = 0.0f;
for (const auto& x : frame.channel(0)) {
peak = std::max(std::fabs(x), peak);
rms += x * x;
diff --git a/modules/audio_processing/agc2/vad_with_level.h b/modules/audio_processing/agc2/vad_with_level.h
index 2a67882..386f162 100644
--- a/modules/audio_processing/agc2/vad_with_level.h
+++ b/modules/audio_processing/agc2/vad_with_level.h
@@ -31,17 +31,26 @@
class VoiceActivityDetector {
public:
virtual ~VoiceActivityDetector() = default;
+ // Resets the internal state.
+ virtual void Reset() = 0;
// Analyzes an audio frame and returns the speech probability.
virtual float ComputeProbability(AudioFrameView<const float> frame) = 0;
};
// Ctor. Uses the default VAD.
VadLevelAnalyzer();
- VadLevelAnalyzer(float vad_probability_attack,
+ // Ctor. `vad_reset_period_ms` indicates the period in milliseconds to call
+ // `VadLevelAnalyzer::Reset()`; it must be equal to or greater than the
+ // duration of two frames. `vad_probability_attack` is a number in (0,1] used
+ // to smooth the speech probability (instant decay, slow attack).
+ VadLevelAnalyzer(int vad_reset_period_ms,
+ float vad_probability_attack,
const AvailableCpuFeatures& cpu_features);
// Ctor. Uses a custom `vad`.
- VadLevelAnalyzer(float vad_probability_attack,
+ VadLevelAnalyzer(int vad_reset_period_ms,
+ float vad_probability_attack,
std::unique_ptr<VoiceActivityDetector> vad);
+
VadLevelAnalyzer(const VadLevelAnalyzer&) = delete;
VadLevelAnalyzer& operator=(const VadLevelAnalyzer&) = delete;
~VadLevelAnalyzer();
@@ -51,8 +60,10 @@
private:
std::unique_ptr<VoiceActivityDetector> vad_;
+ const int vad_reset_period_frames_;
const float vad_probability_attack_;
- float vad_probability_ = 0.f;
+ int time_to_vad_reset_;
+ float vad_probability_;
};
} // namespace webrtc
diff --git a/modules/audio_processing/agc2/vad_with_level_unittest.cc b/modules/audio_processing/agc2/vad_with_level_unittest.cc
index fb93c86..fd8265e 100644
--- a/modules/audio_processing/agc2/vad_with_level_unittest.cc
+++ b/modules/audio_processing/agc2/vad_with_level_unittest.cc
@@ -10,6 +10,7 @@
#include "modules/audio_processing/agc2/vad_with_level.h"
+#include <limits>
#include <memory>
#include <vector>
@@ -25,13 +26,17 @@
using ::testing::AnyNumber;
using ::testing::ReturnRoundRobin;
-constexpr float kInstantAttack = 1.f;
+constexpr int kNoVadPeriodicReset =
+ kFrameDurationMs * (std::numeric_limits<int>::max() / kFrameDurationMs);
+
+constexpr float kInstantAttack = 1.0f;
constexpr float kSlowAttack = 0.1f;
constexpr int kSampleRateHz = 8000;
class MockVad : public VadLevelAnalyzer::VoiceActivityDetector {
public:
+ MOCK_METHOD(void, Reset, (), (override));
MOCK_METHOD(float,
ComputeProbability,
(AudioFrameView<const float> frame),
@@ -42,20 +47,25 @@
// the next value from `speech_probabilities` until it reaches the end and will
// restart from the beginning.
std::unique_ptr<VadLevelAnalyzer> CreateVadLevelAnalyzerWithMockVad(
+ int vad_reset_period_ms,
float vad_probability_attack,
- const std::vector<float>& speech_probabilities) {
+ const std::vector<float>& speech_probabilities,
+ int expected_vad_reset_calls = 0) {
auto vad = std::make_unique<MockVad>();
EXPECT_CALL(*vad, ComputeProbability)
.Times(AnyNumber())
.WillRepeatedly(ReturnRoundRobin(speech_probabilities));
- return std::make_unique<VadLevelAnalyzer>(vad_probability_attack,
- std::move(vad));
+ if (expected_vad_reset_calls >= 0) {
+ EXPECT_CALL(*vad, Reset).Times(expected_vad_reset_calls);
+ }
+ return std::make_unique<VadLevelAnalyzer>(
+ vad_reset_period_ms, vad_probability_attack, std::move(vad));
}
// 10 ms mono frame.
struct FrameWithView {
// Ctor. Initializes the frame samples with `value`.
- FrameWithView(float value = 0.f)
+ FrameWithView(float value = 0.0f)
: channel0(samples.data()),
view(&channel0, /*num_channels=*/1, samples.size()) {
samples.fill(value);
@@ -67,8 +77,8 @@
TEST(AutomaticGainController2VadLevelAnalyzer, PeakLevelGreaterThanRmsLevel) {
// Handcrafted frame so that the average is lower than the peak value.
- FrameWithView frame(1000.f); // Constant frame.
- frame.samples[10] = 2000.f; // Except for one peak value.
+ FrameWithView frame(1000.0f); // Constant frame.
+ frame.samples[10] = 2000.0f; // Except for one peak value.
// Compute audio frame levels (the VAD result is ignored).
VadLevelAnalyzer analyzer;
@@ -83,9 +93,9 @@
TEST(AutomaticGainController2VadLevelAnalyzer, NoSpeechProbabilitySmoothing) {
const std::vector<float> speech_probabilities{0.709f, 0.484f, 0.882f, 0.167f,
0.44f, 0.525f, 0.858f, 0.314f,
- 0.653f, 0.965f, 0.413f, 0.f};
- auto analyzer =
- CreateVadLevelAnalyzerWithMockVad(kInstantAttack, speech_probabilities);
+ 0.653f, 0.965f, 0.413f, 0.0f};
+ auto analyzer = CreateVadLevelAnalyzerWithMockVad(
+ kNoVadPeriodicReset, kInstantAttack, speech_probabilities);
FrameWithView frame;
for (int i = 0; rtc::SafeLt(i, speech_probabilities.size()); ++i) {
SCOPED_TRACE(i);
@@ -98,16 +108,17 @@
// the unprocessed one when slow attack is used.
TEST(AutomaticGainController2VadLevelAnalyzer,
SlowAttackSpeechProbabilitySmoothing) {
- const std::vector<float> speech_probabilities{0.f, 0.f, 1.f, 1.f, 1.f, 1.f};
- auto analyzer =
- CreateVadLevelAnalyzerWithMockVad(kSlowAttack, speech_probabilities);
+ const std::vector<float> speech_probabilities{0.0f, 0.0f, 1.0f,
+ 1.0f, 1.0f, 1.0f};
+ auto analyzer = CreateVadLevelAnalyzerWithMockVad(
+ kNoVadPeriodicReset, kSlowAttack, speech_probabilities);
FrameWithView frame;
- float prev_probability = 0.f;
+ float prev_probability = 0.0f;
for (int i = 0; rtc::SafeLt(i, speech_probabilities.size()); ++i) {
SCOPED_TRACE(i);
const float smoothed_probability =
analyzer->AnalyzeFrame(frame.view).speech_probability;
- EXPECT_LT(smoothed_probability, 1.f); // Not enough time to reach 1.
+ EXPECT_LT(smoothed_probability, 1.0f); // Not enough time to reach 1.
EXPECT_LE(prev_probability, smoothed_probability); // Converge towards 1.
prev_probability = smoothed_probability;
}
@@ -116,15 +127,52 @@
// Checks that the smoothed speech probability instantly decays to the
// unprocessed one when slow attack is used.
TEST(AutomaticGainController2VadLevelAnalyzer, SpeechProbabilityInstantDecay) {
- const std::vector<float> speech_probabilities{1.f, 1.f, 1.f, 1.f, 1.f, 0.f};
- auto analyzer =
- CreateVadLevelAnalyzerWithMockVad(kSlowAttack, speech_probabilities);
+ const std::vector<float> speech_probabilities{1.0f, 1.0f, 1.0f,
+ 1.0f, 1.0f, 0.0f};
+ auto analyzer = CreateVadLevelAnalyzerWithMockVad(
+ kNoVadPeriodicReset, kSlowAttack, speech_probabilities);
FrameWithView frame;
for (int i = 0; rtc::SafeLt(i, speech_probabilities.size() - 1); ++i) {
analyzer->AnalyzeFrame(frame.view);
}
- EXPECT_EQ(0.f, analyzer->AnalyzeFrame(frame.view).speech_probability);
+ EXPECT_EQ(0.0f, analyzer->AnalyzeFrame(frame.view).speech_probability);
}
+// Checks that the VAD is not periodically reset.
+TEST(AutomaticGainController2VadLevelAnalyzer, VadNoPeriodicReset) {
+ constexpr int kNumFrames = 19;
+ auto analyzer = CreateVadLevelAnalyzerWithMockVad(
+ kNoVadPeriodicReset, kSlowAttack, /*speech_probabilities=*/{1.0f},
+ /*expected_vad_reset_calls=*/0);
+ FrameWithView frame;
+ for (int i = 0; i < kNumFrames; ++i) {
+ analyzer->AnalyzeFrame(frame.view);
+ }
+}
+
+class VadPeriodResetParametrization
+ : public ::testing::TestWithParam<std::tuple<int, int>> {
+ protected:
+ int num_frames() const { return std::get<0>(GetParam()); }
+ int vad_reset_period_frames() const { return std::get<1>(GetParam()); }
+};
+
+// Checks that the VAD is periodically reset with the expected period.
+TEST_P(VadPeriodResetParametrization, VadPeriodicReset) {
+ auto analyzer = CreateVadLevelAnalyzerWithMockVad(
+ /*vad_reset_period_ms=*/vad_reset_period_frames() * kFrameDurationMs,
+ kSlowAttack, /*speech_probabilities=*/{1.0f},
+ /*expected_vad_reset_calls=*/num_frames() / vad_reset_period_frames());
+ FrameWithView frame;
+ for (int i = 0; i < num_frames(); ++i) {
+ analyzer->AnalyzeFrame(frame.view);
+ }
+}
+
+INSTANTIATE_TEST_SUITE_P(AutomaticGainController2VadLevelAnalyzer,
+ VadPeriodResetParametrization,
+ ::testing::Combine(::testing::Values(1, 19, 123),
+ ::testing::Values(2, 5, 20, 53)));
+
} // namespace
} // namespace webrtc
diff --git a/modules/audio_processing/include/audio_processing.h b/modules/audio_processing/include/audio_processing.h
index bb24a48..a5c266a 100644
--- a/modules/audio_processing/include/audio_processing.h
+++ b/modules/audio_processing/include/audio_processing.h
@@ -210,7 +210,7 @@
// capture_level_adjustment instead.
struct PreAmplifier {
bool enabled = false;
- float fixed_gain_factor = 1.f;
+ float fixed_gain_factor = 1.0f;
} pre_amplifier;
// Functionality for general level adjustment in the capture pipeline. This
@@ -222,9 +222,9 @@
}
bool enabled = false;
// The `pre_gain_factor` scales the signal before any processing is done.
- float pre_gain_factor = 1.f;
+ float pre_gain_factor = 1.0f;
// The `post_gain_factor` scales the signal after all processing is done.
- float post_gain_factor = 1.f;
+ float post_gain_factor = 1.0f;
struct AnalogMicGainEmulation {
bool operator==(const AnalogMicGainEmulation& rhs) const;
bool operator!=(const AnalogMicGainEmulation& rhs) const {
@@ -352,20 +352,21 @@
enum LevelEstimator { kRms, kPeak };
bool enabled = false;
struct FixedDigital {
- float gain_db = 0.f;
+ float gain_db = 0.0f;
} fixed_digital;
struct AdaptiveDigital {
bool enabled = false;
+ int vad_reset_period_ms = 1500;
float vad_probability_attack = 0.3f;
LevelEstimator level_estimator = kRms;
int level_estimator_adjacent_speech_frames_threshold = 6;
// TODO(crbug.com/webrtc/7494): Remove `use_saturation_protector`.
bool use_saturation_protector = true;
- float initial_saturation_margin_db = 20.f;
- float extra_saturation_margin_db = 5.f;
+ float initial_saturation_margin_db = 20.0f;
+ float extra_saturation_margin_db = 5.0f;
int gain_applier_adjacent_speech_frames_threshold = 6;
- float max_gain_change_db_per_second = 3.f;
- float max_output_noise_level_dbfs = -55.f;
+ float max_gain_change_db_per_second = 3.0f;
+ float max_output_noise_level_dbfs = -55.0f;
bool sse2_allowed = true;
bool avx2_allowed = true;
bool neon_allowed = true;
@@ -417,7 +418,7 @@
int max_volume; // Maximum play-out volume.
};
- RuntimeSetting() : type_(Type::kNotSpecified), value_(0.f) {}
+ RuntimeSetting() : type_(Type::kNotSpecified), value_(0.0f) {}
~RuntimeSetting() = default;
static RuntimeSetting CreateCapturePreGain(float gain) {
@@ -439,8 +440,8 @@
// Corresponds to Config::GainController2::fixed_digital::gain_db, but for
// runtime configuration.
static RuntimeSetting CreateCaptureFixedPostGain(float gain_db) {
- RTC_DCHECK_GE(gain_db, 0.f);
- RTC_DCHECK_LE(gain_db, 90.f);
+ RTC_DCHECK_GE(gain_db, 0.0f);
+ RTC_DCHECK_LE(gain_db, 90.0f);
return {Type::kCaptureFixedPostGain, gain_db};
}