RNN VAD: GRU layer isolated into rnn_gru.h/.cc

Refactoring done to more easily and cleanly add SIMD optimizations and
to remove `GatedRecurrentLayer` from the RNN VAD api.

Bug: webrtc:10480
Change-Id: Ie1dffdd9b19c57c03a0b634f6818c0780456a66c
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/195445
Commit-Queue: Alessio Bazzica <alessiob@webrtc.org>
Reviewed-by: Jakob Ivarsson <jakobi@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#32770}
diff --git a/modules/audio_processing/agc2/rnn_vad/BUILD.gn b/modules/audio_processing/agc2/rnn_vad/BUILD.gn
index c57971a..29cdfeb 100644
--- a/modules/audio_processing/agc2/rnn_vad/BUILD.gn
+++ b/modules/audio_processing/agc2/rnn_vad/BUILD.gn
@@ -32,11 +32,9 @@
     "..:biquad_filter",
     "..:cpu_features",
     "../../../../api:array_view",
-    "../../../../api:function_view",
     "../../../../rtc_base:checks",
     "../../../../rtc_base:safe_compare",
     "../../../../rtc_base:safe_conversions",
-    "../../../../rtc_base/system:arch",
     "//third_party/rnnoise:rnn_vad",
   ]
 }
@@ -83,6 +81,8 @@
   sources = [
     "rnn_fc.cc",
     "rnn_fc.h",
+    "rnn_gru.cc",
+    "rnn_gru.h",
   ]
   deps = [
     ":rnn_vad_common",
@@ -241,6 +241,7 @@
       "pitch_search_unittest.cc",
       "ring_buffer_unittest.cc",
       "rnn_fc_unittest.cc",
+      "rnn_gru_unittest.cc",
       "rnn_unittest.cc",
       "rnn_vad_unittest.cc",
       "sequence_buffer_unittest.cc",
diff --git a/modules/audio_processing/agc2/rnn_vad/rnn.cc b/modules/audio_processing/agc2/rnn_vad/rnn.cc
index 9d6d28f..c1bded1 100644
--- a/modules/audio_processing/agc2/rnn_vad/rnn.cc
+++ b/modules/audio_processing/agc2/rnn_vad/rnn.cc
@@ -10,208 +10,33 @@
 
 #include "modules/audio_processing/agc2/rnn_vad/rnn.h"
 
-// Defines WEBRTC_ARCH_X86_FAMILY, used below.
-#include "rtc_base/system/arch.h"
-
-#if defined(WEBRTC_HAS_NEON)
-#include <arm_neon.h>
-#endif
-#if defined(WEBRTC_ARCH_X86_FAMILY)
-#include <emmintrin.h>
-#endif
-#include <algorithm>
-#include <array>
-#include <cmath>
-#include <numeric>
-
 #include "rtc_base/checks.h"
-#include "rtc_base/numerics/safe_conversions.h"
-#include "third_party/rnnoise/src/rnn_activations.h"
 #include "third_party/rnnoise/src/rnn_vad_weights.h"
 
 namespace webrtc {
 namespace rnn_vad {
 namespace {
 
-using rnnoise::kWeightsScale;
-
-using rnnoise::kInputLayerInputSize;
+using ::rnnoise::kInputLayerInputSize;
 static_assert(kFeatureVectorSize == kInputLayerInputSize, "");
-using rnnoise::kInputDenseBias;
-using rnnoise::kInputDenseWeights;
-using rnnoise::kInputLayerOutputSize;
+using ::rnnoise::kInputDenseBias;
+using ::rnnoise::kInputDenseWeights;
+using ::rnnoise::kInputLayerOutputSize;
 static_assert(kInputLayerOutputSize <= kFullyConnectedLayerMaxUnits, "");
 
-using rnnoise::kHiddenGruBias;
-using rnnoise::kHiddenGruRecurrentWeights;
-using rnnoise::kHiddenGruWeights;
-using rnnoise::kHiddenLayerOutputSize;
+using ::rnnoise::kHiddenGruBias;
+using ::rnnoise::kHiddenGruRecurrentWeights;
+using ::rnnoise::kHiddenGruWeights;
+using ::rnnoise::kHiddenLayerOutputSize;
 static_assert(kHiddenLayerOutputSize <= kGruLayerMaxUnits, "");
 
-using rnnoise::kOutputDenseBias;
-using rnnoise::kOutputDenseWeights;
-using rnnoise::kOutputLayerOutputSize;
+using ::rnnoise::kOutputDenseBias;
+using ::rnnoise::kOutputDenseWeights;
+using ::rnnoise::kOutputLayerOutputSize;
 static_assert(kOutputLayerOutputSize <= kFullyConnectedLayerMaxUnits, "");
 
-using rnnoise::SigmoidApproximated;
-using rnnoise::TansigApproximated;
-
-inline float RectifiedLinearUnit(float x) {
-  return x < 0.f ? 0.f : x;
-}
-
-constexpr int kNumGruGates = 3;  // Update, reset, output.
-
-// TODO(bugs.chromium.org/10480): Hard-coded optimized layout and remove this
-// function to improve setup time.
-// Casts and scales |tensor_src| for a GRU layer and re-arranges the layout.
-// It works both for weights, recurrent weights and bias.
-std::vector<float> GetPreprocessedGruTensor(
-    rtc::ArrayView<const int8_t> tensor_src,
-    int output_size) {
-  // Transpose, cast and scale.
-  // |n| is the size of the first dimension of the 3-dim tensor |weights|.
-  const int n = rtc::CheckedDivExact(rtc::dchecked_cast<int>(tensor_src.size()),
-                                     output_size * kNumGruGates);
-  const int stride_src = kNumGruGates * output_size;
-  const int stride_dst = n * output_size;
-  std::vector<float> tensor_dst(tensor_src.size());
-  for (int g = 0; g < kNumGruGates; ++g) {
-    for (int o = 0; o < output_size; ++o) {
-      for (int i = 0; i < n; ++i) {
-        tensor_dst[g * stride_dst + o * n + i] =
-            rnnoise::kWeightsScale *
-            static_cast<float>(
-                tensor_src[i * stride_src + g * output_size + o]);
-      }
-    }
-  }
-  return tensor_dst;
-}
-
-void ComputeGruUpdateResetGates(int input_size,
-                                int output_size,
-                                rtc::ArrayView<const float> weights,
-                                rtc::ArrayView<const float> recurrent_weights,
-                                rtc::ArrayView<const float> bias,
-                                rtc::ArrayView<const float> input,
-                                rtc::ArrayView<const float> state,
-                                rtc::ArrayView<float> gate) {
-  for (int o = 0; o < output_size; ++o) {
-    gate[o] = bias[o];
-    for (int i = 0; i < input_size; ++i) {
-      gate[o] += input[i] * weights[o * input_size + i];
-    }
-    for (int s = 0; s < output_size; ++s) {
-      gate[o] += state[s] * recurrent_weights[o * output_size + s];
-    }
-    gate[o] = SigmoidApproximated(gate[o]);
-  }
-}
-
-void ComputeGruOutputGate(int input_size,
-                          int output_size,
-                          rtc::ArrayView<const float> weights,
-                          rtc::ArrayView<const float> recurrent_weights,
-                          rtc::ArrayView<const float> bias,
-                          rtc::ArrayView<const float> input,
-                          rtc::ArrayView<const float> state,
-                          rtc::ArrayView<const float> reset,
-                          rtc::ArrayView<float> gate) {
-  for (int o = 0; o < output_size; ++o) {
-    gate[o] = bias[o];
-    for (int i = 0; i < input_size; ++i) {
-      gate[o] += input[i] * weights[o * input_size + i];
-    }
-    for (int s = 0; s < output_size; ++s) {
-      gate[o] += state[s] * recurrent_weights[o * output_size + s] * reset[s];
-    }
-    gate[o] = RectifiedLinearUnit(gate[o]);
-  }
-}
-
-// Gated recurrent unit (GRU) layer un-optimized implementation.
-void ComputeGruLayerOutput(int input_size,
-                           int output_size,
-                           rtc::ArrayView<const float> input,
-                           rtc::ArrayView<const float> weights,
-                           rtc::ArrayView<const float> recurrent_weights,
-                           rtc::ArrayView<const float> bias,
-                           rtc::ArrayView<float> state) {
-  RTC_DCHECK_EQ(input_size, input.size());
-  // Stride and offset used to read parameter arrays.
-  const int stride_in = input_size * output_size;
-  const int stride_out = output_size * output_size;
-
-  // Update gate.
-  std::array<float, kGruLayerMaxUnits> update;
-  ComputeGruUpdateResetGates(
-      input_size, output_size, weights.subview(0, stride_in),
-      recurrent_weights.subview(0, stride_out), bias.subview(0, output_size),
-      input, state, update);
-
-  // Reset gate.
-  std::array<float, kGruLayerMaxUnits> reset;
-  ComputeGruUpdateResetGates(
-      input_size, output_size, weights.subview(stride_in, stride_in),
-      recurrent_weights.subview(stride_out, stride_out),
-      bias.subview(output_size, output_size), input, state, reset);
-
-  // Output gate.
-  std::array<float, kGruLayerMaxUnits> output;
-  ComputeGruOutputGate(
-      input_size, output_size, weights.subview(2 * stride_in, stride_in),
-      recurrent_weights.subview(2 * stride_out, stride_out),
-      bias.subview(2 * output_size, output_size), input, state, reset, output);
-
-  // Update output through the update gates and update the state.
-  for (int o = 0; o < output_size; ++o) {
-    output[o] = update[o] * state[o] + (1.f - update[o]) * output[o];
-    state[o] = output[o];
-  }
-}
-
 }  // namespace
 
-GatedRecurrentLayer::GatedRecurrentLayer(
-    const int input_size,
-    const int output_size,
-    const rtc::ArrayView<const int8_t> bias,
-    const rtc::ArrayView<const int8_t> weights,
-    const rtc::ArrayView<const int8_t> recurrent_weights)
-    : input_size_(input_size),
-      output_size_(output_size),
-      bias_(GetPreprocessedGruTensor(bias, output_size)),
-      weights_(GetPreprocessedGruTensor(weights, output_size)),
-      recurrent_weights_(
-          GetPreprocessedGruTensor(recurrent_weights, output_size)) {
-  RTC_DCHECK_LE(output_size_, kGruLayerMaxUnits)
-      << "Static over-allocation of recurrent layers state vectors is not "
-         "sufficient.";
-  RTC_DCHECK_EQ(kNumGruGates * output_size_, bias_.size())
-      << "Mismatching output size and bias terms array size.";
-  RTC_DCHECK_EQ(kNumGruGates * input_size_ * output_size_, weights_.size())
-      << "Mismatching input-output size and weight coefficients array size.";
-  RTC_DCHECK_EQ(kNumGruGates * output_size_ * output_size_,
-                recurrent_weights_.size())
-      << "Mismatching input-output size and recurrent weight coefficients array"
-         " size.";
-  Reset();
-}
-
-GatedRecurrentLayer::~GatedRecurrentLayer() = default;
-
-void GatedRecurrentLayer::Reset() {
-  state_.fill(0.f);
-}
-
-void GatedRecurrentLayer::ComputeOutput(rtc::ArrayView<const float> input) {
-  // TODO(bugs.chromium.org/10480): Add AVX2.
-  // TODO(bugs.chromium.org/10480): Add Neon.
-  ComputeGruLayerOutput(input_size_, output_size_, input, weights_,
-                        recurrent_weights_, bias_, state_);
-}
-
 RnnVad::RnnVad(const AvailableCpuFeatures& cpu_features)
     : input_(kInputLayerInputSize,
              kInputLayerOutputSize,
@@ -224,7 +49,8 @@
               kHiddenLayerOutputSize,
               kHiddenGruBias,
               kHiddenGruWeights,
-              kHiddenGruRecurrentWeights),
+              kHiddenGruRecurrentWeights,
+              /*layer_name=*/"GRU1"),
       output_(kHiddenLayerOutputSize,
               kOutputLayerOutputSize,
               kOutputDenseBias,
diff --git a/modules/audio_processing/agc2/rnn_vad/rnn.h b/modules/audio_processing/agc2/rnn_vad/rnn.h
index df99c3c..3148f1b 100644
--- a/modules/audio_processing/agc2/rnn_vad/rnn.h
+++ b/modules/audio_processing/agc2/rnn_vad/rnn.h
@@ -18,56 +18,14 @@
 #include <vector>
 
 #include "api/array_view.h"
-#include "api/function_view.h"
 #include "modules/audio_processing/agc2/cpu_features.h"
 #include "modules/audio_processing/agc2/rnn_vad/common.h"
 #include "modules/audio_processing/agc2/rnn_vad/rnn_fc.h"
-#include "rtc_base/system/arch.h"
+#include "modules/audio_processing/agc2/rnn_vad/rnn_gru.h"
 
 namespace webrtc {
 namespace rnn_vad {
 
-// Maximum number of units for a GRU layer.
-constexpr int kGruLayerMaxUnits = 24;
-
-// Recurrent layer with gated recurrent units (GRUs) with sigmoid and ReLU as
-// activation functions for the update/reset and output gates respectively. It
-// owns the output buffer.
-class GatedRecurrentLayer {
- public:
-  // Ctor. `output_size` cannot be greater than `kGruLayerMaxUnits`.
-  GatedRecurrentLayer(int input_size,
-                      int output_size,
-                      rtc::ArrayView<const int8_t> bias,
-                      rtc::ArrayView<const int8_t> weights,
-                      rtc::ArrayView<const int8_t> recurrent_weights);
-  GatedRecurrentLayer(const GatedRecurrentLayer&) = delete;
-  GatedRecurrentLayer& operator=(const GatedRecurrentLayer&) = delete;
-  ~GatedRecurrentLayer();
-
-  // Returns the size of the input vector.
-  int input_size() const { return input_size_; }
-  // Returns the pointer to the first element of the output buffer.
-  const float* data() const { return state_.data(); }
-  // Returns the size of the output buffer.
-  int size() const { return output_size_; }
-
-  // Resets the GRU state.
-  void Reset();
-  // Computes the recurrent layer output and updates the status.
-  void ComputeOutput(rtc::ArrayView<const float> input);
-
- private:
-  const int input_size_;
-  const int output_size_;
-  const std::vector<float> bias_;
-  const std::vector<float> weights_;
-  const std::vector<float> recurrent_weights_;
-  // The state vector of a recurrent layer has length equal to |output_size_|.
-  // However, to avoid dynamic allocation, over-allocation is used.
-  std::array<float, kGruLayerMaxUnits> state_;
-};
-
 // Recurrent network with hard-coded architecture and weights for voice activity
 // detection.
 class RnnVad {
diff --git a/modules/audio_processing/agc2/rnn_vad/rnn_gru.cc b/modules/audio_processing/agc2/rnn_vad/rnn_gru.cc
new file mode 100644
index 0000000..f37fc2a
--- /dev/null
+++ b/modules/audio_processing/agc2/rnn_vad/rnn_gru.cc
@@ -0,0 +1,170 @@
+/*
+ *  Copyright (c) 2020 The WebRTC project authors. All Rights Reserved.
+ *
+ *  Use of this source code is governed by a BSD-style license
+ *  that can be found in the LICENSE file in the root of the source
+ *  tree. An additional intellectual property rights grant can be found
+ *  in the file PATENTS.  All contributing project authors may
+ *  be found in the AUTHORS file in the root of the source tree.
+ */
+
+#include "modules/audio_processing/agc2/rnn_vad/rnn_gru.h"
+
+#include "rtc_base/checks.h"
+#include "rtc_base/numerics/safe_conversions.h"
+#include "third_party/rnnoise/src/rnn_activations.h"
+#include "third_party/rnnoise/src/rnn_vad_weights.h"
+
+namespace webrtc {
+namespace rnn_vad {
+namespace {
+
+constexpr int kNumGruGates = 3;  // Update, reset, output.
+
+std::vector<float> PreprocessGruTensor(rtc::ArrayView<const int8_t> tensor_src,
+                                       int output_size) {
+  // Transpose, cast and scale.
+  // |n| is the size of the first dimension of the 3-dim tensor |weights|.
+  const int n = rtc::CheckedDivExact(rtc::dchecked_cast<int>(tensor_src.size()),
+                                     output_size * kNumGruGates);
+  const int stride_src = kNumGruGates * output_size;
+  const int stride_dst = n * output_size;
+  std::vector<float> tensor_dst(tensor_src.size());
+  for (int g = 0; g < kNumGruGates; ++g) {
+    for (int o = 0; o < output_size; ++o) {
+      for (int i = 0; i < n; ++i) {
+        tensor_dst[g * stride_dst + o * n + i] =
+            ::rnnoise::kWeightsScale *
+            static_cast<float>(
+                tensor_src[i * stride_src + g * output_size + o]);
+      }
+    }
+  }
+  return tensor_dst;
+}
+
+void ComputeGruUpdateResetGates(int input_size,
+                                int output_size,
+                                rtc::ArrayView<const float> weights,
+                                rtc::ArrayView<const float> recurrent_weights,
+                                rtc::ArrayView<const float> bias,
+                                rtc::ArrayView<const float> input,
+                                rtc::ArrayView<const float> state,
+                                rtc::ArrayView<float> gate) {
+  for (int o = 0; o < output_size; ++o) {
+    gate[o] = bias[o];
+    for (int i = 0; i < input_size; ++i) {
+      gate[o] += input[i] * weights[o * input_size + i];
+    }
+    for (int s = 0; s < output_size; ++s) {
+      gate[o] += state[s] * recurrent_weights[o * output_size + s];
+    }
+    gate[o] = ::rnnoise::SigmoidApproximated(gate[o]);
+  }
+}
+
+void ComputeGruOutputGate(int input_size,
+                          int output_size,
+                          rtc::ArrayView<const float> weights,
+                          rtc::ArrayView<const float> recurrent_weights,
+                          rtc::ArrayView<const float> bias,
+                          rtc::ArrayView<const float> input,
+                          rtc::ArrayView<const float> state,
+                          rtc::ArrayView<const float> reset,
+                          rtc::ArrayView<float> gate) {
+  for (int o = 0; o < output_size; ++o) {
+    gate[o] = bias[o];
+    for (int i = 0; i < input_size; ++i) {
+      gate[o] += input[i] * weights[o * input_size + i];
+    }
+    for (int s = 0; s < output_size; ++s) {
+      gate[o] += state[s] * recurrent_weights[o * output_size + s] * reset[s];
+    }
+    // Rectified linear unit.
+    if (gate[o] < 0.f) {
+      gate[o] = 0.f;
+    }
+  }
+}
+
+}  // namespace
+
+GatedRecurrentLayer::GatedRecurrentLayer(
+    const int input_size,
+    const int output_size,
+    const rtc::ArrayView<const int8_t> bias,
+    const rtc::ArrayView<const int8_t> weights,
+    const rtc::ArrayView<const int8_t> recurrent_weights,
+    absl::string_view layer_name)
+    : input_size_(input_size),
+      output_size_(output_size),
+      bias_(PreprocessGruTensor(bias, output_size)),
+      weights_(PreprocessGruTensor(weights, output_size)),
+      recurrent_weights_(PreprocessGruTensor(recurrent_weights, output_size)) {
+  RTC_DCHECK_LE(output_size_, kGruLayerMaxUnits)
+      << "Insufficient GRU layer over-allocation (" << layer_name << ").";
+  RTC_DCHECK_EQ(kNumGruGates * output_size_, bias_.size())
+      << "Mismatching output size and bias terms array size (" << layer_name
+      << ").";
+  RTC_DCHECK_EQ(kNumGruGates * input_size_ * output_size_, weights_.size())
+      << "Mismatching input-output size and weight coefficients array size ("
+      << layer_name << ").";
+  RTC_DCHECK_EQ(kNumGruGates * output_size_ * output_size_,
+                recurrent_weights_.size())
+      << "Mismatching input-output size and recurrent weight coefficients array"
+         " size ("
+      << layer_name << ").";
+  Reset();
+}
+
+GatedRecurrentLayer::~GatedRecurrentLayer() = default;
+
+void GatedRecurrentLayer::Reset() {
+  state_.fill(0.f);
+}
+
+void GatedRecurrentLayer::ComputeOutput(rtc::ArrayView<const float> input) {
+  RTC_DCHECK_EQ(input.size(), input_size_);
+
+  // TODO(bugs.chromium.org/10480): Add AVX2.
+  // TODO(bugs.chromium.org/10480): Add Neon.
+
+  // Stride and offset used to read parameter arrays.
+  const int stride_in = input_size_ * output_size_;
+  const int stride_out = output_size_ * output_size_;
+
+  rtc::ArrayView<const float> bias(bias_);
+  rtc::ArrayView<const float> weights(weights_);
+  rtc::ArrayView<const float> recurrent_weights(recurrent_weights_);
+
+  // Update gate.
+  std::array<float, kGruLayerMaxUnits> update;
+  ComputeGruUpdateResetGates(
+      input_size_, output_size_, weights.subview(0, stride_in),
+      recurrent_weights.subview(0, stride_out), bias.subview(0, output_size_),
+      input, state_, update);
+
+  // Reset gate.
+  std::array<float, kGruLayerMaxUnits> reset;
+  ComputeGruUpdateResetGates(
+      input_size_, output_size_, weights.subview(stride_in, stride_in),
+      recurrent_weights.subview(stride_out, stride_out),
+      bias.subview(output_size_, output_size_), input, state_, reset);
+
+  // Output gate.
+  std::array<float, kGruLayerMaxUnits> output;
+  ComputeGruOutputGate(input_size_, output_size_,
+                       weights.subview(2 * stride_in, stride_in),
+                       recurrent_weights.subview(2 * stride_out, stride_out),
+                       bias.subview(2 * output_size_, output_size_), input,
+                       state_, reset, output);
+
+  // Update output through the update gates and update the state.
+  for (int o = 0; o < output_size_; ++o) {
+    output[o] = update[o] * state_[o] + (1.f - update[o]) * output[o];
+    state_[o] = output[o];
+  }
+}
+
+}  // namespace rnn_vad
+}  // namespace webrtc
diff --git a/modules/audio_processing/agc2/rnn_vad/rnn_gru.h b/modules/audio_processing/agc2/rnn_vad/rnn_gru.h
new file mode 100644
index 0000000..f66b048
--- /dev/null
+++ b/modules/audio_processing/agc2/rnn_vad/rnn_gru.h
@@ -0,0 +1,67 @@
+/*
+ *  Copyright (c) 2020 The WebRTC project authors. All Rights Reserved.
+ *
+ *  Use of this source code is governed by a BSD-style license
+ *  that can be found in the LICENSE file in the root of the source
+ *  tree. An additional intellectual property rights grant can be found
+ *  in the file PATENTS.  All contributing project authors may
+ *  be found in the AUTHORS file in the root of the source tree.
+ */
+
+#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RNN_GRU_H_
+#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RNN_GRU_H_
+
+#include <array>
+#include <vector>
+
+#include "absl/strings/string_view.h"
+#include "api/array_view.h"
+#include "modules/audio_processing/agc2/cpu_features.h"
+
+namespace webrtc {
+namespace rnn_vad {
+
+// Maximum number of units for a GRU layer.
+constexpr int kGruLayerMaxUnits = 24;
+
+// Recurrent layer with gated recurrent units (GRUs) with sigmoid and ReLU as
+// activation functions for the update/reset and output gates respectively.
+class GatedRecurrentLayer {
+ public:
+  // Ctor. `output_size` cannot be greater than `kGruLayerMaxUnits`.
+  GatedRecurrentLayer(int input_size,
+                      int output_size,
+                      rtc::ArrayView<const int8_t> bias,
+                      rtc::ArrayView<const int8_t> weights,
+                      rtc::ArrayView<const int8_t> recurrent_weights,
+                      absl::string_view layer_name);
+  GatedRecurrentLayer(const GatedRecurrentLayer&) = delete;
+  GatedRecurrentLayer& operator=(const GatedRecurrentLayer&) = delete;
+  ~GatedRecurrentLayer();
+
+  // Returns the size of the input vector.
+  int input_size() const { return input_size_; }
+  // Returns the pointer to the first element of the output buffer.
+  const float* data() const { return state_.data(); }
+  // Returns the size of the output buffer.
+  int size() const { return output_size_; }
+
+  // Resets the GRU state.
+  void Reset();
+  // Computes the recurrent layer output and updates the status.
+  void ComputeOutput(rtc::ArrayView<const float> input);
+
+ private:
+  const int input_size_;
+  const int output_size_;
+  const std::vector<float> bias_;
+  const std::vector<float> weights_;
+  const std::vector<float> recurrent_weights_;
+  // Over-allocated array with size equal to `output_size_`.
+  std::array<float, kGruLayerMaxUnits> state_;
+};
+
+}  // namespace rnn_vad
+}  // namespace webrtc
+
+#endif  // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RNN_GRU_H_
diff --git a/modules/audio_processing/agc2/rnn_vad/rnn_gru_unittest.cc b/modules/audio_processing/agc2/rnn_vad/rnn_gru_unittest.cc
new file mode 100644
index 0000000..54e1cf5
--- /dev/null
+++ b/modules/audio_processing/agc2/rnn_vad/rnn_gru_unittest.cc
@@ -0,0 +1,140 @@
+/*
+ *  Copyright (c) 2020 The WebRTC project authors. All Rights Reserved.
+ *
+ *  Use of this source code is governed by a BSD-style license
+ *  that can be found in the LICENSE file in the root of the source
+ *  tree. An additional intellectual property rights grant can be found
+ *  in the file PATENTS.  All contributing project authors may
+ *  be found in the AUTHORS file in the root of the source tree.
+ */
+
+#include "modules/audio_processing/agc2/rnn_vad/rnn_gru.h"
+
+#include <array>
+
+#include "api/array_view.h"
+#include "modules/audio_processing/agc2/rnn_vad/test_utils.h"
+#include "modules/audio_processing/test/performance_timer.h"
+#include "rtc_base/checks.h"
+#include "rtc_base/logging.h"
+#include "test/gtest.h"
+
+namespace webrtc {
+namespace rnn_vad {
+namespace test {
+namespace {
+
+void TestGatedRecurrentLayer(
+    GatedRecurrentLayer& gru,
+    rtc::ArrayView<const float> input_sequence,
+    rtc::ArrayView<const float> expected_output_sequence) {
+  const int input_sequence_length = rtc::CheckedDivExact(
+      rtc::dchecked_cast<int>(input_sequence.size()), gru.input_size());
+  const int output_sequence_length = rtc::CheckedDivExact(
+      rtc::dchecked_cast<int>(expected_output_sequence.size()), gru.size());
+  ASSERT_EQ(input_sequence_length, output_sequence_length)
+      << "The test data length is invalid.";
+  // Feed the GRU layer and check the output at every step.
+  gru.Reset();
+  for (int i = 0; i < input_sequence_length; ++i) {
+    SCOPED_TRACE(i);
+    gru.ComputeOutput(
+        input_sequence.subview(i * gru.input_size(), gru.input_size()));
+    const auto expected_output =
+        expected_output_sequence.subview(i * gru.size(), gru.size());
+    ExpectNearAbsolute(expected_output, gru, 3e-6f);
+  }
+}
+
+// Gated recurrent units layer test data.
+constexpr int kGruInputSize = 5;
+constexpr int kGruOutputSize = 4;
+constexpr std::array<int8_t, 12> kGruBias = {96,   -99, -81, -114, 49,  119,
+                                             -118, 68,  -76, 91,   121, 125};
+constexpr std::array<int8_t, 60> kGruWeights = {
+    // Input 0.
+    124, 9, 1, 116,        // Update.
+    -66, -21, -118, -110,  // Reset.
+    104, 75, -23, -51,     // Output.
+    // Input 1.
+    -72, -111, 47, 93,   // Update.
+    77, -98, 41, -8,     // Reset.
+    40, -23, -43, -107,  // Output.
+    // Input 2.
+    9, -73, 30, -32,      // Update.
+    -2, 64, -26, 91,      // Reset.
+    -48, -24, -28, -104,  // Output.
+    // Input 3.
+    74, -46, 116, 15,    // Update.
+    32, 52, -126, -38,   // Reset.
+    -121, 12, -16, 110,  // Output.
+    // Input 4.
+    -95, 66, -103, -35,  // Update.
+    -38, 3, -126, -61,   // Reset.
+    28, 98, -117, -43    // Output.
+};
+constexpr std::array<int8_t, 48> kGruRecurrentWeights = {
+    // Output 0.
+    -3, 87, 50, 51,     // Update.
+    -22, 27, -39, 62,   // Reset.
+    31, -83, -52, -48,  // Output.
+    // Output 1.
+    -6, 83, -19, 104,  // Update.
+    105, 48, 23, 68,   // Reset.
+    23, 40, 7, -120,   // Output.
+    // Output 2.
+    64, -62, 117, 85,     // Update.
+    51, -43, 54, -105,    // Reset.
+    120, 56, -128, -107,  // Output.
+    // Output 3.
+    39, 50, -17, -47,   // Update.
+    -117, 14, 108, 12,  // Reset.
+    -7, -72, 103, -87,  // Output.
+};
+constexpr std::array<float, 20> kGruInputSequence = {
+    0.89395463f, 0.93224651f, 0.55788344f, 0.32341808f, 0.93355054f,
+    0.13475326f, 0.97370994f, 0.14253306f, 0.93710381f, 0.76093364f,
+    0.65780413f, 0.41657975f, 0.49403164f, 0.46843281f, 0.75138855f,
+    0.24517593f, 0.47657707f, 0.57064998f, 0.435184f,   0.19319285f};
+constexpr std::array<float, 16> kGruExpectedOutputSequence = {
+    0.0239123f,  0.5773077f,  0.f,         0.f,
+    0.01282811f, 0.64330572f, 0.f,         0.04863098f,
+    0.00781069f, 0.75267816f, 0.f,         0.02579715f,
+    0.00471378f, 0.59162533f, 0.11087593f, 0.01334511f};
+
+// Checks that the output of a GRU layer is within tolerance given test input
+// data.
+TEST(RnnVadTest, CheckGatedRecurrentLayer) {
+  GatedRecurrentLayer gru(kGruInputSize, kGruOutputSize, kGruBias, kGruWeights,
+                          kGruRecurrentWeights, /*layer_name=*/"GRU");
+  TestGatedRecurrentLayer(gru, kGruInputSequence, kGruExpectedOutputSequence);
+}
+
+TEST(RnnVadTest, DISABLED_BenchmarkGatedRecurrentLayer) {
+  GatedRecurrentLayer gru(kGruInputSize, kGruOutputSize, kGruBias, kGruWeights,
+                          kGruRecurrentWeights, /*layer_name=*/"GRU");
+
+  rtc::ArrayView<const float> input_sequence(kGruInputSequence);
+  static_assert(kGruInputSequence.size() % kGruInputSize == 0, "");
+  constexpr int input_sequence_length =
+      kGruInputSequence.size() / kGruInputSize;
+
+  constexpr int kNumTests = 10000;
+  ::webrtc::test::PerformanceTimer perf_timer(kNumTests);
+  for (int k = 0; k < kNumTests; ++k) {
+    perf_timer.StartTimer();
+    for (int i = 0; i < input_sequence_length; ++i) {
+      gru.ComputeOutput(
+          input_sequence.subview(i * gru.input_size(), gru.input_size()));
+    }
+    perf_timer.StopTimer();
+  }
+  RTC_LOG(LS_INFO) << (perf_timer.GetDurationAverage() / 1000) << " +/- "
+                   << (perf_timer.GetDurationStandardDeviation() / 1000)
+                   << " ms";
+}
+
+}  // namespace
+}  // namespace test
+}  // namespace rnn_vad
+}  // namespace webrtc
diff --git a/modules/audio_processing/agc2/rnn_vad/rnn_unittest.cc b/modules/audio_processing/agc2/rnn_vad/rnn_unittest.cc
index 4f42d11..1c314d1 100644
--- a/modules/audio_processing/agc2/rnn_vad/rnn_unittest.cc
+++ b/modules/audio_processing/agc2/rnn_vad/rnn_unittest.cc
@@ -10,18 +10,10 @@
 
 #include "modules/audio_processing/agc2/rnn_vad/rnn.h"
 
-#include <array>
-#include <memory>
-#include <vector>
-
+#include "api/array_view.h"
 #include "modules/audio_processing/agc2/cpu_features.h"
-#include "modules/audio_processing/agc2/rnn_vad/test_utils.h"
-#include "modules/audio_processing/test/performance_timer.h"
-#include "rtc_base/checks.h"
-#include "rtc_base/logging.h"
-#include "rtc_base/numerics/safe_conversions.h"
+#include "modules/audio_processing/agc2/rnn_vad/common.h"
 #include "test/gtest.h"
-#include "third_party/rnnoise/src/rnn_vad_weights.h"
 
 namespace webrtc {
 namespace rnn_vad {
@@ -43,116 +35,6 @@
   }
 }
 
-void TestGatedRecurrentLayer(
-    GatedRecurrentLayer& gru,
-    rtc::ArrayView<const float> input_sequence,
-    rtc::ArrayView<const float> expected_output_sequence) {
-  const int input_sequence_length = rtc::CheckedDivExact(
-      rtc::dchecked_cast<int>(input_sequence.size()), gru.input_size());
-  const int output_sequence_length = rtc::CheckedDivExact(
-      rtc::dchecked_cast<int>(expected_output_sequence.size()), gru.size());
-  ASSERT_EQ(input_sequence_length, output_sequence_length)
-      << "The test data length is invalid.";
-  // Feed the GRU layer and check the output at every step.
-  gru.Reset();
-  for (int i = 0; i < input_sequence_length; ++i) {
-    SCOPED_TRACE(i);
-    gru.ComputeOutput(
-        input_sequence.subview(i * gru.input_size(), gru.input_size()));
-    const auto expected_output =
-        expected_output_sequence.subview(i * gru.size(), gru.size());
-    ExpectNearAbsolute(expected_output, gru, 3e-6f);
-  }
-}
-
-// Gated recurrent units layer test data.
-constexpr int kGruInputSize = 5;
-constexpr int kGruOutputSize = 4;
-constexpr std::array<int8_t, 12> kGruBias = {96,   -99, -81, -114, 49,  119,
-                                             -118, 68,  -76, 91,   121, 125};
-constexpr std::array<int8_t, 60> kGruWeights = {
-    // Input 0.
-    124, 9, 1, 116,        // Update.
-    -66, -21, -118, -110,  // Reset.
-    104, 75, -23, -51,     // Output.
-    // Input 1.
-    -72, -111, 47, 93,   // Update.
-    77, -98, 41, -8,     // Reset.
-    40, -23, -43, -107,  // Output.
-    // Input 2.
-    9, -73, 30, -32,      // Update.
-    -2, 64, -26, 91,      // Reset.
-    -48, -24, -28, -104,  // Output.
-    // Input 3.
-    74, -46, 116, 15,    // Update.
-    32, 52, -126, -38,   // Reset.
-    -121, 12, -16, 110,  // Output.
-    // Input 4.
-    -95, 66, -103, -35,  // Update.
-    -38, 3, -126, -61,   // Reset.
-    28, 98, -117, -43    // Output.
-};
-constexpr std::array<int8_t, 48> kGruRecurrentWeights = {
-    // Output 0.
-    -3, 87, 50, 51,     // Update.
-    -22, 27, -39, 62,   // Reset.
-    31, -83, -52, -48,  // Output.
-    // Output 1.
-    -6, 83, -19, 104,  // Update.
-    105, 48, 23, 68,   // Reset.
-    23, 40, 7, -120,   // Output.
-    // Output 2.
-    64, -62, 117, 85,     // Update.
-    51, -43, 54, -105,    // Reset.
-    120, 56, -128, -107,  // Output.
-    // Output 3.
-    39, 50, -17, -47,   // Update.
-    -117, 14, 108, 12,  // Reset.
-    -7, -72, 103, -87,  // Output.
-};
-constexpr std::array<float, 20> kGruInputSequence = {
-    0.89395463f, 0.93224651f, 0.55788344f, 0.32341808f, 0.93355054f,
-    0.13475326f, 0.97370994f, 0.14253306f, 0.93710381f, 0.76093364f,
-    0.65780413f, 0.41657975f, 0.49403164f, 0.46843281f, 0.75138855f,
-    0.24517593f, 0.47657707f, 0.57064998f, 0.435184f,   0.19319285f};
-constexpr std::array<float, 16> kGruExpectedOutputSequence = {
-    0.0239123f,  0.5773077f,  0.f,         0.f,
-    0.01282811f, 0.64330572f, 0.f,         0.04863098f,
-    0.00781069f, 0.75267816f, 0.f,         0.02579715f,
-    0.00471378f, 0.59162533f, 0.11087593f, 0.01334511f};
-
-// Checks that the output of a GRU layer is within tolerance given test input
-// data.
-TEST(RnnVadTest, CheckGatedRecurrentLayer) {
-  GatedRecurrentLayer gru(kGruInputSize, kGruOutputSize, kGruBias, kGruWeights,
-                          kGruRecurrentWeights);
-  TestGatedRecurrentLayer(gru, kGruInputSequence, kGruExpectedOutputSequence);
-}
-
-TEST(RnnVadTest, DISABLED_BenchmarkGatedRecurrentLayer) {
-  GatedRecurrentLayer gru(kGruInputSize, kGruOutputSize, kGruBias, kGruWeights,
-                          kGruRecurrentWeights);
-
-  rtc::ArrayView<const float> input_sequence(kGruInputSequence);
-  static_assert(kGruInputSequence.size() % kGruInputSize == 0, "");
-  constexpr int input_sequence_length =
-      kGruInputSequence.size() / kGruInputSize;
-
-  constexpr int kNumTests = 10000;
-  ::webrtc::test::PerformanceTimer perf_timer(kNumTests);
-  for (int k = 0; k < kNumTests; ++k) {
-    perf_timer.StartTimer();
-    for (int i = 0; i < input_sequence_length; ++i) {
-      gru.ComputeOutput(
-          input_sequence.subview(i * gru.input_size(), gru.input_size()));
-    }
-    perf_timer.StopTimer();
-  }
-  RTC_LOG(LS_INFO) << (perf_timer.GetDurationAverage() / 1000) << " +/- "
-                   << (perf_timer.GetDurationStandardDeviation() / 1000)
-                   << " ms";
-}
-
 // Checks that the speech probability is zero with silence.
 TEST(RnnVadTest, CheckZeroProbabilityWithSilence) {
   RnnVad rnn_vad(GetAvailableCpuFeatures());