AGC2: VadWithLevel -> VoiceActivityDetectorWrapper 2/2
Internal refactoring of AGC2 to decouple the VAD, its wrapper and the
peak and RMS level measurements.
Bit exactness verified with audioproc_f on a collection of AEC dumps
and Wav files (42 recordings in total).
Bug: webrtc:7494
Change-Id: Ib560f1fcaa601557f4f30e47025c69e91b1b62e0
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/234524
Commit-Queue: Alessio Bazzica <alessiob@webrtc.org>
Reviewed-by: Hanna Silen <silen@webrtc.org>
Cr-Commit-Position: refs/heads/main@{#35208}
diff --git a/modules/audio_processing/agc2/BUILD.gn b/modules/audio_processing/agc2/BUILD.gn
index a897c0b1..f767a6d 100644
--- a/modules/audio_processing/agc2/BUILD.gn
+++ b/modules/audio_processing/agc2/BUILD.gn
@@ -280,6 +280,7 @@
":common",
":vad_wrapper",
"..:audio_frame_view",
+ "../../../rtc_base:checks",
"../../../rtc_base:gunit_helpers",
"../../../rtc_base:safe_compare",
"../../../test:test_support",
diff --git a/modules/audio_processing/agc2/adaptive_agc.cc b/modules/audio_processing/agc2/adaptive_agc.cc
index ab1822d..fb06549 100644
--- a/modules/audio_processing/agc2/adaptive_agc.cc
+++ b/modules/audio_processing/agc2/adaptive_agc.cc
@@ -36,6 +36,24 @@
return features;
}
+// Peak and RMS audio levels in dBFS.
+struct AudioLevels {
+ float peak_dbfs;
+ float rms_dbfs;
+};
+
+// Computes the audio levels for the first channel in `frame`.
+AudioLevels ComputeAudioLevels(AudioFrameView<float> frame) {
+ float peak = 0.0f;
+ float rms = 0.0f;
+ for (const auto& x : frame.channel(0)) {
+ peak = std::max(std::fabs(x), peak);
+ rms += x * x;
+ }
+ return {FloatS16ToDbfs(peak),
+ FloatS16ToDbfs(std::sqrt(rms / frame.samples_per_channel()))};
+}
+
} // namespace
AdaptiveAgc::AdaptiveAgc(
@@ -62,16 +80,17 @@
}
void AdaptiveAgc::Process(AudioFrameView<float> frame, float limiter_envelope) {
+ AudioLevels levels = ComputeAudioLevels(frame);
+
AdaptiveDigitalGainApplier::FrameInfo info;
- VadLevelAnalyzer::Result vad_result = vad_.AnalyzeFrame(frame);
- info.speech_probability = vad_result.speech_probability;
- apm_data_dumper_->DumpRaw("agc2_speech_probability",
- vad_result.speech_probability);
- apm_data_dumper_->DumpRaw("agc2_input_rms_dbfs", vad_result.rms_dbfs);
- apm_data_dumper_->DumpRaw("agc2_input_peak_dbfs", vad_result.peak_dbfs);
+ info.speech_probability = vad_.Analyze(frame);
+ apm_data_dumper_->DumpRaw("agc2_speech_probability", info.speech_probability);
+ apm_data_dumper_->DumpRaw("agc2_input_rms_dbfs", levels.rms_dbfs);
+ apm_data_dumper_->DumpRaw("agc2_input_peak_dbfs", levels.peak_dbfs);
- speech_level_estimator_.Update(vad_result);
+ speech_level_estimator_.Update(levels.rms_dbfs, levels.peak_dbfs,
+ info.speech_probability);
info.speech_level_dbfs = speech_level_estimator_.level_dbfs();
info.speech_level_reliable = speech_level_estimator_.IsConfident();
apm_data_dumper_->DumpRaw("agc2_speech_level_dbfs", info.speech_level_dbfs);
@@ -81,7 +100,7 @@
info.noise_rms_dbfs = noise_level_estimator_->Analyze(frame);
apm_data_dumper_->DumpRaw("agc2_noise_rms_dbfs", info.noise_rms_dbfs);
- saturation_protector_->Analyze(info.speech_probability, vad_result.peak_dbfs,
+ saturation_protector_->Analyze(info.speech_probability, levels.peak_dbfs,
info.speech_level_dbfs);
info.headroom_db = saturation_protector_->HeadroomDb();
apm_data_dumper_->DumpRaw("agc2_headroom_db", info.headroom_db);
diff --git a/modules/audio_processing/agc2/adaptive_agc.h b/modules/audio_processing/agc2/adaptive_agc.h
index 8ee8378..32de680 100644
--- a/modules/audio_processing/agc2/adaptive_agc.h
+++ b/modules/audio_processing/agc2/adaptive_agc.h
@@ -47,7 +47,7 @@
private:
AdaptiveModeLevelEstimator speech_level_estimator_;
- VadLevelAnalyzer vad_;
+ VoiceActivityDetectorWrapper vad_;
AdaptiveDigitalGainApplier gain_controller_;
ApmDataDumper* const apm_data_dumper_;
std::unique_ptr<NoiseLevelEstimator> noise_level_estimator_;
diff --git a/modules/audio_processing/agc2/adaptive_mode_level_estimator.cc b/modules/audio_processing/agc2/adaptive_mode_level_estimator.cc
index 81e7d29..fe021fe 100644
--- a/modules/audio_processing/agc2/adaptive_mode_level_estimator.cc
+++ b/modules/audio_processing/agc2/adaptive_mode_level_estimator.cc
@@ -57,15 +57,16 @@
Reset();
}
-void AdaptiveModeLevelEstimator::Update(
- const VadLevelAnalyzer::Result& vad_level) {
- RTC_DCHECK_GT(vad_level.rms_dbfs, -150.f);
- RTC_DCHECK_LT(vad_level.rms_dbfs, 50.f);
- RTC_DCHECK_GT(vad_level.peak_dbfs, -150.f);
- RTC_DCHECK_LT(vad_level.peak_dbfs, 50.f);
- RTC_DCHECK_GE(vad_level.speech_probability, 0.f);
- RTC_DCHECK_LE(vad_level.speech_probability, 1.f);
- if (vad_level.speech_probability < kVadConfidenceThreshold) {
+void AdaptiveModeLevelEstimator::Update(float rms_dbfs,
+ float peak_dbfs,
+ float speech_probability) {
+ RTC_DCHECK_GT(rms_dbfs, -150.0f);
+ RTC_DCHECK_LT(rms_dbfs, 50.0f);
+ RTC_DCHECK_GT(peak_dbfs, -150.0f);
+ RTC_DCHECK_LT(peak_dbfs, 50.0f);
+ RTC_DCHECK_GE(speech_probability, 0.0f);
+ RTC_DCHECK_LE(speech_probability, 1.0f);
+ if (speech_probability < kVadConfidenceThreshold) {
// Not a speech frame.
if (adjacent_speech_frames_threshold_ > 1) {
// When two or more adjacent speech frames are required in order to update
@@ -93,14 +94,14 @@
preliminary_state_.time_to_confidence_ms -= kFrameDurationMs;
}
// Weighted average of levels with speech probability as weight.
- RTC_DCHECK_GT(vad_level.speech_probability, 0.f);
- const float leak_factor = buffer_is_full ? kLevelEstimatorLeakFactor : 1.f;
+ RTC_DCHECK_GT(speech_probability, 0.0f);
+ const float leak_factor = buffer_is_full ? kLevelEstimatorLeakFactor : 1.0f;
preliminary_state_.level_dbfs.numerator =
preliminary_state_.level_dbfs.numerator * leak_factor +
- vad_level.rms_dbfs * vad_level.speech_probability;
+ rms_dbfs * speech_probability;
preliminary_state_.level_dbfs.denominator =
preliminary_state_.level_dbfs.denominator * leak_factor +
- vad_level.speech_probability;
+ speech_probability;
const float level_dbfs = preliminary_state_.level_dbfs.GetRatio();
diff --git a/modules/audio_processing/agc2/adaptive_mode_level_estimator.h b/modules/audio_processing/agc2/adaptive_mode_level_estimator.h
index 14da6b7..989c8c3 100644
--- a/modules/audio_processing/agc2/adaptive_mode_level_estimator.h
+++ b/modules/audio_processing/agc2/adaptive_mode_level_estimator.h
@@ -33,7 +33,7 @@
delete;
// Updates the level estimation.
- void Update(const VadLevelAnalyzer::Result& vad_data);
+ void Update(float rms_dbfs, float peak_dbfs, float speech_probability);
// Returns the estimated speech plus noise level.
float level_dbfs() const { return level_dbfs_; }
// Returns true if the estimator is confident on its current estimate.
diff --git a/modules/audio_processing/agc2/adaptive_mode_level_estimator_unittest.cc b/modules/audio_processing/agc2/adaptive_mode_level_estimator_unittest.cc
index 1cdd91d..684fca1 100644
--- a/modules/audio_processing/agc2/adaptive_mode_level_estimator_unittest.cc
+++ b/modules/audio_processing/agc2/adaptive_mode_level_estimator_unittest.cc
@@ -33,10 +33,12 @@
// Provides the `vad_level` value `num_iterations` times to `level_estimator`.
void RunOnConstantLevel(int num_iterations,
- const VadLevelAnalyzer::Result& vad_level,
+ float rms_dbfs,
+ float peak_dbfs,
+ float speech_probability,
AdaptiveModeLevelEstimator& level_estimator) {
for (int i = 0; i < num_iterations; ++i) {
- level_estimator.Update(vad_level);
+ level_estimator.Update(rms_dbfs, peak_dbfs, speech_probability);
}
}
@@ -47,6 +49,10 @@
return config;
}
+constexpr float kNoSpeechProbability = 0.0f;
+constexpr float kLowSpeechProbability = kVadConfidenceThreshold / 2.0f;
+constexpr float kMaxSpeechProbability = 1.0f;
+
// Level estimator with data dumper.
struct TestLevelEstimator {
explicit TestLevelEstimator(int adjacent_speech_frames_threshold)
@@ -55,36 +61,31 @@
&data_dumper,
GetAdaptiveDigitalConfig(adjacent_speech_frames_threshold))),
initial_speech_level_dbfs(estimator->level_dbfs()),
- vad_level_rms(initial_speech_level_dbfs / 2.0f),
- vad_level_peak(initial_speech_level_dbfs / 3.0f),
- vad_data_speech(
- {/*speech_probability=*/1.0f, vad_level_rms, vad_level_peak}),
- vad_data_non_speech(
- {/*speech_probability=*/kVadConfidenceThreshold / 2.0f,
- vad_level_rms, vad_level_peak}) {
- RTC_DCHECK_LT(vad_level_rms, vad_level_peak);
- RTC_DCHECK_LT(initial_speech_level_dbfs, vad_level_rms);
- RTC_DCHECK_GT(vad_level_rms - initial_speech_level_dbfs, 5.0f)
- << "Adjust `vad_level_rms` so that the difference from the initial "
+ level_rms_dbfs(initial_speech_level_dbfs / 2.0f),
+ level_peak_dbfs(initial_speech_level_dbfs / 3.0f) {
+ RTC_DCHECK_LT(level_rms_dbfs, level_peak_dbfs);
+ RTC_DCHECK_LT(initial_speech_level_dbfs, level_rms_dbfs);
+ RTC_DCHECK_GT(level_rms_dbfs - initial_speech_level_dbfs, 5.0f)
+ << "Adjust `level_rms_dbfs` so that the difference from the initial "
"level is wide enough for the tests";
}
ApmDataDumper data_dumper;
std::unique_ptr<AdaptiveModeLevelEstimator> estimator;
const float initial_speech_level_dbfs;
- const float vad_level_rms;
- const float vad_level_peak;
- const VadLevelAnalyzer::Result vad_data_speech;
- const VadLevelAnalyzer::Result vad_data_non_speech;
+ const float level_rms_dbfs;
+ const float level_peak_dbfs;
};
// Checks that the level estimator converges to a constant input speech level.
TEST(GainController2AdaptiveModeLevelEstimator, LevelStabilizes) {
TestLevelEstimator level_estimator(/*adjacent_speech_frames_threshold=*/1);
RunOnConstantLevel(/*num_iterations=*/kNumFramesToConfidence,
- level_estimator.vad_data_speech,
+ level_estimator.level_rms_dbfs,
+ level_estimator.level_peak_dbfs, kMaxSpeechProbability,
*level_estimator.estimator);
const float estimated_level_dbfs = level_estimator.estimator->level_dbfs();
- RunOnConstantLevel(/*num_iterations=*/1, level_estimator.vad_data_speech,
+ RunOnConstantLevel(/*num_iterations=*/1, level_estimator.level_rms_dbfs,
+ level_estimator.level_peak_dbfs, kMaxSpeechProbability,
*level_estimator.estimator);
EXPECT_NEAR(level_estimator.estimator->level_dbfs(), estimated_level_dbfs,
0.1f);
@@ -95,7 +96,8 @@
TEST(GainController2AdaptiveModeLevelEstimator, IsNotConfident) {
TestLevelEstimator level_estimator(/*adjacent_speech_frames_threshold=*/1);
RunOnConstantLevel(/*num_iterations=*/kNumFramesToConfidence / 2,
- level_estimator.vad_data_speech,
+ level_estimator.level_rms_dbfs,
+ level_estimator.level_peak_dbfs, kMaxSpeechProbability,
*level_estimator.estimator);
EXPECT_FALSE(level_estimator.estimator->IsConfident());
}
@@ -105,7 +107,8 @@
TEST(GainController2AdaptiveModeLevelEstimator, IsConfident) {
TestLevelEstimator level_estimator(/*adjacent_speech_frames_threshold=*/1);
RunOnConstantLevel(/*num_iterations=*/kNumFramesToConfidence,
- level_estimator.vad_data_speech,
+ level_estimator.level_rms_dbfs,
+ level_estimator.level_peak_dbfs, kMaxSpeechProbability,
*level_estimator.estimator);
EXPECT_TRUE(level_estimator.estimator->IsConfident());
}
@@ -117,15 +120,14 @@
TestLevelEstimator level_estimator(/*adjacent_speech_frames_threshold=*/1);
// Simulate speech.
RunOnConstantLevel(/*num_iterations=*/kNumFramesToConfidence,
- level_estimator.vad_data_speech,
+ level_estimator.level_rms_dbfs,
+ level_estimator.level_peak_dbfs, kMaxSpeechProbability,
*level_estimator.estimator);
const float estimated_level_dbfs = level_estimator.estimator->level_dbfs();
// Simulate full-scale non-speech.
RunOnConstantLevel(/*num_iterations=*/kNumFramesToConfidence,
- VadLevelAnalyzer::Result{/*speech_probability=*/0.0f,
- /*rms_dbfs=*/0.0f,
- /*peak_dbfs=*/0.0f},
- *level_estimator.estimator);
+ /*rms_dbfs=*/0.0f, /*peak_dbfs=*/0.0f,
+ kNoSpeechProbability, *level_estimator.estimator);
// No estimated level change is expected.
EXPECT_FLOAT_EQ(level_estimator.estimator->level_dbfs(),
estimated_level_dbfs);
@@ -136,10 +138,11 @@
ConvergenceSpeedBeforeConfidence) {
TestLevelEstimator level_estimator(/*adjacent_speech_frames_threshold=*/1);
RunOnConstantLevel(/*num_iterations=*/kNumFramesToConfidence,
- level_estimator.vad_data_speech,
+ level_estimator.level_rms_dbfs,
+ level_estimator.level_peak_dbfs, kMaxSpeechProbability,
*level_estimator.estimator);
EXPECT_NEAR(level_estimator.estimator->level_dbfs(),
- level_estimator.vad_data_speech.rms_dbfs,
+ level_estimator.level_rms_dbfs,
kConvergenceSpeedTestsLevelTolerance);
}
@@ -150,11 +153,9 @@
// Reach confidence using the initial level estimate.
RunOnConstantLevel(
/*num_iterations=*/kNumFramesToConfidence,
- VadLevelAnalyzer::Result{
- /*speech_probability=*/1.0f,
- /*rms_dbfs=*/level_estimator.initial_speech_level_dbfs,
- /*peak_dbfs=*/level_estimator.initial_speech_level_dbfs + 6.0f},
- *level_estimator.estimator);
+ /*rms_dbfs=*/level_estimator.initial_speech_level_dbfs,
+ /*peak_dbfs=*/level_estimator.initial_speech_level_dbfs + 6.0f,
+ kMaxSpeechProbability, *level_estimator.estimator);
// No estimate change should occur, but confidence is achieved.
ASSERT_FLOAT_EQ(level_estimator.estimator->level_dbfs(),
level_estimator.initial_speech_level_dbfs);
@@ -165,9 +166,10 @@
kConvergenceTimeAfterConfidenceNumFrames > kNumFramesToConfidence, "");
RunOnConstantLevel(
/*num_iterations=*/kConvergenceTimeAfterConfidenceNumFrames,
- level_estimator.vad_data_speech, *level_estimator.estimator);
+ level_estimator.level_rms_dbfs, level_estimator.level_peak_dbfs,
+ kMaxSpeechProbability, *level_estimator.estimator);
EXPECT_NEAR(level_estimator.estimator->level_dbfs(),
- level_estimator.vad_data_speech.rms_dbfs,
+ level_estimator.level_rms_dbfs,
kConvergenceSpeedTestsLevelTolerance);
}
@@ -181,22 +183,28 @@
DoNotAdaptToShortSpeechSegments) {
TestLevelEstimator level_estimator(adjacent_speech_frames_threshold());
const float initial_level = level_estimator.estimator->level_dbfs();
- ASSERT_LT(initial_level, level_estimator.vad_data_speech.peak_dbfs);
+ ASSERT_LT(initial_level, level_estimator.level_peak_dbfs);
for (int i = 0; i < adjacent_speech_frames_threshold() - 1; ++i) {
SCOPED_TRACE(i);
- level_estimator.estimator->Update(level_estimator.vad_data_speech);
+ level_estimator.estimator->Update(level_estimator.level_rms_dbfs,
+ level_estimator.level_peak_dbfs,
+ kMaxSpeechProbability);
EXPECT_EQ(initial_level, level_estimator.estimator->level_dbfs());
}
- level_estimator.estimator->Update(level_estimator.vad_data_non_speech);
+ level_estimator.estimator->Update(level_estimator.level_rms_dbfs,
+ level_estimator.level_peak_dbfs,
+ kLowSpeechProbability);
EXPECT_EQ(initial_level, level_estimator.estimator->level_dbfs());
}
TEST_P(AdaptiveModeLevelEstimatorParametrization, AdaptToEnoughSpeechSegments) {
TestLevelEstimator level_estimator(adjacent_speech_frames_threshold());
const float initial_level = level_estimator.estimator->level_dbfs();
- ASSERT_LT(initial_level, level_estimator.vad_data_speech.peak_dbfs);
+ ASSERT_LT(initial_level, level_estimator.level_peak_dbfs);
for (int i = 0; i < adjacent_speech_frames_threshold(); ++i) {
- level_estimator.estimator->Update(level_estimator.vad_data_speech);
+ level_estimator.estimator->Update(level_estimator.level_rms_dbfs,
+ level_estimator.level_peak_dbfs,
+ kMaxSpeechProbability);
}
EXPECT_LT(initial_level, level_estimator.estimator->level_dbfs());
}
diff --git a/modules/audio_processing/agc2/vad_wrapper.cc b/modules/audio_processing/agc2/vad_wrapper.cc
index 94d5f67..7b61aee 100644
--- a/modules/audio_processing/agc2/vad_wrapper.cc
+++ b/modules/audio_processing/agc2/vad_wrapper.cc
@@ -10,13 +10,10 @@
#include "modules/audio_processing/agc2/vad_wrapper.h"
-#include <algorithm>
#include <array>
-#include <cmath>
#include <utility>
#include "api/array_view.h"
-#include "common_audio/include/audio_util.h"
#include "common_audio/resampler/include/push_resampler.h"
#include "modules/audio_processing/agc2/agc2_common.h"
#include "modules/audio_processing/agc2/rnn_vad/common.h"
@@ -27,82 +24,72 @@
namespace webrtc {
namespace {
-using VoiceActivityDetector = VadLevelAnalyzer::VoiceActivityDetector;
+constexpr int kNumFramesPerSecond = 100;
-// Default VAD that combines a resampler and the RNN VAD.
-// Computes the speech probability on the first channel.
-class Vad : public VoiceActivityDetector {
+class MonoVadImpl : public VoiceActivityDetectorWrapper::MonoVad {
public:
- explicit Vad(const AvailableCpuFeatures& cpu_features)
+ explicit MonoVadImpl(const AvailableCpuFeatures& cpu_features)
: features_extractor_(cpu_features), rnn_vad_(cpu_features) {}
- Vad(const Vad&) = delete;
- Vad& operator=(const Vad&) = delete;
- ~Vad() = default;
+ MonoVadImpl(const MonoVadImpl&) = delete;
+ MonoVadImpl& operator=(const MonoVadImpl&) = delete;
+ ~MonoVadImpl() = default;
+ int SampleRateHz() const override { return rnn_vad::kSampleRate24kHz; }
void Reset() override { rnn_vad_.Reset(); }
-
- float ComputeProbability(AudioFrameView<const float> frame) override {
- // The source number of channels is 1, because we always use the 1st
- // channel.
- resampler_.InitializeIfNeeded(
- /*sample_rate_hz=*/static_cast<int>(frame.samples_per_channel() * 100),
- rnn_vad::kSampleRate24kHz,
- /*num_channels=*/1);
-
- std::array<float, rnn_vad::kFrameSize10ms24kHz> work_frame;
- // Feed the 1st channel to the resampler.
- resampler_.Resample(frame.channel(0).data(), frame.samples_per_channel(),
- work_frame.data(), rnn_vad::kFrameSize10ms24kHz);
-
+ float Analyze(rtc::ArrayView<const float> frame) override {
+ RTC_DCHECK_EQ(frame.size(), rnn_vad::kFrameSize10ms24kHz);
std::array<float, rnn_vad::kFeatureVectorSize> feature_vector;
const bool is_silence = features_extractor_.CheckSilenceComputeFeatures(
- work_frame, feature_vector);
+ /*samples=*/{frame.data(), rnn_vad::kFrameSize10ms24kHz},
+ feature_vector);
return rnn_vad_.ComputeVadProbability(feature_vector, is_silence);
}
private:
- PushResampler<float> resampler_;
rnn_vad::FeaturesExtractor features_extractor_;
rnn_vad::RnnVad rnn_vad_;
};
} // namespace
-VadLevelAnalyzer::VadLevelAnalyzer(int vad_reset_period_ms,
- const AvailableCpuFeatures& cpu_features)
- : VadLevelAnalyzer(vad_reset_period_ms,
- std::make_unique<Vad>(cpu_features)) {}
+VoiceActivityDetectorWrapper::VoiceActivityDetectorWrapper(
+ int vad_reset_period_ms,
+ const AvailableCpuFeatures& cpu_features)
+ : VoiceActivityDetectorWrapper(
+ vad_reset_period_ms,
+ std::make_unique<MonoVadImpl>(cpu_features)) {}
-VadLevelAnalyzer::VadLevelAnalyzer(int vad_reset_period_ms,
- std::unique_ptr<VoiceActivityDetector> vad)
- : vad_(std::move(vad)),
- vad_reset_period_frames_(
+VoiceActivityDetectorWrapper::VoiceActivityDetectorWrapper(
+ int vad_reset_period_ms,
+ std::unique_ptr<MonoVad> vad)
+ : vad_reset_period_frames_(
rtc::CheckedDivExact(vad_reset_period_ms, kFrameDurationMs)),
- time_to_vad_reset_(vad_reset_period_frames_) {
+ time_to_vad_reset_(vad_reset_period_frames_),
+ vad_(std::move(vad)) {
RTC_DCHECK(vad_);
RTC_DCHECK_GT(vad_reset_period_frames_, 1);
+ resampled_buffer_.resize(
+ rtc::CheckedDivExact(vad_->SampleRateHz(), kNumFramesPerSecond));
}
-VadLevelAnalyzer::~VadLevelAnalyzer() = default;
+VoiceActivityDetectorWrapper::~VoiceActivityDetectorWrapper() = default;
-VadLevelAnalyzer::Result VadLevelAnalyzer::AnalyzeFrame(
- AudioFrameView<const float> frame) {
+float VoiceActivityDetectorWrapper::Analyze(AudioFrameView<const float> frame) {
// Periodically reset the VAD.
time_to_vad_reset_--;
if (time_to_vad_reset_ <= 0) {
vad_->Reset();
time_to_vad_reset_ = vad_reset_period_frames_;
}
- // Compute levels.
- float peak = 0.0f;
- float rms = 0.0f;
- for (const auto& x : frame.channel(0)) {
- peak = std::max(std::fabs(x), peak);
- rms += x * x;
- }
- return {vad_->ComputeProbability(frame),
- FloatS16ToDbfs(std::sqrt(rms / frame.samples_per_channel())),
- FloatS16ToDbfs(peak)};
+
+ // Resample the first channel of `frame`.
+ resampler_.InitializeIfNeeded(
+ /*sample_rate_hz=*/frame.samples_per_channel() * kNumFramesPerSecond,
+ vad_->SampleRateHz(), /*num_channels=*/1);
+ resampler_.Resample(frame.channel(0).data(), frame.samples_per_channel(),
+ resampled_buffer_.data(), resampled_buffer_.size());
+
+ return vad_->Analyze(resampled_buffer_);
}
} // namespace webrtc
diff --git a/modules/audio_processing/agc2/vad_wrapper.h b/modules/audio_processing/agc2/vad_wrapper.h
index de73eab..f17fcda 100644
--- a/modules/audio_processing/agc2/vad_wrapper.h
+++ b/modules/audio_processing/agc2/vad_wrapper.h
@@ -12,51 +12,57 @@
#define MODULES_AUDIO_PROCESSING_AGC2_VAD_WRAPPER_H_
#include <memory>
+#include <vector>
+#include "api/array_view.h"
+#include "common_audio/resampler/include/push_resampler.h"
#include "modules/audio_processing/agc2/cpu_features.h"
#include "modules/audio_processing/include/audio_frame_view.h"
namespace webrtc {
-// Class to analyze voice activity and audio levels.
-class VadLevelAnalyzer {
+// Wraps a single-channel Voice Activity Detector (VAD) which is used to analyze
+// the first channel of the input audio frames. Takes care of resampling the
+// input frames to match the sample rate of the wrapped VAD and periodically
+// resets the VAD.
+class VoiceActivityDetectorWrapper {
public:
- struct Result {
- float speech_probability; // Range: [0, 1].
- float rms_dbfs; // Root mean square power (dBFS).
- float peak_dbfs; // Peak power (dBFS).
- };
-
- // Voice Activity Detector (VAD) interface.
- class VoiceActivityDetector {
+ // Single channel VAD interface.
+ class MonoVad {
public:
- virtual ~VoiceActivityDetector() = default;
+ virtual ~MonoVad() = default;
+ // Returns the sample rate (Hz) required for the input frames analyzed by
+ // `ComputeProbability`.
+ virtual int SampleRateHz() const = 0;
// Resets the internal state.
virtual void Reset() = 0;
// Analyzes an audio frame and returns the speech probability.
- virtual float ComputeProbability(AudioFrameView<const float> frame) = 0;
+ virtual float Analyze(rtc::ArrayView<const float> frame) = 0;
};
// Ctor. `vad_reset_period_ms` indicates the period in milliseconds to call
- // `VadLevelAnalyzer::Reset()`; it must be equal to or greater than the
- // duration of two frames. Uses `cpu_features` to instantiate the default VAD.
- VadLevelAnalyzer(int vad_reset_period_ms,
- const AvailableCpuFeatures& cpu_features);
+ // `MonoVad::Reset()`; it must be equal to or greater than the duration of two
+ // frames. Uses `cpu_features` to instantiate the default VAD.
+ VoiceActivityDetectorWrapper(int vad_reset_period_ms,
+ const AvailableCpuFeatures& cpu_features);
// Ctor. Uses a custom `vad`.
- VadLevelAnalyzer(int vad_reset_period_ms,
- std::unique_ptr<VoiceActivityDetector> vad);
+ VoiceActivityDetectorWrapper(int vad_reset_period_ms,
+ std::unique_ptr<MonoVad> vad);
- VadLevelAnalyzer(const VadLevelAnalyzer&) = delete;
- VadLevelAnalyzer& operator=(const VadLevelAnalyzer&) = delete;
- ~VadLevelAnalyzer();
+ VoiceActivityDetectorWrapper(const VoiceActivityDetectorWrapper&) = delete;
+ VoiceActivityDetectorWrapper& operator=(const VoiceActivityDetectorWrapper&) =
+ delete;
+ ~VoiceActivityDetectorWrapper();
- // Computes the speech probability and the level for `frame`.
- Result AnalyzeFrame(AudioFrameView<const float> frame);
+ // Analyzes the first channel of `frame` and returns the speech probability.
+ float Analyze(AudioFrameView<const float> frame);
private:
- std::unique_ptr<VoiceActivityDetector> vad_;
const int vad_reset_period_frames_;
int time_to_vad_reset_;
+ PushResampler<float> resampler_;
+ std::unique_ptr<MonoVad> vad_;
+ std::vector<float> resampled_buffer_;
};
} // namespace webrtc
diff --git a/modules/audio_processing/agc2/vad_wrapper_unittest.cc b/modules/audio_processing/agc2/vad_wrapper_unittest.cc
index a6e776c..c1f7029 100644
--- a/modules/audio_processing/agc2/vad_wrapper_unittest.cc
+++ b/modules/audio_processing/agc2/vad_wrapper_unittest.cc
@@ -18,6 +18,7 @@
#include "modules/audio_processing/agc2/agc2_common.h"
#include "modules/audio_processing/include/audio_frame_view.h"
+#include "rtc_base/checks.h"
#include "rtc_base/gunit.h"
#include "rtc_base/numerics/safe_compare.h"
#include "test/gmock.h"
@@ -26,90 +27,78 @@
namespace {
using ::testing::AnyNumber;
+using ::testing::Return;
using ::testing::ReturnRoundRobin;
+using ::testing::Truly;
constexpr int kNoVadPeriodicReset =
kFrameDurationMs * (std::numeric_limits<int>::max() / kFrameDurationMs);
-constexpr int kSampleRateHz = 8000;
+constexpr int kSampleRate8kHz = 8000;
-class MockVad : public VadLevelAnalyzer::VoiceActivityDetector {
+class MockVad : public VoiceActivityDetectorWrapper::MonoVad {
public:
+ MOCK_METHOD(int, SampleRateHz, (), (const override));
MOCK_METHOD(void, Reset, (), (override));
- MOCK_METHOD(float,
- ComputeProbability,
- (AudioFrameView<const float> frame),
- (override));
+ MOCK_METHOD(float, Analyze, (rtc::ArrayView<const float> frame), (override));
};
-// Creates a `VadLevelAnalyzer` injecting a mock VAD which repeatedly returns
-// the next value from `speech_probabilities` until it reaches the end and will
-// restart from the beginning.
-std::unique_ptr<VadLevelAnalyzer> CreateVadLevelAnalyzerWithMockVad(
+// Creates a `VoiceActivityDetectorWrapper` injecting a mock VAD that
+// repeatedly returns the next value from `speech_probabilities` and that
+// restarts from the beginning when after the last element is returned.
+std::unique_ptr<VoiceActivityDetectorWrapper> CreateMockVadWrapper(
int vad_reset_period_ms,
const std::vector<float>& speech_probabilities,
int expected_vad_reset_calls = 0) {
auto vad = std::make_unique<MockVad>();
- EXPECT_CALL(*vad, ComputeProbability)
+ EXPECT_CALL(*vad, SampleRateHz)
.Times(AnyNumber())
- .WillRepeatedly(ReturnRoundRobin(speech_probabilities));
+ .WillRepeatedly(Return(kSampleRate8kHz));
if (expected_vad_reset_calls >= 0) {
EXPECT_CALL(*vad, Reset).Times(expected_vad_reset_calls);
}
- return std::make_unique<VadLevelAnalyzer>(vad_reset_period_ms,
- std::move(vad));
+ EXPECT_CALL(*vad, Analyze)
+ .Times(AnyNumber())
+ .WillRepeatedly(ReturnRoundRobin(speech_probabilities));
+ return std::make_unique<VoiceActivityDetectorWrapper>(vad_reset_period_ms,
+ std::move(vad));
}
// 10 ms mono frame.
struct FrameWithView {
// Ctor. Initializes the frame samples with `value`.
- explicit FrameWithView(float value = 0.0f)
- : channel0(samples.data()),
- view(&channel0, /*num_channels=*/1, samples.size()) {
- samples.fill(value);
- }
- std::array<float, kSampleRateHz / 100> samples;
+ explicit FrameWithView(int sample_rate_hz = kSampleRate8kHz)
+ : samples(rtc::CheckedDivExact(sample_rate_hz, 100), 0.0f),
+ channel0(samples.data()),
+ view(&channel0, /*num_channels=*/1, samples.size()) {}
+ std::vector<float> samples;
const float* const channel0;
const AudioFrameView<const float> view;
};
-TEST(GainController2VadLevelAnalyzer, RmsLessThanPeakLevel) {
- auto analyzer = CreateVadLevelAnalyzerWithMockVad(
- /*vad_reset_period_ms=*/1500,
- /*speech_probabilities=*/{1.0f},
- /*expected_vad_reset_calls=*/0);
- // Handcrafted frame so that the average is lower than the peak value.
- FrameWithView frame(1000.0f); // Constant frame.
- frame.samples[10] = 2000.0f; // Except for one peak value.
- // Compute audio frame levels.
- auto levels_and_vad_prob = analyzer->AnalyzeFrame(frame.view);
- EXPECT_LT(levels_and_vad_prob.rms_dbfs, levels_and_vad_prob.peak_dbfs);
-}
-
-// Checks that the expect VAD probabilities are returned.
-TEST(GainController2VadLevelAnalyzer, NoSpeechProbabilitySmoothing) {
+// Checks that the expected speech probabilities are returned.
+TEST(GainController2VoiceActivityDetectorWrapper, CheckSpeechProbabilities) {
const std::vector<float> speech_probabilities{0.709f, 0.484f, 0.882f, 0.167f,
0.44f, 0.525f, 0.858f, 0.314f,
0.653f, 0.965f, 0.413f, 0.0f};
- auto analyzer = CreateVadLevelAnalyzerWithMockVad(kNoVadPeriodicReset,
- speech_probabilities);
+ auto vad_wrapper =
+ CreateMockVadWrapper(kNoVadPeriodicReset, speech_probabilities);
FrameWithView frame;
for (int i = 0; rtc::SafeLt(i, speech_probabilities.size()); ++i) {
SCOPED_TRACE(i);
- EXPECT_EQ(speech_probabilities[i],
- analyzer->AnalyzeFrame(frame.view).speech_probability);
+ EXPECT_EQ(speech_probabilities[i], vad_wrapper->Analyze(frame.view));
}
}
// Checks that the VAD is not periodically reset.
-TEST(GainController2VadLevelAnalyzer, VadNoPeriodicReset) {
+TEST(GainController2VoiceActivityDetectorWrapper, VadNoPeriodicReset) {
constexpr int kNumFrames = 19;
- auto analyzer = CreateVadLevelAnalyzerWithMockVad(
- kNoVadPeriodicReset, /*speech_probabilities=*/{1.0f},
- /*expected_vad_reset_calls=*/0);
+ auto vad_wrapper =
+ CreateMockVadWrapper(kNoVadPeriodicReset, /*speech_probabilities=*/{1.0f},
+ /*expected_vad_reset_calls=*/0);
FrameWithView frame;
for (int i = 0; i < kNumFrames; ++i) {
- analyzer->AnalyzeFrame(frame.view);
+ vad_wrapper->Analyze(frame.view);
}
}
@@ -122,20 +111,52 @@
// Checks that the VAD is periodically reset with the expected period.
TEST_P(VadPeriodResetParametrization, VadPeriodicReset) {
- auto analyzer = CreateVadLevelAnalyzerWithMockVad(
+ auto vad_wrapper = CreateMockVadWrapper(
/*vad_reset_period_ms=*/vad_reset_period_frames() * kFrameDurationMs,
/*speech_probabilities=*/{1.0f},
/*expected_vad_reset_calls=*/num_frames() / vad_reset_period_frames());
FrameWithView frame;
for (int i = 0; i < num_frames(); ++i) {
- analyzer->AnalyzeFrame(frame.view);
+ vad_wrapper->Analyze(frame.view);
}
}
-INSTANTIATE_TEST_SUITE_P(GainController2VadLevelAnalyzer,
+INSTANTIATE_TEST_SUITE_P(GainController2VoiceActivityDetectorWrapper,
VadPeriodResetParametrization,
::testing::Combine(::testing::Values(1, 19, 123),
::testing::Values(2, 5, 20, 53)));
+class VadResamplingParametrization
+ : public ::testing::TestWithParam<std::tuple<int, int>> {
+ protected:
+ int input_sample_rate_hz() const { return std::get<0>(GetParam()); }
+ int vad_sample_rate_hz() const { return std::get<1>(GetParam()); }
+};
+
+// Checks that regardless of the input audio sample rate, the wrapped VAD
+// analyzes frames having the expected size, that is according to its internal
+// sample rate.
+TEST_P(VadResamplingParametrization, CheckResampledFrameSize) {
+ auto vad = std::make_unique<MockVad>();
+ EXPECT_CALL(*vad, SampleRateHz)
+ .Times(AnyNumber())
+ .WillRepeatedly(Return(vad_sample_rate_hz()));
+ EXPECT_CALL(*vad, Reset).Times(0);
+ EXPECT_CALL(*vad, Analyze(Truly([this](rtc::ArrayView<const float> frame) {
+ return rtc::SafeEq(frame.size(),
+ rtc::CheckedDivExact(vad_sample_rate_hz(), 100));
+ }))).Times(1);
+ auto vad_wrapper = std::make_unique<VoiceActivityDetectorWrapper>(
+ kNoVadPeriodicReset, std::move(vad));
+ FrameWithView frame(input_sample_rate_hz());
+ vad_wrapper->Analyze(frame.view);
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ GainController2VoiceActivityDetectorWrapper,
+ VadResamplingParametrization,
+ ::testing::Combine(::testing::Values(8000, 16000, 44100, 48000),
+ ::testing::Values(6000, 8000, 12000, 16000, 24000)));
+
} // namespace
} // namespace webrtc