RNN VAD: `VectorMath::DotProduct` with AVX2 optimization

This CL adds a new library for the RNN VAD that provides (optimized)
vector math ops. The scheme is the same of the `VectorMath` class of AEC3
to ensure correct builds across different platforms.

Bug: webrtc:10480
Change-Id: I96bcfbf930ca27388ab5f2d52c022ddb73acf8e6
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/194326
Commit-Queue: Alessio Bazzica <alessiob@webrtc.org>
Reviewed-by: Per Ã…hgren <peah@webrtc.org>
Reviewed-by: Gustaf Ullberg <gustaf@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#32741}
diff --git a/modules/audio_processing/agc2/rnn_vad/BUILD.gn b/modules/audio_processing/agc2/rnn_vad/BUILD.gn
index a4285ba..fafea42 100644
--- a/modules/audio_processing/agc2/rnn_vad/BUILD.gn
+++ b/modules/audio_processing/agc2/rnn_vad/BUILD.gn
@@ -78,6 +78,35 @@
   ]
 }
 
+rtc_source_set("vector_math") {
+  sources = [ "vector_math.h" ]
+  deps = [
+    "..:cpu_features",
+    "../../../../api:array_view",
+    "../../../../rtc_base:checks",
+    "../../../../rtc_base/system:arch",
+  ]
+}
+
+if (current_cpu == "x86" || current_cpu == "x64") {
+  rtc_library("vector_math_avx2") {
+    sources = [ "vector_math_avx2.cc" ]
+    if (is_win) {
+      cflags = [ "/arch:AVX2" ]
+    } else {
+      cflags = [
+        "-mavx2",
+        "-mfma",
+      ]
+    }
+    deps = [
+      ":vector_math",
+      "../../../../api:array_view",
+      "../../../../rtc_base:checks",
+    ]
+  }
+}
+
 rtc_library("rnn_vad_pitch") {
   sources = [
     "pitch_search.cc",
@@ -88,6 +117,7 @@
   deps = [
     ":rnn_vad_auto_correlation",
     ":rnn_vad_common",
+    ":vector_math",
     "..:cpu_features",
     "../../../../api:array_view",
     "../../../../rtc_base:checks",
@@ -95,6 +125,9 @@
     "../../../../rtc_base:safe_compare",
     "../../../../rtc_base:safe_conversions",
   ]
+  if (current_cpu == "x86" || current_cpu == "x64") {
+    deps += [ ":vector_math_avx2" ]
+  }
 }
 
 rtc_source_set("rnn_vad_ring_buffer") {
@@ -191,6 +224,7 @@
       "spectral_features_internal_unittest.cc",
       "spectral_features_unittest.cc",
       "symmetric_matrix_buffer_unittest.cc",
+      "vector_math_unittest.cc",
     ]
     deps = [
       ":rnn_vad",
@@ -203,6 +237,7 @@
       ":rnn_vad_spectral_features",
       ":rnn_vad_symmetric_matrix_buffer",
       ":test_utils",
+      ":vector_math",
       "..:cpu_features",
       "../..:audioproc_test_utils",
       "../../../../api:array_view",
@@ -216,6 +251,9 @@
       "../../utility:pffft_wrapper",
       "//third_party/rnnoise:rnn_vad",
     ]
+    if (current_cpu == "x86" || current_cpu == "x64") {
+      deps += [ ":vector_math_avx2" ]
+    }
     absl_deps = [ "//third_party/abseil-cpp/absl/memory" ]
     data = unittest_resources
     if (is_ios) {
diff --git a/modules/audio_processing/agc2/rnn_vad/vector_math.h b/modules/audio_processing/agc2/rnn_vad/vector_math.h
new file mode 100644
index 0000000..a989682
--- /dev/null
+++ b/modules/audio_processing/agc2/rnn_vad/vector_math.h
@@ -0,0 +1,55 @@
+/*
+ *  Copyright (c) 2020 The WebRTC project authors. All Rights Reserved.
+ *
+ *  Use of this source code is governed by a BSD-style license
+ *  that can be found in the LICENSE file in the root of the source
+ *  tree. An additional intellectual property rights grant can be found
+ *  in the file PATENTS.  All contributing project authors may
+ *  be found in the AUTHORS file in the root of the source tree.
+ */
+
+#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_VECTOR_MATH_H_
+#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_VECTOR_MATH_H_
+
+#include <numeric>
+
+#include "api/array_view.h"
+#include "modules/audio_processing/agc2/cpu_features.h"
+#include "rtc_base/checks.h"
+#include "rtc_base/system/arch.h"
+
+namespace webrtc {
+namespace rnn_vad {
+
+// Provides optimizations for mathematical operations having vectors as
+// operand(s).
+class VectorMath {
+ public:
+  explicit VectorMath(AvailableCpuFeatures cpu_features)
+      : cpu_features_(cpu_features) {}
+
+  // Computes the dot product between two equally sized vectors.
+  float DotProduct(rtc::ArrayView<const float> x,
+                   rtc::ArrayView<const float> y) const {
+#if defined(WEBRTC_ARCH_X86_FAMILY)
+    if (cpu_features_.avx2) {
+      return DotProductAvx2(x, y);
+    }
+    // TODO(bugs.webrtc.org/10480): Add SSE2 alternative implementation.
+#endif
+    // TODO(bugs.webrtc.org/10480): Add NEON alternative implementation.
+    RTC_DCHECK_EQ(x.size(), y.size());
+    return std::inner_product(x.begin(), x.end(), y.begin(), 0.f);
+  }
+
+ private:
+  float DotProductAvx2(rtc::ArrayView<const float> x,
+                       rtc::ArrayView<const float> y) const;
+
+  const AvailableCpuFeatures cpu_features_;
+};
+
+}  // namespace rnn_vad
+}  // namespace webrtc
+
+#endif  // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_VECTOR_MATH_H_
diff --git a/modules/audio_processing/agc2/rnn_vad/vector_math_avx2.cc b/modules/audio_processing/agc2/rnn_vad/vector_math_avx2.cc
new file mode 100644
index 0000000..3b2c4ad
--- /dev/null
+++ b/modules/audio_processing/agc2/rnn_vad/vector_math_avx2.cc
@@ -0,0 +1,53 @@
+/*
+ *  Copyright (c) 2020 The WebRTC project authors. All Rights Reserved.
+ *
+ *  Use of this source code is governed by a BSD-style license
+ *  that can be found in the LICENSE file in the root of the source
+ *  tree. An additional intellectual property rights grant can be found
+ *  in the file PATENTS.  All contributing project authors may
+ *  be found in the AUTHORS file in the root of the source tree.
+ */
+
+#include "modules/audio_processing/agc2/rnn_vad/vector_math.h"
+
+#include <immintrin.h>
+
+#include "api/array_view.h"
+#include "rtc_base/checks.h"
+
+namespace webrtc {
+namespace rnn_vad {
+
+float VectorMath::DotProductAvx2(rtc::ArrayView<const float> x,
+                                 rtc::ArrayView<const float> y) const {
+  RTC_DCHECK(cpu_features_.avx2);
+  RTC_DCHECK_EQ(x.size(), y.size());
+  __m256 accumulator = _mm256_setzero_ps();
+  constexpr int kBlockSizeLog2 = 3;
+  constexpr int kBlockSize = 1 << kBlockSizeLog2;
+  const int incomplete_block_index = (x.size() >> kBlockSizeLog2)
+                                     << kBlockSizeLog2;
+  for (int i = 0; i < incomplete_block_index; i += kBlockSize) {
+    RTC_DCHECK_LE(i + kBlockSize, x.size());
+    const __m256 x_i = _mm256_loadu_ps(&x[i]);
+    const __m256 y_i = _mm256_loadu_ps(&y[i]);
+    accumulator = _mm256_fmadd_ps(x_i, y_i, accumulator);
+  }
+  // Reduce `accumulator` by addition.
+  __m128 high = _mm256_extractf128_ps(accumulator, 1);
+  __m128 low = _mm256_extractf128_ps(accumulator, 0);
+  low = _mm_add_ps(high, low);
+  high = _mm_movehl_ps(high, low);
+  low = _mm_add_ps(high, low);
+  high = _mm_shuffle_ps(low, low, 1);
+  low = _mm_add_ss(high, low);
+  float dot_product = _mm_cvtss_f32(low);
+  // Add the result for the last block if incomplete.
+  for (int i = incomplete_block_index; static_cast<size_t>(i) < x.size(); ++i) {
+    dot_product += x[i] * y[i];
+  }
+  return dot_product;
+}
+
+}  // namespace rnn_vad
+}  // namespace webrtc
diff --git a/modules/audio_processing/agc2/rnn_vad/vector_math_unittest.cc b/modules/audio_processing/agc2/rnn_vad/vector_math_unittest.cc
new file mode 100644
index 0000000..19a8af0
--- /dev/null
+++ b/modules/audio_processing/agc2/rnn_vad/vector_math_unittest.cc
@@ -0,0 +1,67 @@
+/*
+ *  Copyright (c) 2020 The WebRTC project authors. All Rights Reserved.
+ *
+ *  Use of this source code is governed by a BSD-style license
+ *  that can be found in the LICENSE file in the root of the source
+ *  tree. An additional intellectual property rights grant can be found
+ *  in the file PATENTS.  All contributing project authors may
+ *  be found in the AUTHORS file in the root of the source tree.
+ */
+
+#include "modules/audio_processing/agc2/rnn_vad/vector_math.h"
+
+#include <vector>
+
+#include "modules/audio_processing/agc2/cpu_features.h"
+#include "test/gtest.h"
+
+namespace webrtc {
+namespace rnn_vad {
+namespace {
+
+constexpr int kSizeOfX = 19;
+constexpr float kX[kSizeOfX] = {
+    0.31593041f, 0.9350786f,   -0.25252445f, -0.86956251f, -0.9673632f,
+    0.54571901f, -0.72504495f, -0.79509912f, -0.25525012f, -0.73340473f,
+    0.15747377f, -0.04370565f, 0.76135145f,  -0.57239645f, 0.68616848f,
+    0.3740298f,  0.34710799f,  -0.92207423f, 0.10738454f};
+constexpr int kSizeOfXSubSpan = 16;
+static_assert(kSizeOfXSubSpan < kSizeOfX, "");
+constexpr float kEnergyOfX = 7.315563958160327f;
+constexpr float kEnergyOfXSubspan = 6.333327669592963f;
+
+class VectorMathParametrization
+    : public ::testing::TestWithParam<AvailableCpuFeatures> {};
+
+TEST_P(VectorMathParametrization, TestDotProduct) {
+  VectorMath vector_math(/*cpu_features=*/GetParam());
+  EXPECT_FLOAT_EQ(vector_math.DotProduct(kX, kX), kEnergyOfX);
+  EXPECT_FLOAT_EQ(
+      vector_math.DotProduct({kX, kSizeOfXSubSpan}, {kX, kSizeOfXSubSpan}),
+      kEnergyOfXSubspan);
+}
+
+// Finds the relevant CPU features combinations to test.
+std::vector<AvailableCpuFeatures> GetCpuFeaturesToTest() {
+  std::vector<AvailableCpuFeatures> v;
+  v.push_back({/*sse2=*/false, /*avx2=*/false, /*neon=*/false});
+  AvailableCpuFeatures available = GetAvailableCpuFeatures();
+  if (available.avx2) {
+    AvailableCpuFeatures features(
+        {/*sse2=*/false, /*avx2=*/true, /*neon=*/false});
+    v.push_back(features);
+  }
+  return v;
+}
+
+INSTANTIATE_TEST_SUITE_P(
+    RnnVadTest,
+    VectorMathParametrization,
+    ::testing::ValuesIn(GetCpuFeaturesToTest()),
+    [](const ::testing::TestParamInfo<AvailableCpuFeatures>& info) {
+      return info.param.ToString();
+    });
+
+}  // namespace
+}  // namespace rnn_vad
+}  // namespace webrtc