RNN VAD: FC layer with SSE2 impl

This CL adds the SSE2 optimized implementation for fully connected
(FC) layers. The change includes a weights re-alignment op done once
at construction time. It is required in order to optimize the load op
to fill 128 bit registers.

This CL also includes unit test adaptations and a benchmark test
(disabled by default).

Bug: webrtc:10480
Change-Id: I5ed87f0a629faaaf4c8bffbce1cea5557518f8c8
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/141862
Commit-Queue: Alessio Bazzica <alessiob@webrtc.org>
Reviewed-by: Gustaf Ullberg <gustaf@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#29712}
diff --git a/modules/audio_processing/agc2/rnn_vad/BUILD.gn b/modules/audio_processing/agc2/rnn_vad/BUILD.gn
index 852abd8..f4613b1 100644
--- a/modules/audio_processing/agc2/rnn_vad/BUILD.gn
+++ b/modules/audio_processing/agc2/rnn_vad/BUILD.gn
@@ -44,6 +44,7 @@
   deps = [
     "..:biquad_filter",
     "../../../../api:array_view",
+    "../../../../api:function_view",
     "../../../../rtc_base:checks",
     "../../../../rtc_base:rtc_base_approved",
     "../../../../rtc_base/system:arch",
@@ -65,6 +66,8 @@
       "../../../../api:array_view",
       "../../../../api:scoped_refptr",
       "../../../../rtc_base:checks",
+      "../../../../rtc_base/system:arch",
+      "../../../../system_wrappers:cpu_features_api",
       "../../../../test:fileutils",
       "../../../../test:test_support",
     ]
@@ -113,8 +116,10 @@
       "../../../../common_audio/",
       "../../../../rtc_base:checks",
       "../../../../rtc_base:logging",
+      "../../../../rtc_base/system:arch",
       "../../../../test:test_support",
       "../../utility:pffft_wrapper",
+      "//third_party/abseil-cpp/absl/memory",
       "//third_party/rnnoise:rnn_vad",
     ]
     data = unittest_resources
diff --git a/modules/audio_processing/agc2/rnn_vad/rnn.cc b/modules/audio_processing/agc2/rnn_vad/rnn.cc
index e6ef2f3..a5f7b4b 100644
--- a/modules/audio_processing/agc2/rnn_vad/rnn.cc
+++ b/modules/audio_processing/agc2/rnn_vad/rnn.cc
@@ -22,6 +22,7 @@
 #include <algorithm>
 #include <array>
 #include <cmath>
+#include <numeric>
 
 #include "rtc_base/checks.h"
 #include "third_party/rnnoise/src/rnn_activations.h"
@@ -29,6 +30,7 @@
 
 namespace webrtc {
 namespace rnn_vad {
+namespace {
 
 using rnnoise::kWeightsScale;
 
@@ -56,8 +58,6 @@
 using rnnoise::SigmoidApproximated;
 using rnnoise::TansigApproximated;
 
-namespace {
-
 inline float RectifiedLinearUnit(float x) {
   return x < 0.f ? 0.f : x;
 }
@@ -71,6 +71,83 @@
   return scaled_params;
 }
 
+// Casts and scales |weights| and re-arranges the layout.
+std::vector<float> GetPreprocessedWeights(rtc::ArrayView<const int8_t> weights,
+                                          const size_t output_size) {
+  if (output_size == 1) {
+    return GetScaledParams(weights);
+  }
+  // Transpose, scale and cast.
+  const size_t input_size = rtc::CheckedDivExact(weights.size(), output_size);
+  std::vector<float> w(weights.size());
+  for (size_t o = 0; o < output_size; ++o) {
+    for (size_t i = 0; i < input_size; ++i) {
+      w[o * input_size + i] = rnnoise::kWeightsScale *
+                              static_cast<float>(weights[i * output_size + o]);
+    }
+  }
+  return w;
+}
+
+// Fully connected layer un-optimized implementation.
+void ComputeFullyConnectedLayerOutput(
+    size_t input_size,
+    size_t output_size,
+    rtc::ArrayView<const float> input,
+    rtc::ArrayView<const float> bias,
+    rtc::ArrayView<const float> weights,
+    rtc::FunctionView<float(float)> activation_function,
+    rtc::ArrayView<float> output) {
+  RTC_DCHECK_EQ(input.size(), input_size);
+  RTC_DCHECK_EQ(bias.size(), output_size);
+  RTC_DCHECK_EQ(weights.size(), input_size * output_size);
+  for (size_t o = 0; o < output_size; ++o) {
+    output[o] = bias[o];
+    // TODO(bugs.chromium.org/9076): Benchmark how different layouts for
+    // |weights_| change the performance across different platforms.
+    for (size_t i = 0; i < input_size; ++i) {
+      output[o] += input[i] * weights[o * input_size + i];
+    }
+    output[o] = activation_function(output[o]);
+  }
+}
+
+#if defined(WEBRTC_ARCH_X86_FAMILY)
+// Fully connected layer SSE2 implementation.
+void ComputeFullyConnectedLayerOutputSse2(
+    size_t input_size,
+    size_t output_size,
+    rtc::ArrayView<const float> input,
+    rtc::ArrayView<const float> bias,
+    rtc::ArrayView<const float> weights,
+    rtc::FunctionView<float(float)> activation_function,
+    rtc::ArrayView<float> output) {
+  RTC_DCHECK_EQ(input.size(), input_size);
+  RTC_DCHECK_EQ(bias.size(), output_size);
+  RTC_DCHECK_EQ(weights.size(), input_size * output_size);
+  const size_t input_size_by_4 = input_size >> 2;
+  const size_t offset = input_size & ~3;
+  __m128 sum_wx_128;
+  const float* v = reinterpret_cast<const float*>(&sum_wx_128);
+  for (size_t o = 0; o < output_size; ++o) {
+    // Perform 128 bit vector operations.
+    sum_wx_128 = _mm_set1_ps(0);
+    const float* x_p = input.data();
+    const float* w_p = weights.data() + o * input_size;
+    for (size_t i = 0; i < input_size_by_4; ++i, x_p += 4, w_p += 4) {
+      sum_wx_128 = _mm_add_ps(sum_wx_128,
+                              _mm_mul_ps(_mm_loadu_ps(x_p), _mm_loadu_ps(w_p)));
+    }
+    // Perform non-vector operations for any remaining items, sum up bias term
+    // and results from the vectorized code, and apply the activation function.
+    output[o] = activation_function(
+        std::inner_product(input.begin() + offset, input.end(),
+                           weights.begin() + o * input_size + offset,
+                           bias[o] + v[0] + v[1] + v[2] + v[3]));
+  }
+}
+#endif
+
 }  // namespace
 
 FullyConnectedLayer::FullyConnectedLayer(
@@ -78,12 +155,12 @@
     const size_t output_size,
     const rtc::ArrayView<const int8_t> bias,
     const rtc::ArrayView<const int8_t> weights,
-    float (*const activation_function)(float),
+    rtc::FunctionView<float(float)> activation_function,
     Optimization optimization)
     : input_size_(input_size),
       output_size_(output_size),
       bias_(GetScaledParams(bias)),
-      weights_(GetScaledParams(weights)),
+      weights_(GetPreprocessedWeights(weights, output_size)),
       activation_function_(activation_function),
       optimization_(optimization) {
   RTC_DCHECK_LE(output_size_, kFullyConnectedLayersMaxUnits)
@@ -105,31 +182,21 @@
   switch (optimization_) {
 #if defined(WEBRTC_ARCH_X86_FAMILY)
     case Optimization::kSse2:
-      // TODO(bugs.chromium.org/10480): Handle Optimization::kSse2.
-      ComputeOutput_NONE(input);
+      ComputeFullyConnectedLayerOutputSse2(input_size_, output_size_, input,
+                                           bias_, weights_,
+                                           activation_function_, output_);
       break;
 #endif
 #if defined(WEBRTC_HAS_NEON)
     case Optimization::kNeon:
       // TODO(bugs.chromium.org/10480): Handle Optimization::kNeon.
-      ComputeOutput_NONE(input);
+      ComputeFullyConnectedLayerOutput(input_size_, output_size_, input, bias_,
+                                       weights_, activation_function_, output_);
       break;
 #endif
     default:
-      ComputeOutput_NONE(input);
-  }
-}
-
-void FullyConnectedLayer::ComputeOutput_NONE(
-    rtc::ArrayView<const float> input) {
-  for (size_t o = 0; o < output_size_; ++o) {
-    output_[o] = bias_[o];
-    // TODO(bugs.chromium.org/9076): Benchmark how different layouts for
-    // |weights_| change the performance across different platforms.
-    for (size_t i = 0; i < input_size_; ++i) {
-      output_[o] += input[i] * weights_[i * output_size_ + o];
-    }
-    output_[o] = (*activation_function_)(output_[o]);
+      ComputeFullyConnectedLayerOutput(input_size_, output_size_, input, bias_,
+                                       weights_, activation_function_, output_);
   }
 }
 
diff --git a/modules/audio_processing/agc2/rnn_vad/rnn.h b/modules/audio_processing/agc2/rnn_vad/rnn.h
index f53a093..29ee207 100644
--- a/modules/audio_processing/agc2/rnn_vad/rnn.h
+++ b/modules/audio_processing/agc2/rnn_vad/rnn.h
@@ -18,7 +18,9 @@
 #include <vector>
 
 #include "api/array_view.h"
+#include "api/function_view.h"
 #include "modules/audio_processing/agc2/rnn_vad/common.h"
+#include "rtc_base/system/arch.h"
 
 namespace webrtc {
 namespace rnn_vad {
@@ -42,30 +44,28 @@
                       size_t output_size,
                       rtc::ArrayView<const int8_t> bias,
                       rtc::ArrayView<const int8_t> weights,
-                      float (*const activation_function)(float),
+                      rtc::FunctionView<float(float)> activation_function,
                       Optimization optimization);
   FullyConnectedLayer(const FullyConnectedLayer&) = delete;
   FullyConnectedLayer& operator=(const FullyConnectedLayer&) = delete;
   ~FullyConnectedLayer();
   size_t input_size() const { return input_size_; }
   size_t output_size() const { return output_size_; }
+  Optimization optimization() const { return optimization_; }
   rtc::ArrayView<const float> GetOutput() const;
   // Computes the fully-connected layer output.
   void ComputeOutput(rtc::ArrayView<const float> input);
 
  private:
-  // No SIMD optimizations.
-  void ComputeOutput_NONE(rtc::ArrayView<const float> input);
-
   const size_t input_size_;
   const size_t output_size_;
   const std::vector<float> bias_;
   const std::vector<float> weights_;
-  float (*const activation_function_)(float);
-  const Optimization optimization_;
+  rtc::FunctionView<float(float)> activation_function_;
   // The output vector of a recurrent layer has length equal to |output_size_|.
   // However, for efficiency, over-allocation is used.
   std::array<float, kFullyConnectedLayersMaxUnits> output_;
+  const Optimization optimization_;
 };
 
 // Recurrent layer with gated recurrent units (GRUs) with sigmoid and ReLU as
@@ -83,6 +83,7 @@
   ~GatedRecurrentLayer();
   size_t input_size() const { return input_size_; }
   size_t output_size() const { return output_size_; }
+  Optimization optimization() const { return optimization_; }
   rtc::ArrayView<const float> GetOutput() const;
   void Reset();
   // Computes the recurrent layer output and updates the status.
diff --git a/modules/audio_processing/agc2/rnn_vad/rnn_unittest.cc b/modules/audio_processing/agc2/rnn_vad/rnn_unittest.cc
index 97ede18..7497416 100644
--- a/modules/audio_processing/agc2/rnn_vad/rnn_unittest.cc
+++ b/modules/audio_processing/agc2/rnn_vad/rnn_unittest.cc
@@ -11,10 +11,14 @@
 #include "modules/audio_processing/agc2/rnn_vad/rnn.h"
 
 #include <array>
+#include <memory>
+#include <vector>
 
 #include "modules/audio_processing/agc2/rnn_vad/test_utils.h"
+#include "modules/audio_processing/test/performance_timer.h"
 #include "rtc_base/checks.h"
 #include "rtc_base/logging.h"
+#include "rtc_base/system/arch.h"
 #include "test/gtest.h"
 #include "third_party/rnnoise/src/rnn_activations.h"
 #include "third_party/rnnoise/src/rnn_vad_weights.h"
@@ -23,18 +27,14 @@
 namespace rnn_vad {
 namespace test {
 
-using rnnoise::RectifiedLinearUnit;
-using rnnoise::SigmoidApproximated;
-
 namespace {
 
 void TestFullyConnectedLayer(FullyConnectedLayer* fc,
                              rtc::ArrayView<const float> input_vector,
-                             const float expected_output) {
+                             rtc::ArrayView<const float> expected_output) {
   RTC_CHECK(fc);
   fc->ComputeOutput(input_vector);
-  const auto output = fc->GetOutput();
-  EXPECT_NEAR(expected_output, output[0], 3e-6f);
+  ExpectNearAbsolute(expected_output, fc->GetOutput(), 1e-5f);
 }
 
 void TestGatedRecurrentLayer(
@@ -62,32 +62,19 @@
 }
 
 // Fully connected layer test data.
-constexpr size_t kFullyConnectedInputSize = 24;
-constexpr size_t kFullyConnectedOutputSize = 1;
-constexpr std::array<int8_t, 1> kFullyConnectedBias = {-50};
-constexpr std::array<int8_t, 24> kFullyConnectedWeights = {
-    127,  127,  127, 127,  127,  20,  127,  -126, -126, -54, 14,  125,
-    -126, -126, 127, -125, -126, 127, -127, -127, -57,  -30, 127, 80};
-constexpr std::array<float, 24 * 3> kFullyConnectedInputVectors = {
-    // Input 1.
-    0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.215833917f, 0.290601075f, 0.238759011f,
-    0.244751841f, 0.f, 0.0461241305f, 0.106401242f, 0.223070428f, 0.630603909f,
-    0.690453172f, 0.f, 0.387645692f, 0.166913897f, 0.f, 0.0327451192f, 0.f,
-    0.136149868f, 0.446351469f,
-    // Input 2.
-    0.592162728f, 0.529089332f, 1.18205106f, 1.21736848f, 0.f, 0.470851123f,
-    0.130675942f, 0.320903003f, 0.305496395f, 0.0571633279f, 1.57001138f,
-    0.0182026215f, 0.0977443159f, 0.347477973f, 0.493206412f, 0.9688586f,
-    0.0320267938f, 0.244722098f, 0.312745273f, 0.f, 0.00650715502f,
-    0.312553257f, 1.62619662f, 0.782880902f,
-    // Input 3.
-    0.395022154f, 0.333681047f, 0.76302278f, 0.965480626f, 0.f, 0.941198349f,
-    0.0892967582f, 0.745046318f, 0.635769248f, 0.238564298f, 0.970656633f,
-    0.014159563f, 0.094203949f, 0.446816623f, 0.640755892f, 1.20532358f,
-    0.0254284926f, 0.283327013f, 0.726210058f, 0.0550272502f, 0.000344108557f,
-    0.369803518f, 1.56680179f, 0.997883797f};
-constexpr std::array<float, 3> kFullyConnectedExpectedOutputs = {
-    0.436567038f, 0.874741316f, 0.672785878f};
+constexpr std::array<float, 42> kFullyConnectedInputVector = {
+    -1.00131f,   -0.627069f, -7.81097f,  7.86285f,    -2.87145f,  3.32365f,
+    -0.653161f,  0.529839f,  -0.425307f, 0.25583f,    0.235094f,  0.230527f,
+    -0.144687f,  0.182785f,  0.57102f,   0.125039f,   0.479482f,  -0.0255439f,
+    -0.0073141f, -0.147346f, -0.217106f, -0.0846906f, -8.34943f,  3.09065f,
+    1.42628f,    -0.85235f,  -0.220207f, -0.811163f,  2.09032f,   -2.01425f,
+    -0.690268f,  -0.925327f, -0.541354f, 0.58455f,    -0.606726f, -0.0372358f,
+    0.565991f,   0.435854f,  0.420812f,  0.162198f,   -2.13f,     10.0089f};
+constexpr std::array<float, 24> kFullyConnectedExpectedOutput = {
+    -0.623293f, -0.988299f, 0.999378f,  0.967168f,  0.103087f,  -0.978545f,
+    -0.856347f, 0.346675f,  1.f,        -0.717442f, -0.544176f, 0.960363f,
+    0.983443f,  0.999991f,  -0.824335f, 0.984742f,  0.990208f,  0.938179f,
+    0.875092f,  0.999846f,  0.997707f,  -0.999382f, 0.973153f,  -0.966605f};
 
 // Gated recurrent units layer test data.
 constexpr size_t kGruInputSize = 5;
@@ -117,47 +104,94 @@
     0.00781069f, 0.75267816f, 0.f,         0.02579715f,
     0.00471378f, 0.59162533f, 0.11087593f, 0.01334511f};
 
-}  // namespace
+std::string GetOptimizationName(Optimization optimization) {
+  switch (optimization) {
+    case Optimization::kSse2:
+      return "SSE2";
+    case Optimization::kNeon:
+      return "NEON";
+    case Optimization::kNone:
+      return "none";
+  }
+}
 
-class OptimizationTest : public ::testing::Test,
-                         public ::testing::WithParamInterface<Optimization> {};
+}  // namespace
 
 // Checks that the output of a fully connected layer is within tolerance given
 // test input data.
-TEST_P(OptimizationTest, CheckFullyConnectedLayerOutput) {
-  const Optimization optimization = GetParam();
-  RTC_LOG(LS_VERBOSE) << optimization;
-  FullyConnectedLayer fc(kFullyConnectedInputSize, kFullyConnectedOutputSize,
-                         kFullyConnectedBias, kFullyConnectedWeights,
-                         SigmoidApproximated, optimization);
-  // Test on different inputs.
-  static_assert(
-      kFullyConnectedInputVectors.size() % kFullyConnectedInputSize == 0, "");
-  constexpr size_t kNumInputVectors =
-      kFullyConnectedInputVectors.size() / kFullyConnectedInputSize;
-  static_assert(kFullyConnectedExpectedOutputs.size() == kNumInputVectors, "");
-  for (size_t i = 0; i < kNumInputVectors; ++i) {
-    rtc::ArrayView<const float> input(
-        kFullyConnectedInputVectors.data() + kFullyConnectedInputSize * i,
-        kFullyConnectedInputSize);
-    TestFullyConnectedLayer(&fc, input, kFullyConnectedExpectedOutputs[i]);
-  }
+TEST(RnnVadTest, CheckFullyConnectedLayerOutput) {
+  FullyConnectedLayer fc(rnnoise::kInputLayerInputSize,
+                         rnnoise::kInputLayerOutputSize,
+                         rnnoise::kInputDenseBias, rnnoise::kInputDenseWeights,
+                         rnnoise::TansigApproximated, Optimization::kNone);
+  TestFullyConnectedLayer(&fc, kFullyConnectedInputVector,
+                          kFullyConnectedExpectedOutput);
 }
 
 // Checks that the output of a GRU layer is within tolerance given test input
 // data.
-TEST_P(OptimizationTest, CheckGatedRecurrentLayer) {
-  const Optimization optimization = GetParam();
-  RTC_LOG(LS_VERBOSE) << optimization;
+TEST(RnnVadTest, CheckGatedRecurrentLayer) {
   GatedRecurrentLayer gru(kGruInputSize, kGruOutputSize, kGruBias, kGruWeights,
-                          kGruRecurrentWeights, optimization);
+                          kGruRecurrentWeights, Optimization::kNone);
   TestGatedRecurrentLayer(&gru, kGruInputSequence, kGruExpectedOutputSequence);
 }
 
-INSTANTIATE_TEST_SUITE_P(RnnVadTest,
-                         OptimizationTest,
-                         ::testing::Values(Optimization::kNone,
-                                           DetectOptimization()));
+#if defined(WEBRTC_ARCH_X86_FAMILY)
+
+// Like CheckFullyConnectedLayerOutput, but testing the SSE2 implementation.
+TEST(RnnVadTest, CheckFullyConnectedLayerOutputSse2) {
+  if (!IsOptimizationAvailable(Optimization::kSse2)) {
+    return;
+  }
+
+  FullyConnectedLayer fc(rnnoise::kInputLayerInputSize,
+                         rnnoise::kInputLayerOutputSize,
+                         rnnoise::kInputDenseBias, rnnoise::kInputDenseWeights,
+                         rnnoise::TansigApproximated, Optimization::kSse2);
+  TestFullyConnectedLayer(&fc, kFullyConnectedInputVector,
+                          kFullyConnectedExpectedOutput);
+}
+
+#endif  // WEBRTC_ARCH_X86_FAMILY
+
+TEST(RnnVadTest, DISABLED_BenchmarkFullyConnectedLayer) {
+  std::vector<std::unique_ptr<FullyConnectedLayer>> implementations;
+  implementations.emplace_back(std::make_unique<FullyConnectedLayer>(
+      rnnoise::kInputLayerInputSize, rnnoise::kInputLayerOutputSize,
+      rnnoise::kInputDenseBias, rnnoise::kInputDenseWeights,
+      rnnoise::TansigApproximated, Optimization::kNone));
+  if (IsOptimizationAvailable(Optimization::kSse2)) {
+    implementations.emplace_back(std::make_unique<FullyConnectedLayer>(
+        rnnoise::kInputLayerInputSize, rnnoise::kInputLayerOutputSize,
+        rnnoise::kInputDenseBias, rnnoise::kInputDenseWeights,
+        rnnoise::TansigApproximated, Optimization::kSse2));
+  }
+
+  struct Result {
+    Optimization optimization;
+    double average_us;
+    double std_dev_us;
+  };
+  std::vector<Result> results;
+
+  constexpr size_t number_of_tests = 10000;
+  for (auto& fc : implementations) {
+    ::webrtc::test::PerformanceTimer perf_timer(number_of_tests);
+    for (size_t k = 0; k < number_of_tests; ++k) {
+      perf_timer.StartTimer();
+      fc->ComputeOutput(kFullyConnectedInputVector);
+      perf_timer.StopTimer();
+    }
+    results.push_back({fc->optimization(), perf_timer.GetDurationAverage(),
+                       perf_timer.GetDurationStandardDeviation()});
+  }
+
+  for (const auto& result : results) {
+    RTC_LOG(LS_INFO) << GetOptimizationName(result.optimization) << ": "
+                     << (result.average_us / 1e3) << " +/- "
+                     << (result.std_dev_us / 1e3) << " ms";
+  }
+}
 
 }  // namespace test
 }  // namespace rnn_vad
diff --git a/modules/audio_processing/agc2/rnn_vad/test_utils.cc b/modules/audio_processing/agc2/rnn_vad/test_utils.cc
index 6e0eb5b..1a8e1a2 100644
--- a/modules/audio_processing/agc2/rnn_vad/test_utils.cc
+++ b/modules/audio_processing/agc2/rnn_vad/test_utils.cc
@@ -13,6 +13,8 @@
 #include <memory>
 
 #include "rtc_base/checks.h"
+#include "rtc_base/system/arch.h"
+#include "system_wrappers/include/cpu_features_wrapper.h"
 #include "test/gtest.h"
 #include "test/testsupport/file_utils.h"
 
@@ -103,6 +105,25 @@
           kNumPitchBufAutoCorrCoeffs};
 }
 
+bool IsOptimizationAvailable(Optimization optimization) {
+  switch (optimization) {
+    case Optimization::kSse2:
+#if defined(WEBRTC_ARCH_X86_FAMILY)
+      return WebRtc_GetCPUInfo(kSSE2) != 0;
+#else
+      return false;
+#endif
+    case Optimization::kNeon:
+#if defined(WEBRTC_HAS_NEON)
+      return true;
+#else
+      return false;
+#endif
+    case Optimization::kNone:
+      return true;
+  }
+}
+
 }  // namespace test
 }  // namespace rnn_vad
 }  // namespace webrtc
diff --git a/modules/audio_processing/agc2/rnn_vad/test_utils.h b/modules/audio_processing/agc2/rnn_vad/test_utils.h
index fbb270f..db155e6 100644
--- a/modules/audio_processing/agc2/rnn_vad/test_utils.h
+++ b/modules/audio_processing/agc2/rnn_vad/test_utils.h
@@ -151,6 +151,9 @@
   std::array<float, kPitchTestDataSize> test_data_;
 };
 
+// Returns true if the given optimization is available.
+bool IsOptimizationAvailable(Optimization optimization);
+
 }  // namespace test
 }  // namespace rnn_vad
 }  // namespace webrtc