Further SSE2 optimizations for the AEC3 adaptive filter.
This CL adds further SSE2 optimizations for the AEC3
adaptive filter.
The changes are bitexact
BUG=webrtc:6018
Review-Url: https://codereview.webrtc.org/2810133002
Cr-Original-Commit-Position: refs/heads/master@{#17667}
Cr-Mirrored-From: https://chromium.googlesource.com/external/webrtc
Cr-Mirrored-Commit: 69ffdf49382cda8ae4baa7173ab5237cb815292a
diff --git a/modules/audio_processing/aec3/adaptive_fir_filter.cc b/modules/audio_processing/aec3/adaptive_fir_filter.cc
index 7f66ce5..3174fa7 100644
--- a/modules/audio_processing/aec3/adaptive_fir_filter.cc
+++ b/modules/audio_processing/aec3/adaptive_fir_filter.cc
@@ -36,10 +36,15 @@
fft.Fft(&h, H);
}
+} // namespace
+
+namespace aec3 {
+
// Computes and stores the frequency response of the filter.
void UpdateFrequencyResponse(
rtc::ArrayView<const FftData> H,
std::vector<std::array<float, kFftLengthBy2Plus1>>* H2) {
+ RTC_DCHECK_EQ(H.size(), H2->size());
for (size_t k = 0; k < H.size(); ++k) {
std::transform(H[k].re.begin(), H[k].re.end(), H[k].im.begin(),
(*H2)[k].begin(),
@@ -47,6 +52,27 @@
}
}
+#if defined(WEBRTC_ARCH_X86_FAMILY)
+// Computes and stores the frequency response of the filter.
+void UpdateFrequencyResponse_SSE2(
+ rtc::ArrayView<const FftData> H,
+ std::vector<std::array<float, kFftLengthBy2Plus1>>* H2) {
+ RTC_DCHECK_EQ(H.size(), H2->size());
+ for (size_t k = 0; k < H.size(); ++k) {
+ for (size_t j = 0; j < kFftLengthBy2; j += 4) {
+ const __m128 re = _mm_loadu_ps(&H[k].re[j]);
+ const __m128 re2 = _mm_mul_ps(re, re);
+ const __m128 im = _mm_loadu_ps(&H[k].im[j]);
+ const __m128 im2 = _mm_mul_ps(im, im);
+ const __m128 H2_k_j = _mm_add_ps(re2, im2);
+ _mm_storeu_ps(&(*H2)[k][j], H2_k_j);
+ }
+ (*H2)[k][kFftLengthBy2] = H[k].re[kFftLengthBy2] * H[k].re[kFftLengthBy2] +
+ H[k].im[kFftLengthBy2] * H[k].im[kFftLengthBy2];
+ }
+}
+#endif
+
// Computes and stores the echo return loss estimate of the filter, which is the
// sum of the partition frequency responses.
void UpdateErlEstimator(
@@ -59,9 +85,24 @@
}
}
-} // namespace
-
-namespace aec3 {
+#if defined(WEBRTC_ARCH_X86_FAMILY)
+// Computes and stores the echo return loss estimate of the filter, which is the
+// sum of the partition frequency responses.
+void UpdateErlEstimator_SSE2(
+ const std::vector<std::array<float, kFftLengthBy2Plus1>>& H2,
+ std::array<float, kFftLengthBy2Plus1>* erl) {
+ erl->fill(0.f);
+ for (auto& H2_j : H2) {
+ for (size_t k = 0; k < kFftLengthBy2; k += 4) {
+ const __m128 H2_j_k = _mm_loadu_ps(&H2_j[k]);
+ __m128 erl_k = _mm_loadu_ps(&(*erl)[k]);
+ erl_k = _mm_add_ps(erl_k, H2_j_k);
+ _mm_storeu_ps(&(*erl)[k], erl_k);
+ }
+ (*erl)[kFftLengthBy2] += H2_j[kFftLengthBy2];
+ }
+}
+#endif
// Adapts the filter partitions as H(t+1)=H(t)+G(t)*conj(X(t)).
void AdaptPartitions(const RenderBuffer& render_buffer,
@@ -290,8 +331,17 @@
: 0;
// Update the frequency response and echo return loss for the filter.
- UpdateFrequencyResponse(H_, &H2_);
- UpdateErlEstimator(H2_, &erl_);
+ switch (optimization_) {
+#if defined(WEBRTC_ARCH_X86_FAMILY)
+ case Aec3Optimization::kSse2:
+ aec3::UpdateFrequencyResponse_SSE2(H_, &H2_);
+ aec3::UpdateErlEstimator_SSE2(H2_, &erl_);
+ break;
+#endif
+ default:
+ aec3::UpdateFrequencyResponse(H_, &H2_);
+ aec3::UpdateErlEstimator(H2_, &erl_);
+ }
}
} // namespace webrtc
diff --git a/modules/audio_processing/aec3/adaptive_fir_filter.h b/modules/audio_processing/aec3/adaptive_fir_filter.h
index 4fe10ea..78b6422 100644
--- a/modules/audio_processing/aec3/adaptive_fir_filter.h
+++ b/modules/audio_processing/aec3/adaptive_fir_filter.h
@@ -25,6 +25,27 @@
namespace webrtc {
namespace aec3 {
+// Computes and stores the frequency response of the filter.
+void UpdateFrequencyResponse(
+ rtc::ArrayView<const FftData> H,
+ std::vector<std::array<float, kFftLengthBy2Plus1>>* H2);
+#if defined(WEBRTC_ARCH_X86_FAMILY)
+void UpdateFrequencyResponse_SSE2(
+ rtc::ArrayView<const FftData> H,
+ std::vector<std::array<float, kFftLengthBy2Plus1>>* H2);
+#endif
+
+// Computes and stores the echo return loss estimate of the filter, which is the
+// sum of the partition frequency responses.
+void UpdateErlEstimator(
+ const std::vector<std::array<float, kFftLengthBy2Plus1>>& H2,
+ std::array<float, kFftLengthBy2Plus1>* erl);
+#if defined(WEBRTC_ARCH_X86_FAMILY)
+void UpdateErlEstimator_SSE2(
+ const std::vector<std::array<float, kFftLengthBy2Plus1>>& H2,
+ std::array<float, kFftLengthBy2Plus1>* erl);
+#endif
+
// Adapts the filter partitions.
void AdaptPartitions(const RenderBuffer& render_buffer,
const FftData& G,
diff --git a/modules/audio_processing/aec3/adaptive_fir_filter_unittest.cc b/modules/audio_processing/aec3/adaptive_fir_filter_unittest.cc
index 85d9769..6d1a582 100644
--- a/modules/audio_processing/aec3/adaptive_fir_filter_unittest.cc
+++ b/modules/audio_processing/aec3/adaptive_fir_filter_unittest.cc
@@ -42,9 +42,9 @@
} // namespace
#if defined(WEBRTC_ARCH_X86_FAMILY)
-// Verifies that the optimized methods are bitexact to their reference
-// counterparts.
-TEST(AdaptiveFirFilter, TestOptimizations) {
+// Verifies that the optimized methods for filter adaptation are bitexact to
+// their reference counterparts.
+TEST(AdaptiveFirFilter, FilterAdaptationOptimizations) {
bool use_sse2 = (WebRtc_GetCPUInfo(kSSE2) != 0);
if (use_sse2) {
RenderBuffer render_buffer(Aec3Optimization::kNone, 3, 12,
@@ -93,6 +93,59 @@
}
}
+// Verifies that the optimized method for frequency response computation is
+// bitexact to the reference counterpart.
+TEST(AdaptiveFirFilter, UpdateFrequencyResponseOptimization) {
+ bool use_sse2 = (WebRtc_GetCPUInfo(kSSE2) != 0);
+ if (use_sse2) {
+ const size_t kNumPartitions = 12;
+ std::vector<FftData> H(kNumPartitions);
+ std::vector<std::array<float, kFftLengthBy2Plus1>> H2(kNumPartitions);
+ std::vector<std::array<float, kFftLengthBy2Plus1>> H2_SSE2(kNumPartitions);
+
+ for (size_t j = 0; j < H.size(); ++j) {
+ for (size_t k = 0; k < H[j].re.size(); ++k) {
+ H[j].re[k] = k + j / 3.f;
+ H[j].im[k] = j + k / 7.f;
+ }
+ }
+
+ UpdateFrequencyResponse(H, &H2);
+ UpdateFrequencyResponse_SSE2(H, &H2_SSE2);
+
+ for (size_t j = 0; j < H2.size(); ++j) {
+ for (size_t k = 0; k < H[j].re.size(); ++k) {
+ EXPECT_FLOAT_EQ(H2[j][k], H2_SSE2[j][k]);
+ }
+ }
+ }
+}
+
+// Verifies that the optimized method for echo return loss computation is
+// bitexact to the reference counterpart.
+TEST(AdaptiveFirFilter, UpdateErlOptimization) {
+ bool use_sse2 = (WebRtc_GetCPUInfo(kSSE2) != 0);
+ if (use_sse2) {
+ const size_t kNumPartitions = 12;
+ std::vector<std::array<float, kFftLengthBy2Plus1>> H2(kNumPartitions);
+ std::array<float, kFftLengthBy2Plus1> erl;
+ std::array<float, kFftLengthBy2Plus1> erl_SSE2;
+
+ for (size_t j = 0; j < H2.size(); ++j) {
+ for (size_t k = 0; k < H2[j].size(); ++k) {
+ H2[j][k] = k + j / 3.f;
+ }
+ }
+
+ UpdateErlEstimator(H2, &erl);
+ UpdateErlEstimator_SSE2(H2, &erl_SSE2);
+
+ for (size_t j = 0; j < erl.size(); ++j) {
+ EXPECT_FLOAT_EQ(erl[j], erl_SSE2[j]);
+ }
+ }
+}
+
#endif
#if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID)