RNN VAD optimizations: `VectorMath::DotProduct()` NEON arm64
Results: RNN VAD realtime factor improved from 140x to 195x (+55x)
Test device: Pixel 2 XL
Benchmark setup: max clock speed forced on all the cores by
setting "performance" as scaling governor
Bug: webrtc:10480
Change-Id: I3e92f643853ad1fe990db909c578ce78ee826c03
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/198842
Reviewed-by: Per Ã…hgren <peah@webrtc.org>
Commit-Queue: Alessio Bazzica <alessiob@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#32888}
diff --git a/modules/audio_processing/agc2/BUILD.gn b/modules/audio_processing/agc2/BUILD.gn
index c6667df..7b71f6a 100644
--- a/modules/audio_processing/agc2/BUILD.gn
+++ b/modules/audio_processing/agc2/BUILD.gn
@@ -162,6 +162,13 @@
"vad_with_level.cc",
"vad_with_level.h",
]
+
+ defines = []
+ if (rtc_build_with_neon && current_cpu != "arm64") {
+ suppressed_configs += [ "//build/config/compiler:compiler_arm_fpu" ]
+ cflags = [ "-mfpu=neon" ]
+ }
+
deps = [
":common",
":cpu_features",
diff --git a/modules/audio_processing/agc2/rnn_vad/BUILD.gn b/modules/audio_processing/agc2/rnn_vad/BUILD.gn
index 9895b76..4732efd 100644
--- a/modules/audio_processing/agc2/rnn_vad/BUILD.gn
+++ b/modules/audio_processing/agc2/rnn_vad/BUILD.gn
@@ -17,6 +17,7 @@
"rnn.h",
]
+ defines = []
if (rtc_build_with_neon && current_cpu != "arm64") {
suppressed_configs += [ "//build/config/compiler:compiler_arm_fpu" ]
cflags = [ "-mfpu=neon" ]
@@ -84,6 +85,13 @@
"rnn_gru.cc",
"rnn_gru.h",
]
+
+ defines = []
+ if (rtc_build_with_neon && current_cpu != "arm64") {
+ suppressed_configs += [ "//build/config/compiler:compiler_arm_fpu" ]
+ cflags = [ "-mfpu=neon" ]
+ }
+
deps = [
":rnn_vad_common",
":vector_math",
@@ -138,6 +146,13 @@
"pitch_search_internal.cc",
"pitch_search_internal.h",
]
+
+ defines = []
+ if (rtc_build_with_neon && current_cpu != "arm64") {
+ suppressed_configs += [ "//build/config/compiler:compiler_arm_fpu" ]
+ cflags = [ "-mfpu=neon" ]
+ }
+
deps = [
":rnn_vad_auto_correlation",
":rnn_vad_common",
@@ -253,6 +268,13 @@
"symmetric_matrix_buffer_unittest.cc",
"vector_math_unittest.cc",
]
+
+ defines = []
+ if (rtc_build_with_neon && current_cpu != "arm64") {
+ suppressed_configs += [ "//build/config/compiler:compiler_arm_fpu" ]
+ cflags = [ "-mfpu=neon" ]
+ }
+
deps = [
":rnn_vad",
":rnn_vad_auto_correlation",
diff --git a/modules/audio_processing/agc2/rnn_vad/rnn_vad_unittest.cc b/modules/audio_processing/agc2/rnn_vad/rnn_vad_unittest.cc
index f223d58..989b235 100644
--- a/modules/audio_processing/agc2/rnn_vad/rnn_vad_unittest.cc
+++ b/modules/audio_processing/agc2/rnn_vad/rnn_vad_unittest.cc
@@ -166,6 +166,9 @@
if (available.sse2) {
v.push_back({/*sse2=*/true, /*avx2=*/false, /*neon=*/false});
}
+ if (available.neon) {
+ v.push_back({/*sse2=*/false, /*avx2=*/false, /*neon=*/true});
+ }
return v;
}
diff --git a/modules/audio_processing/agc2/rnn_vad/vector_math.h b/modules/audio_processing/agc2/rnn_vad/vector_math.h
index 0600b90..47f6811 100644
--- a/modules/audio_processing/agc2/rnn_vad/vector_math.h
+++ b/modules/audio_processing/agc2/rnn_vad/vector_math.h
@@ -14,6 +14,9 @@
// Defines WEBRTC_ARCH_X86_FAMILY, used below.
#include "rtc_base/system/arch.h"
+#if defined(WEBRTC_HAS_NEON)
+#include <arm_neon.h>
+#endif
#if defined(WEBRTC_ARCH_X86_FAMILY)
#include <emmintrin.h>
#endif
@@ -70,8 +73,31 @@
}
return dot_product;
}
+#elif defined(WEBRTC_HAS_NEON) && defined(WEBRTC_ARCH_ARM64)
+ if (cpu_features_.neon) {
+ float32x4_t accumulator = vdupq_n_f32(0.f);
+ constexpr int kBlockSizeLog2 = 2;
+ constexpr int kBlockSize = 1 << kBlockSizeLog2;
+ const int incomplete_block_index = (x.size() >> kBlockSizeLog2)
+ << kBlockSizeLog2;
+ for (int i = 0; i < incomplete_block_index; i += kBlockSize) {
+ RTC_DCHECK_LE(i + kBlockSize, x.size());
+ const float32x4_t x_i = vld1q_f32(&x[i]);
+ const float32x4_t y_i = vld1q_f32(&y[i]);
+ accumulator = vfmaq_f32(accumulator, x_i, y_i);
+ }
+ // Reduce `accumulator` by addition.
+ const float32x2_t tmp =
+ vpadd_f32(vget_low_f32(accumulator), vget_high_f32(accumulator));
+ float dot_product = vget_lane_f32(vpadd_f32(tmp, vrev64_f32(tmp)), 0);
+ // Add the result for the last block if incomplete.
+ for (int i = incomplete_block_index;
+ i < rtc::dchecked_cast<int>(x.size()); ++i) {
+ dot_product += x[i] * y[i];
+ }
+ return dot_product;
+ }
#endif
- // TODO(bugs.webrtc.org/10480): Add NEON alternative implementation.
return std::inner_product(x.begin(), x.end(), y.begin(), 0.f);
}
diff --git a/modules/audio_processing/agc2/rnn_vad/vector_math_unittest.cc b/modules/audio_processing/agc2/rnn_vad/vector_math_unittest.cc
index 9a2d5bc..45fd65d 100644
--- a/modules/audio_processing/agc2/rnn_vad/vector_math_unittest.cc
+++ b/modules/audio_processing/agc2/rnn_vad/vector_math_unittest.cc
@@ -52,6 +52,9 @@
if (available.sse2) {
v.push_back({/*sse2=*/true, /*avx2=*/false, /*neon=*/false});
}
+ if (available.neon) {
+ v.push_back({/*sse2=*/false, /*avx2=*/false, /*neon=*/true});
+ }
return v;
}