RNN VAD: GRU layer optimized

Using `VectorMath::DotProduct()` in GatedRecurrentLayer to reuse existing
SIMD optimizations. Results:
- When SSE2/AVX2 is avilable, the GRU layer takes 40% of the unoptimized
  code
- The realtime factor for the VAD improved as follows
  - SSE2: from 570x to 630x
  - AVX2: from 610x to 680x

This CL also improved the GRU layer benchmark by (i) benchmarking a GRU
layer havibng the same size of that used in the VAD and (ii) by prefetching
a long input sequence.

Bug: webrtc:10480
Change-Id: I9716b15661e4c6b81592b4cf7c172d90e41b5223
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/195545
Reviewed-by: Per Åhgren <peah@webrtc.org>
Commit-Queue: Alessio Bazzica <alessiob@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#32803}
diff --git a/modules/audio_processing/agc2/rnn_vad/BUILD.gn b/modules/audio_processing/agc2/rnn_vad/BUILD.gn
index 29cdfeb..ef2370c 100644
--- a/modules/audio_processing/agc2/rnn_vad/BUILD.gn
+++ b/modules/audio_processing/agc2/rnn_vad/BUILD.gn
@@ -86,6 +86,7 @@
   ]
   deps = [
     ":rnn_vad_common",
+    ":vector_math",
     "..:cpu_features",
     "../../../../api:array_view",
     "../../../../api:function_view",
@@ -94,6 +95,9 @@
     "../../../../rtc_base/system:arch",
     "//third_party/rnnoise:rnn_vad",
   ]
+  if (current_cpu == "x86" || current_cpu == "x64") {
+    deps += [ ":vector_math_avx2" ]
+  }
   absl_deps = [ "//third_party/abseil-cpp/absl/strings" ]
 }
 
diff --git a/modules/audio_processing/agc2/rnn_vad/rnn.cc b/modules/audio_processing/agc2/rnn_vad/rnn.cc
index c1bded1..f828a24 100644
--- a/modules/audio_processing/agc2/rnn_vad/rnn.cc
+++ b/modules/audio_processing/agc2/rnn_vad/rnn.cc
@@ -50,6 +50,7 @@
               kHiddenGruBias,
               kHiddenGruWeights,
               kHiddenGruRecurrentWeights,
+              cpu_features,
               /*layer_name=*/"GRU1"),
       output_(kHiddenLayerOutputSize,
               kOutputLayerOutputSize,
diff --git a/modules/audio_processing/agc2/rnn_vad/rnn_fc_unittest.cc b/modules/audio_processing/agc2/rnn_vad/rnn_fc_unittest.cc
index c586ed2..900ce63 100644
--- a/modules/audio_processing/agc2/rnn_vad/rnn_fc_unittest.cc
+++ b/modules/audio_processing/agc2/rnn_vad/rnn_fc_unittest.cc
@@ -46,12 +46,12 @@
     0.983443f,  0.999991f,  -0.824335f, 0.984742f,  0.990208f,  0.938179f,
     0.875092f,  0.999846f,  0.997707f,  -0.999382f, 0.973153f,  -0.966605f};
 
-class RnnParametrization
+class RnnFcParametrization
     : public ::testing::TestWithParam<AvailableCpuFeatures> {};
 
 // Checks that the output of a fully connected layer is within tolerance given
 // test input data.
-TEST_P(RnnParametrization, CheckFullyConnectedLayerOutput) {
+TEST_P(RnnFcParametrization, CheckFullyConnectedLayerOutput) {
   FullyConnectedLayer fc(kInputLayerInputSize, kInputLayerOutputSize,
                          kInputDenseBias, kInputDenseWeights,
                          ActivationFunction::kTansigApproximated,
@@ -61,7 +61,7 @@
   ExpectNearAbsolute(kFullyConnectedExpectedOutput, fc, 1e-5f);
 }
 
-TEST_P(RnnParametrization, DISABLED_BenchmarkFullyConnectedLayer) {
+TEST_P(RnnFcParametrization, DISABLED_BenchmarkFullyConnectedLayer) {
   const AvailableCpuFeatures cpu_features = GetParam();
   FullyConnectedLayer fc(kInputLayerInputSize, kInputLayerOutputSize,
                          kInputDenseBias, kInputDenseWeights,
@@ -87,16 +87,14 @@
   v.push_back({/*sse2=*/false, /*avx2=*/false, /*neon=*/false});
   AvailableCpuFeatures available = GetAvailableCpuFeatures();
   if (available.sse2) {
-    AvailableCpuFeatures features(
-        {/*sse2=*/true, /*avx2=*/false, /*neon=*/false});
-    v.push_back(features);
+    v.push_back({/*sse2=*/true, /*avx2=*/false, /*neon=*/false});
   }
   return v;
 }
 
 INSTANTIATE_TEST_SUITE_P(
     RnnVadTest,
-    RnnParametrization,
+    RnnFcParametrization,
     ::testing::ValuesIn(GetCpuFeaturesToTest()),
     [](const ::testing::TestParamInfo<AvailableCpuFeatures>& info) {
       return info.param.ToString();
diff --git a/modules/audio_processing/agc2/rnn_vad/rnn_gru.cc b/modules/audio_processing/agc2/rnn_vad/rnn_gru.cc
index f37fc2a..482016e 100644
--- a/modules/audio_processing/agc2/rnn_vad/rnn_gru.cc
+++ b/modules/audio_processing/agc2/rnn_vad/rnn_gru.cc
@@ -43,47 +43,79 @@
   return tensor_dst;
 }
 
-void ComputeGruUpdateResetGates(int input_size,
-                                int output_size,
-                                rtc::ArrayView<const float> weights,
-                                rtc::ArrayView<const float> recurrent_weights,
-                                rtc::ArrayView<const float> bias,
-                                rtc::ArrayView<const float> input,
-                                rtc::ArrayView<const float> state,
-                                rtc::ArrayView<float> gate) {
+// Computes the output for the update or the reset gate.
+// Operation: `g = sigmoid(W^T∙i + R^T∙s + b)` where
+// - `g`: output gate vector
+// - `W`: weights matrix
+// - `i`: input vector
+// - `R`: recurrent weights matrix
+// - `s`: state gate vector
+// - `b`: bias vector
+void ComputeUpdateResetGate(int input_size,
+                            int output_size,
+                            const VectorMath& vector_math,
+                            rtc::ArrayView<const float> input,
+                            rtc::ArrayView<const float> state,
+                            rtc::ArrayView<const float> bias,
+                            rtc::ArrayView<const float> weights,
+                            rtc::ArrayView<const float> recurrent_weights,
+                            rtc::ArrayView<float> gate) {
+  RTC_DCHECK_EQ(input.size(), input_size);
+  RTC_DCHECK_EQ(state.size(), output_size);
+  RTC_DCHECK_EQ(bias.size(), output_size);
+  RTC_DCHECK_EQ(weights.size(), input_size * output_size);
+  RTC_DCHECK_EQ(recurrent_weights.size(), output_size * output_size);
+  RTC_DCHECK_GE(gate.size(), output_size);  // `gate` is over-allocated.
   for (int o = 0; o < output_size; ++o) {
-    gate[o] = bias[o];
-    for (int i = 0; i < input_size; ++i) {
-      gate[o] += input[i] * weights[o * input_size + i];
-    }
-    for (int s = 0; s < output_size; ++s) {
-      gate[o] += state[s] * recurrent_weights[o * output_size + s];
-    }
-    gate[o] = ::rnnoise::SigmoidApproximated(gate[o]);
+    float x = bias[o];
+    x += vector_math.DotProduct(input,
+                                weights.subview(o * input_size, input_size));
+    x += vector_math.DotProduct(
+        state, recurrent_weights.subview(o * output_size, output_size));
+    gate[o] = ::rnnoise::SigmoidApproximated(x);
   }
 }
 
-void ComputeGruOutputGate(int input_size,
-                          int output_size,
-                          rtc::ArrayView<const float> weights,
-                          rtc::ArrayView<const float> recurrent_weights,
-                          rtc::ArrayView<const float> bias,
-                          rtc::ArrayView<const float> input,
-                          rtc::ArrayView<const float> state,
-                          rtc::ArrayView<const float> reset,
-                          rtc::ArrayView<float> gate) {
+// Computes the output for the state gate.
+// Operation: `s' = u .* s + (1 - u) .* ReLU(W^T∙i + R^T∙(s .* r) + b)` where
+// - `s'`: output state gate vector
+// - `s`: previous state gate vector
+// - `u`: update gate vector
+// - `W`: weights matrix
+// - `i`: input vector
+// - `R`: recurrent weights matrix
+// - `r`: reset gate vector
+// - `b`: bias vector
+// - `.*` element-wise product
+void ComputeStateGate(int input_size,
+                      int output_size,
+                      const VectorMath& vector_math,
+                      rtc::ArrayView<const float> input,
+                      rtc::ArrayView<const float> update,
+                      rtc::ArrayView<const float> reset,
+                      rtc::ArrayView<const float> bias,
+                      rtc::ArrayView<const float> weights,
+                      rtc::ArrayView<const float> recurrent_weights,
+                      rtc::ArrayView<float> state) {
+  RTC_DCHECK_EQ(input.size(), input_size);
+  RTC_DCHECK_GE(update.size(), output_size);  // `update` is over-allocated.
+  RTC_DCHECK_GE(reset.size(), output_size);   // `reset` is over-allocated.
+  RTC_DCHECK_EQ(bias.size(), output_size);
+  RTC_DCHECK_EQ(weights.size(), input_size * output_size);
+  RTC_DCHECK_EQ(recurrent_weights.size(), output_size * output_size);
+  RTC_DCHECK_EQ(state.size(), output_size);
+  std::array<float, kGruLayerMaxUnits> reset_x_state;
   for (int o = 0; o < output_size; ++o) {
-    gate[o] = bias[o];
-    for (int i = 0; i < input_size; ++i) {
-      gate[o] += input[i] * weights[o * input_size + i];
-    }
-    for (int s = 0; s < output_size; ++s) {
-      gate[o] += state[s] * recurrent_weights[o * output_size + s] * reset[s];
-    }
-    // Rectified linear unit.
-    if (gate[o] < 0.f) {
-      gate[o] = 0.f;
-    }
+    reset_x_state[o] = state[o] * reset[o];
+  }
+  for (int o = 0; o < output_size; ++o) {
+    float x = bias[o];
+    x += vector_math.DotProduct(input,
+                                weights.subview(o * input_size, input_size));
+    x += vector_math.DotProduct(
+        {reset_x_state.data(), static_cast<size_t>(output_size)},
+        recurrent_weights.subview(o * output_size, output_size));
+    state[o] = update[o] * state[o] + (1.f - update[o]) * std::max(0.f, x);
   }
 }
 
@@ -95,12 +127,14 @@
     const rtc::ArrayView<const int8_t> bias,
     const rtc::ArrayView<const int8_t> weights,
     const rtc::ArrayView<const int8_t> recurrent_weights,
+    const AvailableCpuFeatures& cpu_features,
     absl::string_view layer_name)
     : input_size_(input_size),
       output_size_(output_size),
       bias_(PreprocessGruTensor(bias, output_size)),
       weights_(PreprocessGruTensor(weights, output_size)),
-      recurrent_weights_(PreprocessGruTensor(recurrent_weights, output_size)) {
+      recurrent_weights_(PreprocessGruTensor(recurrent_weights, output_size)),
+      vector_math_(cpu_features) {
   RTC_DCHECK_LE(output_size_, kGruLayerMaxUnits)
       << "Insufficient GRU layer over-allocation (" << layer_name << ").";
   RTC_DCHECK_EQ(kNumGruGates * output_size_, bias_.size())
@@ -126,44 +160,38 @@
 void GatedRecurrentLayer::ComputeOutput(rtc::ArrayView<const float> input) {
   RTC_DCHECK_EQ(input.size(), input_size_);
 
-  // TODO(bugs.chromium.org/10480): Add AVX2.
-  // TODO(bugs.chromium.org/10480): Add Neon.
-
-  // Stride and offset used to read parameter arrays.
-  const int stride_in = input_size_ * output_size_;
-  const int stride_out = output_size_ * output_size_;
-
+  // The tensors below are organized as a sequence of flattened tensors for the
+  // `update`, `reset` and `state` gates.
   rtc::ArrayView<const float> bias(bias_);
   rtc::ArrayView<const float> weights(weights_);
   rtc::ArrayView<const float> recurrent_weights(recurrent_weights_);
+  // Strides to access to the flattened tensors for a specific gate.
+  const int stride_weights = input_size_ * output_size_;
+  const int stride_recurrent_weights = output_size_ * output_size_;
+
+  rtc::ArrayView<float> state(state_.data(), output_size_);
 
   // Update gate.
   std::array<float, kGruLayerMaxUnits> update;
-  ComputeGruUpdateResetGates(
-      input_size_, output_size_, weights.subview(0, stride_in),
-      recurrent_weights.subview(0, stride_out), bias.subview(0, output_size_),
-      input, state_, update);
-
+  ComputeUpdateResetGate(
+      input_size_, output_size_, vector_math_, input, state,
+      bias.subview(0, output_size_), weights.subview(0, stride_weights),
+      recurrent_weights.subview(0, stride_recurrent_weights), update);
   // Reset gate.
   std::array<float, kGruLayerMaxUnits> reset;
-  ComputeGruUpdateResetGates(
-      input_size_, output_size_, weights.subview(stride_in, stride_in),
-      recurrent_weights.subview(stride_out, stride_out),
-      bias.subview(output_size_, output_size_), input, state_, reset);
-
-  // Output gate.
-  std::array<float, kGruLayerMaxUnits> output;
-  ComputeGruOutputGate(input_size_, output_size_,
-                       weights.subview(2 * stride_in, stride_in),
-                       recurrent_weights.subview(2 * stride_out, stride_out),
-                       bias.subview(2 * output_size_, output_size_), input,
-                       state_, reset, output);
-
-  // Update output through the update gates and update the state.
-  for (int o = 0; o < output_size_; ++o) {
-    output[o] = update[o] * state_[o] + (1.f - update[o]) * output[o];
-    state_[o] = output[o];
-  }
+  ComputeUpdateResetGate(input_size_, output_size_, vector_math_, input, state,
+                         bias.subview(output_size_, output_size_),
+                         weights.subview(stride_weights, stride_weights),
+                         recurrent_weights.subview(stride_recurrent_weights,
+                                                   stride_recurrent_weights),
+                         reset);
+  // State gate.
+  ComputeStateGate(input_size_, output_size_, vector_math_, input, update,
+                   reset, bias.subview(2 * output_size_, output_size_),
+                   weights.subview(2 * stride_weights, stride_weights),
+                   recurrent_weights.subview(2 * stride_recurrent_weights,
+                                             stride_recurrent_weights),
+                   state);
 }
 
 }  // namespace rnn_vad
diff --git a/modules/audio_processing/agc2/rnn_vad/rnn_gru.h b/modules/audio_processing/agc2/rnn_vad/rnn_gru.h
index f66b048..3407dfcd 100644
--- a/modules/audio_processing/agc2/rnn_vad/rnn_gru.h
+++ b/modules/audio_processing/agc2/rnn_vad/rnn_gru.h
@@ -17,6 +17,7 @@
 #include "absl/strings/string_view.h"
 #include "api/array_view.h"
 #include "modules/audio_processing/agc2/cpu_features.h"
+#include "modules/audio_processing/agc2/rnn_vad/vector_math.h"
 
 namespace webrtc {
 namespace rnn_vad {
@@ -34,6 +35,7 @@
                       rtc::ArrayView<const int8_t> bias,
                       rtc::ArrayView<const int8_t> weights,
                       rtc::ArrayView<const int8_t> recurrent_weights,
+                      const AvailableCpuFeatures& cpu_features,
                       absl::string_view layer_name);
   GatedRecurrentLayer(const GatedRecurrentLayer&) = delete;
   GatedRecurrentLayer& operator=(const GatedRecurrentLayer&) = delete;
@@ -57,6 +59,7 @@
   const std::vector<float> bias_;
   const std::vector<float> weights_;
   const std::vector<float> recurrent_weights_;
+  const VectorMath vector_math_;
   // Over-allocated array with size equal to `output_size_`.
   std::array<float, kGruLayerMaxUnits> state_;
 };
diff --git a/modules/audio_processing/agc2/rnn_vad/rnn_gru_unittest.cc b/modules/audio_processing/agc2/rnn_vad/rnn_gru_unittest.cc
index 4e8b524..ee8bdac 100644
--- a/modules/audio_processing/agc2/rnn_vad/rnn_gru_unittest.cc
+++ b/modules/audio_processing/agc2/rnn_vad/rnn_gru_unittest.cc
@@ -11,6 +11,8 @@
 #include "modules/audio_processing/agc2/rnn_vad/rnn_gru.h"
 
 #include <array>
+#include <memory>
+#include <vector>
 
 #include "api/array_view.h"
 #include "modules/audio_processing/agc2/rnn_vad/test_utils.h"
@@ -18,6 +20,7 @@
 #include "rtc_base/checks.h"
 #include "rtc_base/logging.h"
 #include "test/gtest.h"
+#include "third_party/rnnoise/src/rnn_vad_weights.h"
 
 namespace webrtc {
 namespace rnn_vad {
@@ -101,24 +104,44 @@
     0.00781069f, 0.75267816f, 0.f,         0.02579715f,
     0.00471378f, 0.59162533f, 0.11087593f, 0.01334511f};
 
+class RnnGruParametrization
+    : public ::testing::TestWithParam<AvailableCpuFeatures> {};
+
 // Checks that the output of a GRU layer is within tolerance given test input
 // data.
-TEST(RnnVadTest, CheckGatedRecurrentLayer) {
+TEST_P(RnnGruParametrization, CheckGatedRecurrentLayer) {
   GatedRecurrentLayer gru(kGruInputSize, kGruOutputSize, kGruBias, kGruWeights,
-                          kGruRecurrentWeights, /*layer_name=*/"GRU");
+                          kGruRecurrentWeights,
+                          /*cpu_features=*/GetParam(),
+                          /*layer_name=*/"GRU");
   TestGatedRecurrentLayer(gru, kGruInputSequence, kGruExpectedOutputSequence);
 }
 
-TEST(RnnVadTest, DISABLED_BenchmarkGatedRecurrentLayer) {
-  GatedRecurrentLayer gru(kGruInputSize, kGruOutputSize, kGruBias, kGruWeights,
-                          kGruRecurrentWeights, /*layer_name=*/"GRU");
+TEST_P(RnnGruParametrization, DISABLED_BenchmarkGatedRecurrentLayer) {
+  // Prefetch test data.
+  std::unique_ptr<FileReader> reader = CreateGruInputReader();
+  std::vector<float> gru_input_sequence(reader->size());
+  reader->ReadChunk(gru_input_sequence);
 
-  rtc::ArrayView<const float> input_sequence(kGruInputSequence);
-  static_assert(kGruInputSequence.size() % kGruInputSize == 0, "");
-  constexpr int input_sequence_length =
-      kGruInputSequence.size() / kGruInputSize;
+  using ::rnnoise::kHiddenGruBias;
+  using ::rnnoise::kHiddenGruRecurrentWeights;
+  using ::rnnoise::kHiddenGruWeights;
+  using ::rnnoise::kHiddenLayerOutputSize;
+  using ::rnnoise::kInputLayerOutputSize;
 
-  constexpr int kNumTests = 10000;
+  GatedRecurrentLayer gru(kInputLayerOutputSize, kHiddenLayerOutputSize,
+                          kHiddenGruBias, kHiddenGruWeights,
+                          kHiddenGruRecurrentWeights,
+                          /*cpu_features=*/GetParam(),
+                          /*layer_name=*/"GRU");
+
+  rtc::ArrayView<const float> input_sequence(gru_input_sequence);
+  ASSERT_EQ(input_sequence.size() % kInputLayerOutputSize,
+            static_cast<size_t>(0));
+  const int input_sequence_length =
+      input_sequence.size() / kInputLayerOutputSize;
+
+  constexpr int kNumTests = 100;
   ::webrtc::test::PerformanceTimer perf_timer(kNumTests);
   for (int k = 0; k < kNumTests; ++k) {
     perf_timer.StartTimer();
@@ -133,6 +156,28 @@
                    << " ms";
 }
 
+// Finds the relevant CPU features combinations to test.
+std::vector<AvailableCpuFeatures> GetCpuFeaturesToTest() {
+  std::vector<AvailableCpuFeatures> v;
+  AvailableCpuFeatures available = GetAvailableCpuFeatures();
+  v.push_back({/*sse2=*/false, /*avx2=*/false, /*neon=*/false});
+  if (available.avx2) {
+    v.push_back({/*sse2=*/false, /*avx2=*/true, /*neon=*/false});
+  }
+  if (available.sse2) {
+    v.push_back({/*sse2=*/true, /*avx2=*/false, /*neon=*/false});
+  }
+  return v;
+}
+
+INSTANTIATE_TEST_SUITE_P(
+    RnnVadTest,
+    RnnGruParametrization,
+    ::testing::ValuesIn(GetCpuFeaturesToTest()),
+    [](const ::testing::TestParamInfo<AvailableCpuFeatures>& info) {
+      return info.param.ToString();
+    });
+
 }  // namespace
 }  // namespace rnn_vad
 }  // namespace webrtc
diff --git a/modules/audio_processing/agc2/rnn_vad/test_utils.cc b/modules/audio_processing/agc2/rnn_vad/test_utils.cc
index 3db6774..b8ca9c3 100644
--- a/modules/audio_processing/agc2/rnn_vad/test_utils.cc
+++ b/modules/audio_processing/agc2/rnn_vad/test_utils.cc
@@ -111,6 +111,12 @@
   return {kChunkSize, num_chunks, std::move(reader)};
 }
 
+std::unique_ptr<FileReader> CreateGruInputReader() {
+  return std::make_unique<FloatFileReader<float>>(
+      /*filename=*/test::ResourcePath("audio_processing/agc2/rnn_vad/gru_in",
+                                      "dat"));
+}
+
 std::unique_ptr<FileReader> CreateVadProbsReader() {
   return std::make_unique<FloatFileReader<float>>(
       /*filename=*/test::ResourcePath("audio_processing/agc2/rnn_vad/vad_prob",
diff --git a/modules/audio_processing/agc2/rnn_vad/test_utils.h b/modules/audio_processing/agc2/rnn_vad/test_utils.h
index 86af5e0..e366e18 100644
--- a/modules/audio_processing/agc2/rnn_vad/test_utils.h
+++ b/modules/audio_processing/agc2/rnn_vad/test_utils.h
@@ -77,6 +77,9 @@
 // Creates a reader for the LP residual and pitch information test data.
 ChunksFileReader CreateLpResidualAndPitchInfoReader();
 
+// Creates a reader for the sequence of GRU input vectors.
+std::unique_ptr<FileReader> CreateGruInputReader();
+
 // Creates a reader for the VAD probabilities test data.
 std::unique_ptr<FileReader> CreateVadProbsReader();
 
diff --git a/resources/audio_processing/agc2/rnn_vad/gru_in.dat.sha1 b/resources/audio_processing/agc2/rnn_vad/gru_in.dat.sha1
new file mode 100644
index 0000000..f78c40e
--- /dev/null
+++ b/resources/audio_processing/agc2/rnn_vad/gru_in.dat.sha1
@@ -0,0 +1 @@
+402abf7a4e5d35abb78906fff2b3f4d8d24aa629
\ No newline at end of file