blob: 1721e9c98358fb02eed0102aeac70f0060a1d5ea [file] [log] [blame]
/*
* Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/aec3/matched_filter.h"
// Defines WEBRTC_ARCH_X86_FAMILY, used below.
#include "rtc_base/system/arch.h"
#if defined(WEBRTC_HAS_NEON)
#include <arm_neon.h>
#endif
#if defined(WEBRTC_ARCH_X86_FAMILY)
#include <emmintrin.h>
#endif
#include <algorithm>
#include <cstddef>
#include <initializer_list>
#include <iterator>
#include <numeric>
#include "modules/audio_processing/aec3/downsampled_render_buffer.h"
#include "modules/audio_processing/logging/apm_data_dumper.h"
#include "rtc_base/checks.h"
#include "rtc_base/logging.h"
namespace webrtc {
namespace aec3 {
#if defined(WEBRTC_HAS_NEON)
void MatchedFilterCore_NEON(size_t x_start_index,
float x2_sum_threshold,
float smoothing,
rtc::ArrayView<const float> x,
rtc::ArrayView<const float> y,
rtc::ArrayView<float> h,
bool* filters_updated,
float* error_sum) {
const int h_size = static_cast<int>(h.size());
const int x_size = static_cast<int>(x.size());
RTC_DCHECK_EQ(0, h_size % 4);
// Process for all samples in the sub-block.
for (size_t i = 0; i < y.size(); ++i) {
// Apply the matched filter as filter * x, and compute x * x.
RTC_DCHECK_GT(x_size, x_start_index);
const float* x_p = &x[x_start_index];
const float* h_p = &h[0];
// Initialize values for the accumulation.
float32x4_t s_128 = vdupq_n_f32(0);
float32x4_t x2_sum_128 = vdupq_n_f32(0);
float x2_sum = 0.f;
float s = 0;
// Compute loop chunk sizes until, and after, the wraparound of the circular
// buffer for x.
const int chunk1 =
std::min(h_size, static_cast<int>(x_size - x_start_index));
// Perform the loop in two chunks.
const int chunk2 = h_size - chunk1;
for (int limit : {chunk1, chunk2}) {
// Perform 128 bit vector operations.
const int limit_by_4 = limit >> 2;
for (int k = limit_by_4; k > 0; --k, h_p += 4, x_p += 4) {
// Load the data into 128 bit vectors.
const float32x4_t x_k = vld1q_f32(x_p);
const float32x4_t h_k = vld1q_f32(h_p);
// Compute and accumulate x * x and h * x.
x2_sum_128 = vmlaq_f32(x2_sum_128, x_k, x_k);
s_128 = vmlaq_f32(s_128, h_k, x_k);
}
// Perform non-vector operations for any remaining items.
for (int k = limit - limit_by_4 * 4; k > 0; --k, ++h_p, ++x_p) {
const float x_k = *x_p;
x2_sum += x_k * x_k;
s += *h_p * x_k;
}
x_p = &x[0];
}
// Combine the accumulated vector and scalar values.
float* v = reinterpret_cast<float*>(&x2_sum_128);
x2_sum += v[0] + v[1] + v[2] + v[3];
v = reinterpret_cast<float*>(&s_128);
s += v[0] + v[1] + v[2] + v[3];
// Compute the matched filter error.
float e = y[i] - s;
const bool saturation = y[i] >= 32000.f || y[i] <= -32000.f;
(*error_sum) += e * e;
// Update the matched filter estimate in an NLMS manner.
if (x2_sum > x2_sum_threshold && !saturation) {
RTC_DCHECK_LT(0.f, x2_sum);
const float alpha = smoothing * e / x2_sum;
const float32x4_t alpha_128 = vmovq_n_f32(alpha);
// filter = filter + smoothing * (y - filter * x) * x / x * x.
float* h_p = &h[0];
x_p = &x[x_start_index];
// Perform the loop in two chunks.
for (int limit : {chunk1, chunk2}) {
// Perform 128 bit vector operations.
const int limit_by_4 = limit >> 2;
for (int k = limit_by_4; k > 0; --k, h_p += 4, x_p += 4) {
// Load the data into 128 bit vectors.
float32x4_t h_k = vld1q_f32(h_p);
const float32x4_t x_k = vld1q_f32(x_p);
// Compute h = h + alpha * x.
h_k = vmlaq_f32(h_k, alpha_128, x_k);
// Store the result.
vst1q_f32(h_p, h_k);
}
// Perform non-vector operations for any remaining items.
for (int k = limit - limit_by_4 * 4; k > 0; --k, ++h_p, ++x_p) {
*h_p += alpha * *x_p;
}
x_p = &x[0];
}
*filters_updated = true;
}
x_start_index = x_start_index > 0 ? x_start_index - 1 : x_size - 1;
}
}
#endif
#if defined(WEBRTC_ARCH_X86_FAMILY)
void MatchedFilterCore_SSE2(size_t x_start_index,
float x2_sum_threshold,
float smoothing,
rtc::ArrayView<const float> x,
rtc::ArrayView<const float> y,
rtc::ArrayView<float> h,
bool* filters_updated,
float* error_sum) {
const int h_size = static_cast<int>(h.size());
const int x_size = static_cast<int>(x.size());
RTC_DCHECK_EQ(0, h_size % 4);
// Process for all samples in the sub-block.
for (size_t i = 0; i < y.size(); ++i) {
// Apply the matched filter as filter * x, and compute x * x.
RTC_DCHECK_GT(x_size, x_start_index);
const float* x_p = &x[x_start_index];
const float* h_p = &h[0];
// Initialize values for the accumulation.
__m128 s_128 = _mm_set1_ps(0);
__m128 x2_sum_128 = _mm_set1_ps(0);
float x2_sum = 0.f;
float s = 0;
// Compute loop chunk sizes until, and after, the wraparound of the circular
// buffer for x.
const int chunk1 =
std::min(h_size, static_cast<int>(x_size - x_start_index));
// Perform the loop in two chunks.
const int chunk2 = h_size - chunk1;
for (int limit : {chunk1, chunk2}) {
// Perform 128 bit vector operations.
const int limit_by_4 = limit >> 2;
for (int k = limit_by_4; k > 0; --k, h_p += 4, x_p += 4) {
// Load the data into 128 bit vectors.
const __m128 x_k = _mm_loadu_ps(x_p);
const __m128 h_k = _mm_loadu_ps(h_p);
const __m128 xx = _mm_mul_ps(x_k, x_k);
// Compute and accumulate x * x and h * x.
x2_sum_128 = _mm_add_ps(x2_sum_128, xx);
const __m128 hx = _mm_mul_ps(h_k, x_k);
s_128 = _mm_add_ps(s_128, hx);
}
// Perform non-vector operations for any remaining items.
for (int k = limit - limit_by_4 * 4; k > 0; --k, ++h_p, ++x_p) {
const float x_k = *x_p;
x2_sum += x_k * x_k;
s += *h_p * x_k;
}
x_p = &x[0];
}
// Combine the accumulated vector and scalar values.
float* v = reinterpret_cast<float*>(&x2_sum_128);
x2_sum += v[0] + v[1] + v[2] + v[3];
v = reinterpret_cast<float*>(&s_128);
s += v[0] + v[1] + v[2] + v[3];
// Compute the matched filter error.
float e = y[i] - s;
const bool saturation = y[i] >= 32000.f || y[i] <= -32000.f;
(*error_sum) += e * e;
// Update the matched filter estimate in an NLMS manner.
if (x2_sum > x2_sum_threshold && !saturation) {
RTC_DCHECK_LT(0.f, x2_sum);
const float alpha = smoothing * e / x2_sum;
const __m128 alpha_128 = _mm_set1_ps(alpha);
// filter = filter + smoothing * (y - filter * x) * x / x * x.
float* h_p = &h[0];
x_p = &x[x_start_index];
// Perform the loop in two chunks.
for (int limit : {chunk1, chunk2}) {
// Perform 128 bit vector operations.
const int limit_by_4 = limit >> 2;
for (int k = limit_by_4; k > 0; --k, h_p += 4, x_p += 4) {
// Load the data into 128 bit vectors.
__m128 h_k = _mm_loadu_ps(h_p);
const __m128 x_k = _mm_loadu_ps(x_p);
// Compute h = h + alpha * x.
const __m128 alpha_x = _mm_mul_ps(alpha_128, x_k);
h_k = _mm_add_ps(h_k, alpha_x);
// Store the result.
_mm_storeu_ps(h_p, h_k);
}
// Perform non-vector operations for any remaining items.
for (int k = limit - limit_by_4 * 4; k > 0; --k, ++h_p, ++x_p) {
*h_p += alpha * *x_p;
}
x_p = &x[0];
}
*filters_updated = true;
}
x_start_index = x_start_index > 0 ? x_start_index - 1 : x_size - 1;
}
}
#endif
void MatchedFilterCore(size_t x_start_index,
float x2_sum_threshold,
float smoothing,
rtc::ArrayView<const float> x,
rtc::ArrayView<const float> y,
rtc::ArrayView<float> h,
bool* filters_updated,
float* error_sum) {
// Process for all samples in the sub-block.
for (size_t i = 0; i < y.size(); ++i) {
// Apply the matched filter as filter * x, and compute x * x.
float x2_sum = 0.f;
float s = 0;
size_t x_index = x_start_index;
for (size_t k = 0; k < h.size(); ++k) {
x2_sum += x[x_index] * x[x_index];
s += h[k] * x[x_index];
x_index = x_index < (x.size() - 1) ? x_index + 1 : 0;
}
// Compute the matched filter error.
float e = y[i] - s;
const bool saturation = y[i] >= 32000.f || y[i] <= -32000.f;
(*error_sum) += e * e;
// Update the matched filter estimate in an NLMS manner.
if (x2_sum > x2_sum_threshold && !saturation) {
RTC_DCHECK_LT(0.f, x2_sum);
const float alpha = smoothing * e / x2_sum;
// filter = filter + smoothing * (y - filter * x) * x / x * x.
size_t x_index = x_start_index;
for (size_t k = 0; k < h.size(); ++k) {
h[k] += alpha * x[x_index];
x_index = x_index < (x.size() - 1) ? x_index + 1 : 0;
}
*filters_updated = true;
}
x_start_index = x_start_index > 0 ? x_start_index - 1 : x.size() - 1;
}
}
} // namespace aec3
MatchedFilter::MatchedFilter(ApmDataDumper* data_dumper,
Aec3Optimization optimization,
size_t sub_block_size,
size_t window_size_sub_blocks,
int num_matched_filters,
size_t alignment_shift_sub_blocks,
float excitation_limit,
float smoothing_fast,
float smoothing_slow,
float matching_filter_threshold)
: data_dumper_(data_dumper),
optimization_(optimization),
sub_block_size_(sub_block_size),
filter_intra_lag_shift_(alignment_shift_sub_blocks * sub_block_size_),
filters_(
num_matched_filters,
std::vector<float>(window_size_sub_blocks * sub_block_size_, 0.f)),
lag_estimates_(num_matched_filters),
filters_offsets_(num_matched_filters, 0),
excitation_limit_(excitation_limit),
smoothing_fast_(smoothing_fast),
smoothing_slow_(smoothing_slow),
matching_filter_threshold_(matching_filter_threshold) {
RTC_DCHECK(data_dumper);
RTC_DCHECK_LT(0, window_size_sub_blocks);
RTC_DCHECK((kBlockSize % sub_block_size) == 0);
RTC_DCHECK((sub_block_size % 4) == 0);
}
MatchedFilter::~MatchedFilter() = default;
void MatchedFilter::Reset() {
for (auto& f : filters_) {
std::fill(f.begin(), f.end(), 0.f);
}
for (auto& l : lag_estimates_) {
l = MatchedFilter::LagEstimate();
}
}
void MatchedFilter::Update(const DownsampledRenderBuffer& render_buffer,
rtc::ArrayView<const float> capture,
bool use_slow_smoothing) {
RTC_DCHECK_EQ(sub_block_size_, capture.size());
auto& y = capture;
const float smoothing =
use_slow_smoothing ? smoothing_slow_ : smoothing_fast_;
const float x2_sum_threshold =
filters_[0].size() * excitation_limit_ * excitation_limit_;
// Apply all matched filters.
size_t alignment_shift = 0;
for (size_t n = 0; n < filters_.size(); ++n) {
float error_sum = 0.f;
bool filters_updated = false;
size_t x_start_index =
(render_buffer.read + alignment_shift + sub_block_size_ - 1) %
render_buffer.buffer.size();
switch (optimization_) {
#if defined(WEBRTC_ARCH_X86_FAMILY)
case Aec3Optimization::kSse2:
aec3::MatchedFilterCore_SSE2(x_start_index, x2_sum_threshold, smoothing,
render_buffer.buffer, y, filters_[n],
&filters_updated, &error_sum);
break;
case Aec3Optimization::kAvx2:
aec3::MatchedFilterCore_AVX2(x_start_index, x2_sum_threshold, smoothing,
render_buffer.buffer, y, filters_[n],
&filters_updated, &error_sum);
break;
#endif
#if defined(WEBRTC_HAS_NEON)
case Aec3Optimization::kNeon:
aec3::MatchedFilterCore_NEON(x_start_index, x2_sum_threshold, smoothing,
render_buffer.buffer, y, filters_[n],
&filters_updated, &error_sum);
break;
#endif
default:
aec3::MatchedFilterCore(x_start_index, x2_sum_threshold, smoothing,
render_buffer.buffer, y, filters_[n],
&filters_updated, &error_sum);
}
// Compute anchor for the matched filter error.
const float error_sum_anchor =
std::inner_product(y.begin(), y.end(), y.begin(), 0.f);
// Estimate the lag in the matched filter as the distance to the portion in
// the filter that contributes the most to the matched filter output. This
// is detected as the peak of the matched filter.
const size_t lag_estimate = std::distance(
filters_[n].begin(),
std::max_element(
filters_[n].begin(), filters_[n].end(),
[](float a, float b) -> bool { return a * a < b * b; }));
// Update the lag estimates for the matched filter.
lag_estimates_[n] = LagEstimate(
error_sum_anchor - error_sum,
(lag_estimate > 2 && lag_estimate < (filters_[n].size() - 10) &&
error_sum < matching_filter_threshold_ * error_sum_anchor),
lag_estimate + alignment_shift, filters_updated);
RTC_DCHECK_GE(10, filters_.size());
switch (n) {
case 0:
data_dumper_->DumpRaw("aec3_correlator_0_h", filters_[0]);
break;
case 1:
data_dumper_->DumpRaw("aec3_correlator_1_h", filters_[1]);
break;
case 2:
data_dumper_->DumpRaw("aec3_correlator_2_h", filters_[2]);
break;
case 3:
data_dumper_->DumpRaw("aec3_correlator_3_h", filters_[3]);
break;
case 4:
data_dumper_->DumpRaw("aec3_correlator_4_h", filters_[4]);
break;
case 5:
data_dumper_->DumpRaw("aec3_correlator_5_h", filters_[5]);
break;
case 6:
data_dumper_->DumpRaw("aec3_correlator_6_h", filters_[6]);
break;
case 7:
data_dumper_->DumpRaw("aec3_correlator_7_h", filters_[7]);
break;
case 8:
data_dumper_->DumpRaw("aec3_correlator_8_h", filters_[8]);
break;
case 9:
data_dumper_->DumpRaw("aec3_correlator_9_h", filters_[9]);
break;
default:
RTC_NOTREACHED();
}
alignment_shift += filter_intra_lag_shift_;
}
}
void MatchedFilter::LogFilterProperties(int sample_rate_hz,
size_t shift,
size_t downsampling_factor) const {
size_t alignment_shift = 0;
constexpr int kFsBy1000 = 16;
for (size_t k = 0; k < filters_.size(); ++k) {
int start = static_cast<int>(alignment_shift * downsampling_factor);
int end = static_cast<int>((alignment_shift + filters_[k].size()) *
downsampling_factor);
RTC_LOG(LS_VERBOSE) << "Filter " << k << ": start: "
<< (start - static_cast<int>(shift)) / kFsBy1000
<< " ms, end: "
<< (end - static_cast<int>(shift)) / kFsBy1000
<< " ms.";
alignment_shift += filter_intra_lag_shift_;
}
}
} // namespace webrtc