AEC3: Suppression filter handles multiple channels
Suppression filter is extended to support the synthesis
of multiple channels. This CL is also a major clean-up of ApplyGain.
The CL has been tested for bit-exactness for single channel output.
Bug: webrtc:10913
Change-Id: I1319f127981552e17dec66701a248d34dcf0e563
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/154341
Commit-Queue: Gustaf Ullberg <gustaf@webrtc.org>
Reviewed-by: Per Ã…hgren <peah@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#29284}
diff --git a/modules/audio_processing/aec3/echo_remover.cc b/modules/audio_processing/aec3/echo_remover.cc
index a184517..725e33e 100644
--- a/modules/audio_processing/aec3/echo_remover.cc
+++ b/modules/audio_processing/aec3/echo_remover.cc
@@ -191,7 +191,9 @@
subtractors_(num_capture_channels_),
suppression_gains_(num_capture_channels_),
cngs_(num_capture_channels_),
- suppression_filter_(optimization_, sample_rate_hz_),
+ suppression_filter_(optimization_,
+ sample_rate_hz_,
+ num_capture_channels_),
render_signal_analyzer_(config_),
residual_echo_estimators_(num_capture_channels_),
aec_state_(config_),
@@ -378,7 +380,7 @@
E2[0], Y2[0], subtractor_output[0], y0);
// Choose the linear output.
- const auto& Y_fft = aec_state_.UseLinearFilterOutput() ? E[0] : Y[0];
+ const auto& Y_fft = aec_state_.UseLinearFilterOutput() ? E : Y;
#if WEBRTC_APM_DEBUG_DUMP
if (aec_state_.UseLinearFilterOutput()) {
@@ -439,8 +441,7 @@
[](float a, float b) { return std::min(a, b); });
}
- // TODO(bugs.webrtc.org/10913): Make ApplyGain handle multiple channels.
- suppression_filter_.ApplyGain(comfort_noise[0], high_band_comfort_noise[0], G,
+ suppression_filter_.ApplyGain(comfort_noise, high_band_comfort_noise, G,
high_bands_gain, Y_fft, y);
// Update the metrics.
diff --git a/modules/audio_processing/aec3/suppression_filter.cc b/modules/audio_processing/aec3/suppression_filter.cc
index 6679a87..8a813d9 100644
--- a/modules/audio_processing/aec3/suppression_filter.cc
+++ b/modules/audio_processing/aec3/suppression_filter.cc
@@ -61,107 +61,117 @@
} // namespace
SuppressionFilter::SuppressionFilter(Aec3Optimization optimization,
- int sample_rate_hz)
+ int sample_rate_hz,
+ size_t num_capture_channels)
: optimization_(optimization),
sample_rate_hz_(sample_rate_hz),
+ num_capture_channels_(num_capture_channels),
fft_(),
- e_output_old_(NumBandsForRate(sample_rate_hz_)) {
+ e_output_old_(NumBandsForRate(sample_rate_hz_),
+ std::vector<std::array<float, kFftLengthBy2>>(
+ num_capture_channels_)) {
RTC_DCHECK(ValidFullBandRate(sample_rate_hz_));
- std::for_each(e_output_old_.begin(), e_output_old_.end(),
- [](std::array<float, kFftLengthBy2>& a) { a.fill(0.f); });
+ for (size_t b = 0; b < e_output_old_.size(); ++b) {
+ for (size_t ch = 0; ch < e_output_old_[b].size(); ++ch) {
+ e_output_old_[b][ch].fill(0.f);
+ }
+ }
}
SuppressionFilter::~SuppressionFilter() = default;
void SuppressionFilter::ApplyGain(
- const FftData& comfort_noise,
- const FftData& comfort_noise_high_band,
+ rtc::ArrayView<const FftData> comfort_noise,
+ rtc::ArrayView<const FftData> comfort_noise_high_band,
const std::array<float, kFftLengthBy2Plus1>& suppression_gain,
float high_bands_gain,
- const FftData& E_lowest_band,
+ rtc::ArrayView<const FftData> E_lowest_band,
std::vector<std::vector<std::vector<float>>>* e) {
RTC_DCHECK(e);
RTC_DCHECK_EQ(e->size(), NumBandsForRate(sample_rate_hz_));
- FftData E;
-
- // Analysis filterbank.
- E.Assign(E_lowest_band);
-
- // Apply gain.
- std::transform(suppression_gain.begin(), suppression_gain.end(), E.re.begin(),
- E.re.begin(), std::multiplies<float>());
- std::transform(suppression_gain.begin(), suppression_gain.end(), E.im.begin(),
- E.im.begin(), std::multiplies<float>());
// Comfort noise gain is sqrt(1-g^2), where g is the suppression gain.
std::array<float, kFftLengthBy2Plus1> noise_gain;
- std::transform(suppression_gain.begin(), suppression_gain.end(),
- noise_gain.begin(), [](float g) { return 1.f - g * g; });
+ for (size_t i = 0; i < kFftLengthBy2Plus1; ++i) {
+ noise_gain[i] = 1.f - suppression_gain[i] * suppression_gain[i];
+ }
aec3::VectorMath(optimization_).Sqrt(noise_gain);
- // Scale and add the comfort noise.
- for (size_t k = 0; k < kFftLengthBy2Plus1; k++) {
- E.re[k] += noise_gain[k] * comfort_noise.re[k];
- E.im[k] += noise_gain[k] * comfort_noise.im[k];
- }
+ const float high_bands_noise_scaling =
+ 0.4f * std::sqrt(1.f - high_bands_gain * high_bands_gain);
- // Synthesis filterbank.
- std::array<float, kFftLength> e_extended;
- constexpr float kIfftNormalization = 2.f / kFftLength;
+ for (size_t ch = 0; ch < num_capture_channels_; ++ch) {
+ FftData E;
- fft_.Ifft(E, &e_extended);
- std::transform(e_output_old_[0].begin(), e_output_old_[0].end(),
- std::begin(kSqrtHanning) + kFftLengthBy2, (*e)[0][0].begin(),
- [&](float a, float b) { return kIfftNormalization * a * b; });
- std::transform(e_extended.begin(), e_extended.begin() + kFftLengthBy2,
- std::begin(kSqrtHanning), e_extended.begin(),
- [&](float a, float b) { return kIfftNormalization * a * b; });
- std::transform((*e)[0][0].begin(), (*e)[0][0].end(), e_extended.begin(),
- (*e)[0][0].begin(), std::plus<float>());
- std::for_each((*e)[0][0].begin(), (*e)[0][0].end(), [](float& x_k) {
- x_k = rtc::SafeClamp(x_k, -32768.f, 32767.f);
- });
- std::copy(e_extended.begin() + kFftLengthBy2, e_extended.begin() + kFftLength,
- std::begin(e_output_old_[0]));
+ // Analysis filterbank.
+ E.Assign(E_lowest_band[ch]);
- if (e->size() > 1) {
- // Form time-domain high-band noise.
- std::array<float, kFftLength> time_domain_high_band_noise;
- std::transform(comfort_noise_high_band.re.begin(),
- comfort_noise_high_band.re.end(), E.re.begin(),
- [&](float a) { return kIfftNormalization * a; });
- std::transform(comfort_noise_high_band.im.begin(),
- comfort_noise_high_band.im.end(), E.im.begin(),
- [&](float a) { return kIfftNormalization * a; });
- fft_.Ifft(E, &time_domain_high_band_noise);
+ for (size_t i = 0; i < kFftLengthBy2Plus1; ++i) {
+ // Apply suppression gains.
+ E.re[i] *= suppression_gain[i];
+ E.im[i] *= suppression_gain[i];
- // Scale and apply the noise to the signals.
- const float high_bands_noise_scaling =
- 0.4f * std::sqrt(1.f - high_bands_gain * high_bands_gain);
-
- std::transform(
- (*e)[1][0].begin(), (*e)[1][0].end(),
- time_domain_high_band_noise.begin(), (*e)[1][0].begin(),
- [&](float a, float b) {
- return std::max(
- std::min(b * high_bands_noise_scaling + high_bands_gain * a,
- 32767.0f),
- -32768.0f);
- });
-
- if (e->size() > 2) {
- RTC_DCHECK_EQ(3, e->size());
- std::for_each((*e)[2][0].begin(), (*e)[2][0].end(), [&](float& a) {
- a = rtc::SafeClamp(a * high_bands_gain, -32768.f, 32767.f);
- });
+ // Scale and add the comfort noise.
+ E.re[i] += noise_gain[i] * comfort_noise[ch].re[i];
+ E.im[i] += noise_gain[i] * comfort_noise[ch].im[i];
}
- std::array<float, kFftLengthBy2> tmp;
- for (size_t k = 1; k < e->size(); ++k) {
- std::copy((*e)[k][0].begin(), (*e)[k][0].end(), tmp.begin());
- std::copy(e_output_old_[k].begin(), e_output_old_[k].end(),
- (*e)[k][0].begin());
- std::copy(tmp.begin(), tmp.end(), e_output_old_[k].begin());
+ // Synthesis filterbank.
+ std::array<float, kFftLength> e_extended;
+ constexpr float kIfftNormalization = 2.f / kFftLength;
+ fft_.Ifft(E, &e_extended);
+
+ auto& e0 = (*e)[0][ch];
+ auto& e0_old = e_output_old_[0][ch];
+
+ // Window and add the first half of e_extended with the second half of
+ // e_extended from the previous block.
+ for (size_t i = 0; i < kFftLengthBy2; ++i) {
+ e0[i] = e0_old[i] * kSqrtHanning[kFftLengthBy2 + i];
+ e0[i] += e_extended[i] * kSqrtHanning[i];
+ e0[i] *= kIfftNormalization;
+ }
+
+ // The second half of e_extended is stored for the succeeding frame.
+ std::copy(e_extended.begin() + kFftLengthBy2,
+ e_extended.begin() + kFftLength, std::begin(e0_old));
+
+ // Apply suppression gain to upper bands.
+ for (size_t b = 1; b < e->size(); ++b) {
+ auto& e_band = (*e)[b][ch];
+ for (size_t i = 0; i < kFftLengthBy2; ++i) {
+ e_band[i] *= high_bands_gain;
+ }
+ }
+
+ // Add comfort noise to band 1.
+ if (e->size() > 1) {
+ E.Assign(comfort_noise_high_band[ch]);
+ std::array<float, kFftLength> time_domain_high_band_noise;
+ fft_.Ifft(E, &time_domain_high_band_noise);
+
+ auto& e1 = (*e)[1][ch];
+ const float gain = high_bands_noise_scaling * kIfftNormalization;
+ for (size_t i = 0; i < kFftLengthBy2; ++i) {
+ e1[i] += time_domain_high_band_noise[i] * gain;
+ }
+ }
+
+ // Delay upper bands to match the delay of the filter bank.
+ for (size_t b = 1; b < e->size(); ++b) {
+ auto& e_band = (*e)[b][ch];
+ auto& e_band_old = e_output_old_[b][ch];
+ for (size_t i = 0; i < kFftLengthBy2; ++i) {
+ std::swap(e_band[i], e_band_old[i]);
+ }
+ }
+
+ // Clamp output of all bands.
+ for (size_t b = 0; b < e->size(); ++b) {
+ auto& e_band = (*e)[b][ch];
+ for (size_t i = 0; i < kFftLengthBy2; ++i) {
+ e_band[i] = rtc::SafeClamp(e_band[i], -32768.f, 32767.f);
+ }
}
}
}
diff --git a/modules/audio_processing/aec3/suppression_filter.h b/modules/audio_processing/aec3/suppression_filter.h
index 03b13c8..a35fb40 100644
--- a/modules/audio_processing/aec3/suppression_filter.h
+++ b/modules/audio_processing/aec3/suppression_filter.h
@@ -24,21 +24,24 @@
class SuppressionFilter {
public:
- SuppressionFilter(Aec3Optimization optimization, int sample_rate_hz);
+ SuppressionFilter(Aec3Optimization optimization,
+ int sample_rate_hz,
+ size_t num_capture_channels_);
~SuppressionFilter();
- void ApplyGain(const FftData& comfort_noise,
- const FftData& comfort_noise_high_bands,
+ void ApplyGain(rtc::ArrayView<const FftData> comfort_noise,
+ rtc::ArrayView<const FftData> comfort_noise_high_bands,
const std::array<float, kFftLengthBy2Plus1>& suppression_gain,
float high_bands_gain,
- const FftData& E_lowest_band,
+ rtc::ArrayView<const FftData> E_lowest_band,
std::vector<std::vector<std::vector<float>>>* e);
private:
const Aec3Optimization optimization_;
const int sample_rate_hz_;
+ const size_t num_capture_channels_;
const OouraFft ooura_fft_;
const Aec3Fft fft_;
- std::vector<std::array<float, kFftLengthBy2>> e_output_old_;
+ std::vector<std::vector<std::array<float, kFftLengthBy2>>> e_output_old_;
RTC_DISALLOW_COPY_AND_ASSIGN(SuppressionFilter);
};
diff --git a/modules/audio_processing/aec3/suppression_filter_unittest.cc b/modules/audio_processing/aec3/suppression_filter_unittest.cc
index 1e05a02..b55c719 100644
--- a/modules/audio_processing/aec3/suppression_filter_unittest.cc
+++ b/modules/audio_processing/aec3/suppression_filter_unittest.cc
@@ -51,46 +51,46 @@
// Verifies the check for null suppressor output.
TEST(SuppressionFilter, NullOutput) {
- FftData cn;
- FftData cn_high_bands;
- FftData E;
+ std::vector<FftData> cn(1);
+ std::vector<FftData> cn_high_bands(1);
+ std::vector<FftData> E(1);
std::array<float, kFftLengthBy2Plus1> gain;
- EXPECT_DEATH(SuppressionFilter(Aec3Optimization::kNone, 16000)
+ EXPECT_DEATH(SuppressionFilter(Aec3Optimization::kNone, 16000, 1)
.ApplyGain(cn, cn_high_bands, gain, 1.0f, E, nullptr),
"");
}
// Verifies the check for allowed sample rate.
TEST(SuppressionFilter, ProperSampleRate) {
- EXPECT_DEATH(SuppressionFilter(Aec3Optimization::kNone, 16001), "");
+ EXPECT_DEATH(SuppressionFilter(Aec3Optimization::kNone, 16001, 1), "");
}
#endif
// Verifies that no comfort noise is added when the gain is 1.
TEST(SuppressionFilter, ComfortNoiseInUnityGain) {
- SuppressionFilter filter(Aec3Optimization::kNone, 48000);
- FftData cn;
- FftData cn_high_bands;
+ SuppressionFilter filter(Aec3Optimization::kNone, 48000, 1);
+ std::vector<FftData> cn(1);
+ std::vector<FftData> cn_high_bands(1);
std::array<float, kFftLengthBy2Plus1> gain;
std::array<float, kFftLengthBy2> e_old_;
Aec3Fft fft;
e_old_.fill(0.f);
gain.fill(1.f);
- cn.re.fill(1.f);
- cn.im.fill(1.f);
- cn_high_bands.re.fill(1.f);
- cn_high_bands.im.fill(1.f);
+ cn[0].re.fill(1.f);
+ cn[0].im.fill(1.f);
+ cn_high_bands[0].re.fill(1.f);
+ cn_high_bands[0].im.fill(1.f);
std::vector<std::vector<std::vector<float>>> e(
3,
std::vector<std::vector<float>>(1, std::vector<float>(kBlockSize, 0.f)));
std::vector<std::vector<std::vector<float>>> e_ref = e;
- FftData E;
- fft.PaddedFft(e[0][0], e_old_, Aec3Fft::Window::kSqrtHanning, &E);
+ std::vector<FftData> E(1);
+ fft.PaddedFft(e[0][0], e_old_, Aec3Fft::Window::kSqrtHanning, &E[0]);
std::copy(e[0][0].begin(), e[0][0].end(), e_old_.begin());
filter.ApplyGain(cn, cn_high_bands, gain, 1.f, E, &e);
@@ -110,9 +110,9 @@
constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz);
constexpr size_t kNumChannels = 1;
- SuppressionFilter filter(Aec3Optimization::kNone, kSampleRateHz);
- FftData cn;
- FftData cn_high_bands;
+ SuppressionFilter filter(Aec3Optimization::kNone, kSampleRateHz, 1);
+ std::vector<FftData> cn(1);
+ std::vector<FftData> cn_high_bands(1);
std::array<float, kFftLengthBy2> e_old_;
Aec3Fft fft;
std::array<float, kFftLengthBy2Plus1> gain;
@@ -124,10 +124,10 @@
gain.fill(1.f);
std::for_each(gain.begin() + 10, gain.end(), [](float& a) { a = 0.f; });
- cn.re.fill(0.f);
- cn.im.fill(0.f);
- cn_high_bands.re.fill(0.f);
- cn_high_bands.im.fill(0.f);
+ cn[0].re.fill(0.f);
+ cn[0].im.fill(0.f);
+ cn_high_bands[0].re.fill(0.f);
+ cn_high_bands[0].im.fill(0.f);
size_t sample_counter = 0;
@@ -138,8 +138,8 @@
e0_input = std::inner_product(e[0][0].begin(), e[0][0].end(),
e[0][0].begin(), e0_input);
- FftData E;
- fft.PaddedFft(e[0][0], e_old_, Aec3Fft::Window::kSqrtHanning, &E);
+ std::vector<FftData> E(1);
+ fft.PaddedFft(e[0][0], e_old_, Aec3Fft::Window::kSqrtHanning, &E[0]);
std::copy(e[0][0].begin(), e[0][0].end(), e_old_.begin());
filter.ApplyGain(cn, cn_high_bands, gain, 1.f, E, &e);
@@ -157,11 +157,11 @@
constexpr int kSampleRateHz = 48000;
constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz);
- SuppressionFilter filter(Aec3Optimization::kNone, kSampleRateHz);
- FftData cn;
+ SuppressionFilter filter(Aec3Optimization::kNone, kSampleRateHz, 1);
+ std::vector<FftData> cn(1);
std::array<float, kFftLengthBy2> e_old_;
Aec3Fft fft;
- FftData cn_high_bands;
+ std::vector<FftData> cn_high_bands(1);
std::array<float, kFftLengthBy2Plus1> gain;
std::vector<std::vector<std::vector<float>>> e(
kNumBands, std::vector<std::vector<float>>(
@@ -170,10 +170,10 @@
gain.fill(1.f);
std::for_each(gain.begin() + 30, gain.end(), [](float& a) { a = 0.f; });
- cn.re.fill(0.f);
- cn.im.fill(0.f);
- cn_high_bands.re.fill(0.f);
- cn_high_bands.im.fill(0.f);
+ cn[0].re.fill(0.f);
+ cn[0].im.fill(0.f);
+ cn_high_bands[0].re.fill(0.f);
+ cn_high_bands[0].im.fill(0.f);
size_t sample_counter = 0;
@@ -184,8 +184,8 @@
e0_input = std::inner_product(e[0][0].begin(), e[0][0].end(),
e[0][0].begin(), e0_input);
- FftData E;
- fft.PaddedFft(e[0][0], e_old_, Aec3Fft::Window::kSqrtHanning, &E);
+ std::vector<FftData> E(1);
+ fft.PaddedFft(e[0][0], e_old_, Aec3Fft::Window::kSqrtHanning, &E[0]);
std::copy(e[0][0].begin(), e[0][0].end(), e_old_.begin());
filter.ApplyGain(cn, cn_high_bands, gain, 1.f, E, &e);
@@ -202,9 +202,9 @@
constexpr int kSampleRateHz = 48000;
constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz);
- SuppressionFilter filter(Aec3Optimization::kNone, kSampleRateHz);
- FftData cn;
- FftData cn_high_bands;
+ SuppressionFilter filter(Aec3Optimization::kNone, kSampleRateHz, 1);
+ std::vector<FftData> cn(1);
+ std::vector<FftData> cn_high_bands(1);
std::array<float, kFftLengthBy2> e_old_;
Aec3Fft fft;
std::array<float, kFftLengthBy2Plus1> gain;
@@ -214,10 +214,10 @@
gain.fill(1.f);
- cn.re.fill(0.f);
- cn.im.fill(0.f);
- cn_high_bands.re.fill(0.f);
- cn_high_bands.im.fill(0.f);
+ cn[0].re.fill(0.f);
+ cn[0].im.fill(0.f);
+ cn_high_bands[0].re.fill(0.f);
+ cn_high_bands[0].im.fill(0.f);
for (size_t k = 0; k < 100; ++k) {
for (size_t band = 0; band < kNumBands; ++band) {
@@ -228,8 +228,8 @@
}
}
- FftData E;
- fft.PaddedFft(e[0][0], e_old_, Aec3Fft::Window::kSqrtHanning, &E);
+ std::vector<FftData> E(1);
+ fft.PaddedFft(e[0][0], e_old_, Aec3Fft::Window::kSqrtHanning, &E[0]);
std::copy(e[0][0].begin(), e[0][0].end(), e_old_.begin());
filter.ApplyGain(cn, cn_high_bands, gain, 1.f, E, &e);