Add RNN-VAD to AGC2.
* Move 'VadWithLevel' to AGC2 where it belongs.
* Remove the vectors from VadWithLevel. They were there to make it work
with modules/audio_processing/vad, which we don't need any longer.
* Remove the vector handling from AGC2. It was spread out across
AdaptiveDigitalGainApplier, AdaptiveAGC and their unit tests.
* Hack the RNN VAD into VadWithLevel. The main issue is the resampling.
Bug: webrtc:9076
Change-Id: I13056c985d0ec41269735150caf4aaeb6ff9281e
Reviewed-on: https://webrtc-review.googlesource.com/77364
Reviewed-by: Sam Zackrisson <saza@webrtc.org>
Commit-Queue: Alex Loiko <aleloi@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#23688}
diff --git a/modules/audio_processing/agc2/BUILD.gn b/modules/audio_processing/agc2/BUILD.gn
index e0ed2bb..8501dd9 100644
--- a/modules/audio_processing/agc2/BUILD.gn
+++ b/modules/audio_processing/agc2/BUILD.gn
@@ -33,6 +33,7 @@
":common",
":gain_applier",
":noise_level_estimator",
+ ":rnn_vad_with_level",
"..:aec_core",
"..:apm_logging",
"..:audio_frame_view",
@@ -41,9 +42,6 @@
"../../../rtc_base:checks",
"../../../rtc_base:rtc_base_approved",
"../../../rtc_base:safe_minmax",
- "../vad",
- "../vad:vad_with_level",
- "rnn_vad",
]
}
@@ -133,6 +131,20 @@
configs += [ "..:apm_debug_dump" ]
}
+rtc_source_set("rnn_vad_with_level") {
+ sources = [
+ "vad_with_level.cc",
+ "vad_with_level.h",
+ ]
+ deps = [
+ "..:audio_frame_view",
+ "../../../api:array_view",
+ "../../../common_audio",
+ "../../../rtc_base:checks",
+ "rnn_vad:lib",
+ ]
+}
+
rtc_source_set("adaptive_digital_unittests") {
testonly = true
configs += [ "..:apm_debug_dump" ]
@@ -155,7 +167,6 @@
"../../../rtc_base:checks",
"../../../rtc_base:rtc_base_approved",
"../../../rtc_base:rtc_base_tests_utils",
- "../vad:vad_with_level",
]
}
diff --git a/modules/audio_processing/agc2/adaptive_agc.cc b/modules/audio_processing/agc2/adaptive_agc.cc
index 45e8853..7b24244 100644
--- a/modules/audio_processing/agc2/adaptive_agc.cc
+++ b/modules/audio_processing/agc2/adaptive_agc.cc
@@ -14,8 +14,8 @@
#include <numeric>
#include "common_audio/include/audio_util.h"
+#include "modules/audio_processing/agc2/vad_with_level.h"
#include "modules/audio_processing/logging/apm_data_dumper.h"
-#include "modules/audio_processing/vad/voice_activity_detector.h"
namespace webrtc {
@@ -38,17 +38,14 @@
// frames, and no estimates for other frames. We want to feed all to
// the level estimator, but only care about the last level it
// produces.
- rtc::ArrayView<const VadWithLevel::LevelAndProbability> vad_results =
+ const VadWithLevel::LevelAndProbability vad_result =
vad_.AnalyzeFrame(float_frame);
- for (const auto& vad_result : vad_results) {
- apm_data_dumper_->DumpRaw("agc2_vad_probability",
- vad_result.speech_probability);
- apm_data_dumper_->DumpRaw("agc2_vad_rms_dbfs", vad_result.speech_rms_dbfs);
+ apm_data_dumper_->DumpRaw("agc2_vad_probability",
+ vad_result.speech_probability);
+ apm_data_dumper_->DumpRaw("agc2_vad_rms_dbfs", vad_result.speech_rms_dbfs);
- apm_data_dumper_->DumpRaw("agc2_vad_peak_dbfs",
- vad_result.speech_peak_dbfs);
- speech_level_estimator_.UpdateEstimation(vad_result);
- }
+ apm_data_dumper_->DumpRaw("agc2_vad_peak_dbfs", vad_result.speech_peak_dbfs);
+ speech_level_estimator_.UpdateEstimation(vad_result);
const float speech_level_dbfs = speech_level_estimator_.LatestLevelEstimate();
@@ -57,7 +54,7 @@
apm_data_dumper_->DumpRaw("agc2_noise_estimate_dbfs", noise_level_dbfs);
// The gain applier applies the gain.
- gain_applier_.Process(speech_level_dbfs, noise_level_dbfs, vad_results,
+ gain_applier_.Process(speech_level_dbfs, noise_level_dbfs, vad_result,
float_frame);
}
diff --git a/modules/audio_processing/agc2/adaptive_agc.h b/modules/audio_processing/agc2/adaptive_agc.h
index a91aa2a..dabe783 100644
--- a/modules/audio_processing/agc2/adaptive_agc.h
+++ b/modules/audio_processing/agc2/adaptive_agc.h
@@ -16,8 +16,8 @@
#include "modules/audio_processing/agc2/adaptive_digital_gain_applier.h"
#include "modules/audio_processing/agc2/adaptive_mode_level_estimator.h"
#include "modules/audio_processing/agc2/noise_level_estimator.h"
+#include "modules/audio_processing/agc2/vad_with_level.h"
#include "modules/audio_processing/include/audio_frame_view.h"
-#include "modules/audio_processing/vad/vad_with_level.h"
namespace webrtc {
class ApmDataDumper;
diff --git a/modules/audio_processing/agc2/adaptive_digital_gain_applier.cc b/modules/audio_processing/agc2/adaptive_digital_gain_applier.cc
index 20b5a27..f5b6b91 100644
--- a/modules/audio_processing/agc2/adaptive_digital_gain_applier.cc
+++ b/modules/audio_processing/agc2/adaptive_digital_gain_applier.cc
@@ -74,7 +74,7 @@
void AdaptiveDigitalGainApplier::Process(
float input_level_dbfs,
float input_noise_level_dbfs,
- rtc::ArrayView<const VadWithLevel::LevelAndProbability> vad_results,
+ const VadWithLevel::LevelAndProbability vad_result,
AudioFrameView<float> float_frame) {
RTC_DCHECK_GE(input_level_dbfs, -150.f);
RTC_DCHECK_LE(input_level_dbfs, 0.f);
@@ -85,21 +85,9 @@
LimitGainByNoise(ComputeGainDb(input_level_dbfs), input_noise_level_dbfs,
apm_data_dumper_);
- // TODO(webrtc:7494): Remove this construct. Remove the vectors from
- // VadWithData after we move to a VAD that outputs an estimate every
- // kFrameDurationMs ms.
- //
- // Forbid increasing the gain when there is no speech. For some
- // VADs, 'vad_results' has either many or 0 results. If there are 0
- // results, keep the old flag. If there are many results, and at
- // least one is confident speech, we allow attenuation.
- if (!vad_results.empty()) {
- gain_increase_allowed_ = std::all_of(
- vad_results.begin(), vad_results.end(),
- [](const VadWithLevel::LevelAndProbability& vad_result) {
- return vad_result.speech_probability > kVadConfidenceThreshold;
- });
- }
+ // Forbid increasing the gain when there is no speech.
+ gain_increase_allowed_ =
+ vad_result.speech_probability > kVadConfidenceThreshold;
const float gain_change_this_frame_db = ComputeGainChangeThisFrameDb(
target_gain_db, last_gain_db_, gain_increase_allowed_);
diff --git a/modules/audio_processing/agc2/adaptive_digital_gain_applier.h b/modules/audio_processing/agc2/adaptive_digital_gain_applier.h
index b06c65b..31f87f1 100644
--- a/modules/audio_processing/agc2/adaptive_digital_gain_applier.h
+++ b/modules/audio_processing/agc2/adaptive_digital_gain_applier.h
@@ -13,8 +13,8 @@
#include "modules/audio_processing/agc2/agc2_common.h"
#include "modules/audio_processing/agc2/gain_applier.h"
+#include "modules/audio_processing/agc2/vad_with_level.h"
#include "modules/audio_processing/include/audio_frame_view.h"
-#include "modules/audio_processing/vad/vad_with_level.h"
namespace webrtc {
@@ -24,11 +24,10 @@
public:
explicit AdaptiveDigitalGainApplier(ApmDataDumper* apm_data_dumper);
// Decide what gain to apply.
- void Process(
- float input_level_dbfs,
- float input_noise_level_dbfs,
- rtc::ArrayView<const VadWithLevel::LevelAndProbability> vad_results,
- AudioFrameView<float> float_frame);
+ void Process(float input_level_dbfs,
+ float input_noise_level_dbfs,
+ const VadWithLevel::LevelAndProbability vad_result,
+ AudioFrameView<float> float_frame);
private:
float last_gain_db_ = kInitialAdaptiveDigitalGainDb;
diff --git a/modules/audio_processing/agc2/adaptive_digital_gain_applier_unittest.cc b/modules/audio_processing/agc2/adaptive_digital_gain_applier_unittest.cc
index ebb040e..860da00 100644
--- a/modules/audio_processing/agc2/adaptive_digital_gain_applier_unittest.cc
+++ b/modules/audio_processing/agc2/adaptive_digital_gain_applier_unittest.cc
@@ -33,10 +33,8 @@
for (int i = 0; i < num_iterations; ++i) {
VectorFloatFrame fake_audio(1, 1, 1.f);
- gain_applier->Process(
- input_level_dbfs, kNoNoiseDbfs,
- rtc::ArrayView<const VadWithLevel::LevelAndProbability>(&vad_data, 1),
- fake_audio.float_frame_view());
+ gain_applier->Process(input_level_dbfs, kNoNoiseDbfs, vad_data,
+ fake_audio.float_frame_view());
gain_linear = fake_audio.float_frame_view().channel(0)[0];
}
return gain_linear;
@@ -54,10 +52,8 @@
// Make one call with reasonable audio level values and settings.
VectorFloatFrame fake_audio(2, 480, 10000.f);
- gain_applier.Process(
- -5.0, kNoNoiseDbfs,
- rtc::ArrayView<const VadWithLevel::LevelAndProbability>(&kVadSpeech, 1),
- fake_audio.float_frame_view());
+ gain_applier.Process(-5.0, kNoNoiseDbfs, kVadSpeech,
+ fake_audio.float_frame_view());
}
// Check that the output is -kHeadroom dBFS.
@@ -107,10 +103,8 @@
for (int i = 0; i < kNumFramesToAdapt; ++i) {
SCOPED_TRACE(i);
VectorFloatFrame fake_audio(1, 1, 1.f);
- gain_applier.Process(
- initial_level_dbfs, kNoNoiseDbfs,
- rtc::ArrayView<const VadWithLevel::LevelAndProbability>(&kVadSpeech, 1),
- fake_audio.float_frame_view());
+ gain_applier.Process(initial_level_dbfs, kNoNoiseDbfs, kVadSpeech,
+ fake_audio.float_frame_view());
float current_gain_linear = fake_audio.float_frame_view().channel(0)[0];
EXPECT_LE(std::abs(current_gain_linear - last_gain_linear),
kMaxChangePerFrameLinear);
@@ -121,10 +115,8 @@
for (int i = 0; i < kNumFramesToAdapt; ++i) {
SCOPED_TRACE(i);
VectorFloatFrame fake_audio(1, 1, 1.f);
- gain_applier.Process(
- 0.f, kNoNoiseDbfs,
- rtc::ArrayView<const VadWithLevel::LevelAndProbability>(&kVadSpeech, 1),
- fake_audio.float_frame_view());
+ gain_applier.Process(0.f, kNoNoiseDbfs, kVadSpeech,
+ fake_audio.float_frame_view());
float current_gain_linear = fake_audio.float_frame_view().channel(0)[0];
EXPECT_LE(std::abs(current_gain_linear - last_gain_linear),
kMaxChangePerFrameLinear);
@@ -140,10 +132,8 @@
constexpr int num_samples = 480;
VectorFloatFrame fake_audio(1, num_samples, 1.f);
- gain_applier.Process(
- initial_level_dbfs, kNoNoiseDbfs,
- rtc::ArrayView<const VadWithLevel::LevelAndProbability>(&kVadSpeech, 1),
- fake_audio.float_frame_view());
+ gain_applier.Process(initial_level_dbfs, kNoNoiseDbfs, kVadSpeech,
+ fake_audio.float_frame_view());
float maximal_difference = 0.f;
float current_value = 1.f * DbToRatio(kInitialAdaptiveDigitalGainDb);
for (const auto& x : fake_audio.float_frame_view().channel(0)) {
@@ -172,10 +162,8 @@
for (int i = 0; i < num_initial_frames + num_frames; ++i) {
VectorFloatFrame fake_audio(1, num_samples, 1.f);
- gain_applier.Process(
- initial_level_dbfs, kWithNoiseDbfs,
- rtc::ArrayView<const VadWithLevel::LevelAndProbability>(&kVadSpeech, 1),
- fake_audio.float_frame_view());
+ gain_applier.Process(initial_level_dbfs, kWithNoiseDbfs, kVadSpeech,
+ fake_audio.float_frame_view());
// Wait so that the adaptive gain applier has time to lower the gain.
if (i > num_initial_frames) {
diff --git a/modules/audio_processing/agc2/adaptive_mode_level_estimator.h b/modules/audio_processing/agc2/adaptive_mode_level_estimator.h
index 9762f1f..186c59b 100644
--- a/modules/audio_processing/agc2/adaptive_mode_level_estimator.h
+++ b/modules/audio_processing/agc2/adaptive_mode_level_estimator.h
@@ -12,7 +12,7 @@
#define MODULES_AUDIO_PROCESSING_AGC2_ADAPTIVE_MODE_LEVEL_ESTIMATOR_H_
#include "modules/audio_processing/agc2/saturation_protector.h"
-#include "modules/audio_processing/vad/vad_with_level.h"
+#include "modules/audio_processing/agc2/vad_with_level.h"
namespace webrtc {
class ApmDataDumper;
diff --git a/modules/audio_processing/agc2/agc2_common.h b/modules/audio_processing/agc2/agc2_common.h
index 3ed88a3..7300653 100644
--- a/modules/audio_processing/agc2/agc2_common.h
+++ b/modules/audio_processing/agc2/agc2_common.h
@@ -41,10 +41,10 @@
// Used in the Level Estimator for deciding when to update the speech
// level estimate. Also used in the adaptive digital gain applier to
// decide when to allow target gain reduction.
-constexpr float kVadConfidenceThreshold = 0.9f;
+constexpr float kVadConfidenceThreshold = 0.4f;
// The amount of 'memory' of the Level Estimator. Decides leak factors.
-constexpr size_t kFullBufferSizeMs = 1000;
+constexpr size_t kFullBufferSizeMs = 1600;
constexpr float kFullBufferLeakFactor = 1.f - 1.f / kFullBufferSizeMs;
constexpr float kInitialSpeechLevelEstimateDbfs = -30.f;
@@ -52,7 +52,10 @@
// Saturation Protector settings.
constexpr float kInitialSaturationMarginDb = 17.f;
-constexpr size_t kPeakEnveloperSuperFrameLengthMs = 500;
+constexpr size_t kPeakEnveloperSuperFrameLengthMs = 400;
+static_assert(kFullBufferSizeMs % kPeakEnveloperSuperFrameLengthMs == 0,
+ "Full buffer size should be a multiple of super frame length for "
+ "optimal Saturation Protector performance.");
constexpr size_t kPeakEnveloperBufferSize =
kFullBufferSizeMs / kPeakEnveloperSuperFrameLengthMs + 1;
diff --git a/modules/audio_processing/agc2/saturation_protector.h b/modules/audio_processing/agc2/saturation_protector.h
index d330c15..3a796fa 100644
--- a/modules/audio_processing/agc2/saturation_protector.h
+++ b/modules/audio_processing/agc2/saturation_protector.h
@@ -14,7 +14,7 @@
#include <array>
#include "modules/audio_processing/agc2/agc2_common.h"
-#include "modules/audio_processing/vad/vad_with_level.h"
+#include "modules/audio_processing/agc2/vad_with_level.h"
namespace webrtc {
diff --git a/modules/audio_processing/agc2/saturation_protector_unittest.cc b/modules/audio_processing/agc2/saturation_protector_unittest.cc
index 88da2a2..6013e13 100644
--- a/modules/audio_processing/agc2/saturation_protector_unittest.cc
+++ b/modules/audio_processing/agc2/saturation_protector_unittest.cc
@@ -30,6 +30,7 @@
max_difference =
std::max(max_difference, std::abs(new_margin - last_margin));
last_margin = new_margin;
+ saturation_protector->DebugDumpEstimate();
}
return max_difference;
}
@@ -127,6 +128,12 @@
kLaterSpeechLevelDbfs, &saturation_protector),
max_difference);
+ // The saturation protector expects that the RMS changes roughly
+ // 'kFullBufferSizeMs' after peaks change. This is to account for
+ // delay introduces by the level estimator. Therefore, the input
+ // above is 'normal' and 'expected', and shouldn't influence the
+ // margin by much.
+
const float total_difference =
std::abs(saturation_protector.LastMargin() - kInitialSaturationMarginDb);
diff --git a/modules/audio_processing/agc2/vad_with_level.cc b/modules/audio_processing/agc2/vad_with_level.cc
new file mode 100644
index 0000000..decfacd
--- /dev/null
+++ b/modules/audio_processing/agc2/vad_with_level.cc
@@ -0,0 +1,68 @@
+/*
+ * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
+ *
+ * Use of this source code is governed by a BSD-style license
+ * that can be found in the LICENSE file in the root of the source
+ * tree. An additional intellectual property rights grant can be found
+ * in the file PATENTS. All contributing project authors may
+ * be found in the AUTHORS file in the root of the source tree.
+ */
+
+#include "modules/audio_processing/agc2/vad_with_level.h"
+
+#include <algorithm>
+
+#include "common_audio/include/audio_util.h"
+#include "modules/audio_processing/agc2/rnn_vad/common.h"
+#include "rtc_base/checks.h"
+
+namespace webrtc {
+
+namespace {
+float ProcessForPeak(AudioFrameView<const float> frame) {
+ float current_max = 0;
+ for (const auto& x : frame.channel(0)) {
+ current_max = std::max(std::fabs(x), current_max);
+ }
+ return current_max;
+}
+
+float ProcessForRms(AudioFrameView<const float> frame) {
+ float rms = 0;
+ for (const auto& x : frame.channel(0)) {
+ rms += x * x;
+ }
+ return sqrt(rms / frame.samples_per_channel());
+}
+} // namespace
+
+VadWithLevel::VadWithLevel() = default;
+VadWithLevel::~VadWithLevel() = default;
+
+VadWithLevel::LevelAndProbability VadWithLevel::AnalyzeFrame(
+ AudioFrameView<const float> frame) {
+ SetSampleRate(static_cast<int>(frame.samples_per_channel() * 100));
+ std::array<float, rnn_vad::kFrameSize10ms24kHz> work_frame;
+ // Feed the 1st channel to the resampler.
+ resampler_.Resample(frame.channel(0).data(), frame.samples_per_channel(),
+ work_frame.data(), rnn_vad::kFrameSize10ms24kHz);
+
+ std::array<float, rnn_vad::kFeatureVectorSize> feature_vector;
+
+ const bool is_silence = features_extractor_.CheckSilenceComputeFeatures(
+ work_frame, feature_vector);
+ const float vad_probability =
+ rnn_vad_.ComputeVadProbability(feature_vector, is_silence);
+ return LevelAndProbability(vad_probability,
+ FloatS16ToDbfs(ProcessForRms(frame)),
+ FloatS16ToDbfs(ProcessForPeak(frame)));
+}
+
+void VadWithLevel::SetSampleRate(int sample_rate_hz) {
+ // The source number of channels in 1, because we always use the 1st
+ // channel.
+ resampler_.InitializeIfNeeded(sample_rate_hz, rnn_vad::kSampleRate24kHz,
+ 1 /* num_channels */);
+}
+
+} // namespace webrtc
diff --git a/modules/audio_processing/vad/vad_with_level.h b/modules/audio_processing/agc2/vad_with_level.h
similarity index 60%
rename from modules/audio_processing/vad/vad_with_level.h
rename to modules/audio_processing/agc2/vad_with_level.h
index 9ad4d17..67a00ce 100644
--- a/modules/audio_processing/vad/vad_with_level.h
+++ b/modules/audio_processing/agc2/vad_with_level.h
@@ -8,10 +8,13 @@
* be found in the AUTHORS file in the root of the source tree.
*/
-#ifndef MODULES_AUDIO_PROCESSING_VAD_VAD_WITH_LEVEL_H_
-#define MODULES_AUDIO_PROCESSING_VAD_VAD_WITH_LEVEL_H_
+#ifndef MODULES_AUDIO_PROCESSING_AGC2_VAD_WITH_LEVEL_H_
+#define MODULES_AUDIO_PROCESSING_AGC2_VAD_WITH_LEVEL_H_
#include "api/array_view.h"
+#include "common_audio/resampler/include/push_resampler.h"
+#include "modules/audio_processing/agc2/rnn_vad/features_extraction.h"
+#include "modules/audio_processing/agc2/rnn_vad/rnn.h"
#include "modules/audio_processing/include/audio_frame_view.h"
namespace webrtc {
@@ -28,13 +31,19 @@
float speech_peak_dbfs = 0;
};
- // TODO(webrtc:7494): This is a stub. Add implementation.
- rtc::ArrayView<const LevelAndProbability> AnalyzeFrame(
- AudioFrameView<const float> frame) {
- return {nullptr, 0};
- }
+ VadWithLevel();
+ ~VadWithLevel();
+
+ LevelAndProbability AnalyzeFrame(AudioFrameView<const float> frame);
+
+ private:
+ void SetSampleRate(int sample_rate_hz);
+
+ rnn_vad::RnnBasedVad rnn_vad_;
+ rnn_vad::FeaturesExtractor features_extractor_;
+ PushResampler<float> resampler_;
};
} // namespace webrtc
-#endif // MODULES_AUDIO_PROCESSING_VAD_VAD_WITH_LEVEL_H_
+#endif // MODULES_AUDIO_PROCESSING_AGC2_VAD_WITH_LEVEL_H_
diff --git a/modules/audio_processing/vad/BUILD.gn b/modules/audio_processing/vad/BUILD.gn
index 9a57789..ae2a84d 100644
--- a/modules/audio_processing/vad/BUILD.gn
+++ b/modules/audio_processing/vad/BUILD.gn
@@ -44,16 +44,6 @@
]
}
-rtc_source_set("vad_with_level") {
- sources = [
- "vad_with_level.h",
- ]
- deps = [
- "..:audio_frame_view",
- "../../../api:array_view",
- ]
-}
-
if (rtc_include_tests) {
rtc_static_library("vad_unittests") {
testonly = true