AGC2 add an interface for the noise level estimator
Done in preparation for the child CL which adds an alternative
implementation.
Bug: webrtc:7494
Change-Id: I4963376afc917eae434a0d0ccee18f21880eefe0
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/214125
Reviewed-by: Jakob Ivarsson <jakobi@webrtc.org>
Commit-Queue: Alessio Bazzica <alessiob@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#33646}
diff --git a/modules/audio_processing/agc2/adaptive_agc.cc b/modules/audio_processing/agc2/adaptive_agc.cc
index 9657aec..ca9959a 100644
--- a/modules/audio_processing/agc2/adaptive_agc.cc
+++ b/modules/audio_processing/agc2/adaptive_agc.cc
@@ -58,7 +58,7 @@
kMaxGainChangePerSecondDb,
kMaxOutputNoiseLevelDbfs),
apm_data_dumper_(apm_data_dumper),
- noise_level_estimator_(apm_data_dumper) {
+ noise_level_estimator_(CreateNoiseLevelEstimator(apm_data_dumper)) {
RTC_DCHECK(apm_data_dumper);
}
@@ -80,7 +80,7 @@
config.adaptive_digital.max_gain_change_db_per_second,
config.adaptive_digital.max_output_noise_level_dbfs),
apm_data_dumper_(apm_data_dumper),
- noise_level_estimator_(apm_data_dumper) {
+ noise_level_estimator_(CreateNoiseLevelEstimator(apm_data_dumper)) {
RTC_DCHECK(apm_data_dumper);
if (!config.adaptive_digital.use_saturation_protector) {
RTC_LOG(LS_WARNING) << "The saturation protector cannot be disabled.";
@@ -94,7 +94,7 @@
info.vad_result = vad_.AnalyzeFrame(frame);
speech_level_estimator_.Update(info.vad_result);
info.input_level_dbfs = speech_level_estimator_.level_dbfs();
- info.input_noise_level_dbfs = noise_level_estimator_.Analyze(frame);
+ info.input_noise_level_dbfs = noise_level_estimator_->Analyze(frame);
info.limiter_envelope_dbfs =
limiter_envelope > 0 ? FloatS16ToDbfs(limiter_envelope) : -90.0f;
info.estimate_is_confident = speech_level_estimator_.IsConfident();
diff --git a/modules/audio_processing/agc2/adaptive_agc.h b/modules/audio_processing/agc2/adaptive_agc.h
index f3c7854..b861c48 100644
--- a/modules/audio_processing/agc2/adaptive_agc.h
+++ b/modules/audio_processing/agc2/adaptive_agc.h
@@ -11,6 +11,8 @@
#ifndef MODULES_AUDIO_PROCESSING_AGC2_ADAPTIVE_AGC_H_
#define MODULES_AUDIO_PROCESSING_AGC2_ADAPTIVE_AGC_H_
+#include <memory>
+
#include "modules/audio_processing/agc2/adaptive_digital_gain_applier.h"
#include "modules/audio_processing/agc2/adaptive_mode_level_estimator.h"
#include "modules/audio_processing/agc2/noise_level_estimator.h"
@@ -42,7 +44,7 @@
VadLevelAnalyzer vad_;
AdaptiveDigitalGainApplier gain_applier_;
ApmDataDumper* const apm_data_dumper_;
- NoiseLevelEstimator noise_level_estimator_;
+ std::unique_ptr<NoiseLevelEstimator> noise_level_estimator_;
};
} // namespace webrtc
diff --git a/modules/audio_processing/agc2/noise_level_estimator.cc b/modules/audio_processing/agc2/noise_level_estimator.cc
index d50ecba..6aa942c 100644
--- a/modules/audio_processing/agc2/noise_level_estimator.cc
+++ b/modules/audio_processing/agc2/noise_level_estimator.cc
@@ -18,11 +18,11 @@
#include "api/array_view.h"
#include "common_audio/include/audio_util.h"
+#include "modules/audio_processing/agc2/signal_classifier.h"
#include "modules/audio_processing/logging/apm_data_dumper.h"
#include "rtc_base/checks.h"
namespace webrtc {
-
namespace {
constexpr int kFramesPerSecond = 100;
@@ -41,86 +41,106 @@
const float rms = std::sqrt(signal_energy / num_samples);
return FloatS16ToDbfs(rms);
}
-} // namespace
-NoiseLevelEstimator::NoiseLevelEstimator(ApmDataDumper* data_dumper)
- : data_dumper_(data_dumper), signal_classifier_(data_dumper) {
- Initialize(48000);
-}
-
-NoiseLevelEstimator::~NoiseLevelEstimator() {}
-
-void NoiseLevelEstimator::Initialize(int sample_rate_hz) {
- sample_rate_hz_ = sample_rate_hz;
- noise_energy_ = 1.0f;
- first_update_ = true;
- min_noise_energy_ = sample_rate_hz * 2.0f * 2.0f / kFramesPerSecond;
- noise_energy_hold_counter_ = 0;
- signal_classifier_.Initialize(sample_rate_hz);
-}
-
-float NoiseLevelEstimator::Analyze(const AudioFrameView<const float>& frame) {
- data_dumper_->DumpRaw("agc2_noise_level_estimator_hold_counter",
- noise_energy_hold_counter_);
- const int sample_rate_hz =
- static_cast<int>(frame.samples_per_channel() * kFramesPerSecond);
- if (sample_rate_hz != sample_rate_hz_) {
- Initialize(sample_rate_hz);
+class NoiseLevelEstimatorImpl : public NoiseLevelEstimator {
+ public:
+ NoiseLevelEstimatorImpl(ApmDataDumper* data_dumper)
+ : data_dumper_(data_dumper), signal_classifier_(data_dumper) {
+ Initialize(48000);
}
- const float frame_energy = FrameEnergy(frame);
- if (frame_energy <= 0.f) {
- RTC_DCHECK_GE(frame_energy, 0.f);
- data_dumper_->DumpRaw("agc2_noise_level_estimator_signal_type", -1);
- return EnergyToDbfs(noise_energy_, frame.samples_per_channel());
- }
+ NoiseLevelEstimatorImpl(const NoiseLevelEstimatorImpl&) = delete;
+ NoiseLevelEstimatorImpl& operator=(const NoiseLevelEstimatorImpl&) = delete;
+ ~NoiseLevelEstimatorImpl() = default;
- if (first_update_) {
- // Initialize the noise energy to the frame energy.
- first_update_ = false;
- data_dumper_->DumpRaw("agc2_noise_level_estimator_signal_type", -1);
- noise_energy_ = std::max(frame_energy, min_noise_energy_);
- return EnergyToDbfs(noise_energy_, frame.samples_per_channel());
- }
+ float Analyze(const AudioFrameView<const float>& frame) {
+ data_dumper_->DumpRaw("agc2_noise_level_estimator_hold_counter",
+ noise_energy_hold_counter_);
+ const int sample_rate_hz =
+ static_cast<int>(frame.samples_per_channel() * kFramesPerSecond);
+ if (sample_rate_hz != sample_rate_hz_) {
+ Initialize(sample_rate_hz);
+ }
+ const float frame_energy = FrameEnergy(frame);
+ if (frame_energy <= 0.f) {
+ RTC_DCHECK_GE(frame_energy, 0.f);
+ data_dumper_->DumpRaw("agc2_noise_level_estimator_signal_type", -1);
+ return EnergyToDbfs(noise_energy_, frame.samples_per_channel());
+ }
- const SignalClassifier::SignalType signal_type =
- signal_classifier_.Analyze(frame.channel(0));
- data_dumper_->DumpRaw("agc2_noise_level_estimator_signal_type",
- static_cast<int>(signal_type));
+ if (first_update_) {
+ // Initialize the noise energy to the frame energy.
+ first_update_ = false;
+ data_dumper_->DumpRaw("agc2_noise_level_estimator_signal_type", -1);
+ noise_energy_ = std::max(frame_energy, min_noise_energy_);
+ return EnergyToDbfs(noise_energy_, frame.samples_per_channel());
+ }
- // Update the noise estimate in a minimum statistics-type manner.
- if (signal_type == SignalClassifier::SignalType::kStationary) {
- if (frame_energy > noise_energy_) {
- // Leak the estimate upwards towards the frame energy if no recent
- // downward update.
- noise_energy_hold_counter_ = std::max(noise_energy_hold_counter_ - 1, 0);
+ const SignalClassifier::SignalType signal_type =
+ signal_classifier_.Analyze(frame.channel(0));
+ data_dumper_->DumpRaw("agc2_noise_level_estimator_signal_type",
+ static_cast<int>(signal_type));
- if (noise_energy_hold_counter_ == 0) {
- constexpr float kMaxNoiseEnergyFactor = 1.01f;
+ // Update the noise estimate in a minimum statistics-type manner.
+ if (signal_type == SignalClassifier::SignalType::kStationary) {
+ if (frame_energy > noise_energy_) {
+ // Leak the estimate upwards towards the frame energy if no recent
+ // downward update.
+ noise_energy_hold_counter_ =
+ std::max(noise_energy_hold_counter_ - 1, 0);
+
+ if (noise_energy_hold_counter_ == 0) {
+ constexpr float kMaxNoiseEnergyFactor = 1.01f;
+ noise_energy_ =
+ std::min(noise_energy_ * kMaxNoiseEnergyFactor, frame_energy);
+ }
+ } else {
+ // Update smoothly downwards with a limited maximum update magnitude.
+ constexpr float kMinNoiseEnergyFactor = 0.9f;
+ constexpr float kNoiseEnergyDeltaFactor = 0.05f;
noise_energy_ =
- std::min(noise_energy_ * kMaxNoiseEnergyFactor, frame_energy);
+ std::max(noise_energy_ * kMinNoiseEnergyFactor,
+ noise_energy_ - kNoiseEnergyDeltaFactor *
+ (noise_energy_ - frame_energy));
+ // Prevent an energy increase for the next 10 seconds.
+ constexpr int kNumFramesToEnergyIncreaseAllowed = 1000;
+ noise_energy_hold_counter_ = kNumFramesToEnergyIncreaseAllowed;
}
} else {
- // Update smoothly downwards with a limited maximum update magnitude.
- constexpr float kMinNoiseEnergyFactor = 0.9f;
- constexpr float kNoiseEnergyDeltaFactor = 0.05f;
- noise_energy_ =
- std::max(noise_energy_ * kMinNoiseEnergyFactor,
- noise_energy_ - kNoiseEnergyDeltaFactor *
- (noise_energy_ - frame_energy));
- // Prevent an energy increase for the next 10 seconds.
- constexpr int kNumFramesToEnergyIncreaseAllowed = 1000;
- noise_energy_hold_counter_ = kNumFramesToEnergyIncreaseAllowed;
+ // TODO(bugs.webrtc.org/7494): Remove to not forget the estimated level.
+ // For a non-stationary signal, leak the estimate downwards in order to
+ // avoid estimate locking due to incorrect signal classification.
+ noise_energy_ = noise_energy_ * 0.99f;
}
- } else {
- // TODO(bugs.webrtc.org/7494): Remove to not forget the estimated level.
- // For a non-stationary signal, leak the estimate downwards in order to
- // avoid estimate locking due to incorrect signal classification.
- noise_energy_ = noise_energy_ * 0.99f;
+
+ // Ensure a minimum of the estimate.
+ noise_energy_ = std::max(noise_energy_, min_noise_energy_);
+ return EnergyToDbfs(noise_energy_, frame.samples_per_channel());
}
- // Ensure a minimum of the estimate.
- noise_energy_ = std::max(noise_energy_, min_noise_energy_);
- return EnergyToDbfs(noise_energy_, frame.samples_per_channel());
+ private:
+ void Initialize(int sample_rate_hz) {
+ sample_rate_hz_ = sample_rate_hz;
+ noise_energy_ = 1.0f;
+ first_update_ = true;
+ min_noise_energy_ = sample_rate_hz * 2.0f * 2.0f / kFramesPerSecond;
+ noise_energy_hold_counter_ = 0;
+ signal_classifier_.Initialize(sample_rate_hz);
+ }
+
+ ApmDataDumper* const data_dumper_;
+ int sample_rate_hz_;
+ float min_noise_energy_;
+ bool first_update_;
+ float noise_energy_;
+ int noise_energy_hold_counter_;
+ SignalClassifier signal_classifier_;
+};
+
+} // namespace
+
+std::unique_ptr<NoiseLevelEstimator> CreateNoiseLevelEstimator(
+ ApmDataDumper* data_dumper) {
+ return std::make_unique<NoiseLevelEstimatorImpl>(data_dumper);
}
} // namespace webrtc
diff --git a/modules/audio_processing/agc2/noise_level_estimator.h b/modules/audio_processing/agc2/noise_level_estimator.h
index 65d4623..7e57b4c 100644
--- a/modules/audio_processing/agc2/noise_level_estimator.h
+++ b/modules/audio_processing/agc2/noise_level_estimator.h
@@ -11,33 +11,26 @@
#ifndef MODULES_AUDIO_PROCESSING_AGC2_NOISE_LEVEL_ESTIMATOR_H_
#define MODULES_AUDIO_PROCESSING_AGC2_NOISE_LEVEL_ESTIMATOR_H_
-#include "modules/audio_processing/agc2/signal_classifier.h"
+#include <memory>
+
#include "modules/audio_processing/include/audio_frame_view.h"
namespace webrtc {
class ApmDataDumper;
+// Noise level estimator interface.
class NoiseLevelEstimator {
public:
- NoiseLevelEstimator(ApmDataDumper* data_dumper);
- NoiseLevelEstimator(const NoiseLevelEstimator&) = delete;
- NoiseLevelEstimator& operator=(const NoiseLevelEstimator&) = delete;
- ~NoiseLevelEstimator();
- // Returns the estimated noise level in dBFS.
- float Analyze(const AudioFrameView<const float>& frame);
-
- private:
- void Initialize(int sample_rate_hz);
-
- ApmDataDumper* const data_dumper_;
- int sample_rate_hz_;
- float min_noise_energy_;
- bool first_update_;
- float noise_energy_;
- int noise_energy_hold_counter_;
- SignalClassifier signal_classifier_;
+ virtual ~NoiseLevelEstimator() = default;
+ // Analyzes a 10 ms `frame`, updates the noise level estimation and returns
+ // the value for the latter in dBFS.
+ virtual float Analyze(const AudioFrameView<const float>& frame) = 0;
};
+// Creates a noise level estimator based on stationarity detection.
+std::unique_ptr<NoiseLevelEstimator> CreateNoiseLevelEstimator(
+ ApmDataDumper* data_dumper);
+
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_NOISE_LEVEL_ESTIMATOR_H_
diff --git a/modules/audio_processing/agc2/noise_level_estimator_unittest.cc b/modules/audio_processing/agc2/noise_level_estimator_unittest.cc
index 327fcee..ccee34a 100644
--- a/modules/audio_processing/agc2/noise_level_estimator_unittest.cc
+++ b/modules/audio_processing/agc2/noise_level_estimator_unittest.cc
@@ -31,7 +31,7 @@
float RunEstimator(rtc::FunctionView<float()> sample_generator,
int sample_rate_hz) {
ApmDataDumper data_dumper(0);
- NoiseLevelEstimator estimator(&data_dumper);
+ auto estimator = CreateNoiseLevelEstimator(&data_dumper);
const int samples_per_channel =
rtc::CheckedDivExact(sample_rate_hz, kFramesPerSecond);
VectorFloatFrame signal(1, samples_per_channel, 0.0f);
@@ -41,9 +41,9 @@
for (int j = 0; j < samples_per_channel; ++j) {
frame_view.channel(0)[j] = sample_generator();
}
- estimator.Analyze(frame_view);
+ estimator->Analyze(frame_view);
}
- return estimator.Analyze(signal.float_frame_view());
+ return estimator->Analyze(signal.float_frame_view());
}
class NoiseEstimatorParametrization : public ::testing::TestWithParam<int> {