RNN VAD: FC and GRU layers implicit conversion to ArrayView

Plus a few minor code readability improvements.

Bug: webrtc:10480
Change-Id: I590d8e203b1d05959a8c15373841e37abe83237e
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/195334
Commit-Queue: Alessio Bazzica <alessiob@webrtc.org>
Reviewed-by: Karl Wiberg <kwiberg@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#32764}
diff --git a/modules/audio_processing/agc2/rnn_vad/BUILD.gn b/modules/audio_processing/agc2/rnn_vad/BUILD.gn
index dbba6c1..4351afd 100644
--- a/modules/audio_processing/agc2/rnn_vad/BUILD.gn
+++ b/modules/audio_processing/agc2/rnn_vad/BUILD.gn
@@ -84,6 +84,7 @@
     "..:cpu_features",
     "../../../../api:array_view",
     "../../../../rtc_base:checks",
+    "../../../../rtc_base:safe_conversions",
     "../../../../rtc_base/system:arch",
   ]
 }
@@ -103,6 +104,7 @@
       ":vector_math",
       "../../../../api:array_view",
       "../../../../rtc_base:checks",
+      "../../../../rtc_base:safe_conversions",
     ]
   }
 }
diff --git a/modules/audio_processing/agc2/rnn_vad/rnn.cc b/modules/audio_processing/agc2/rnn_vad/rnn.cc
index fb4962f..1c9b736 100644
--- a/modules/audio_processing/agc2/rnn_vad/rnn.cc
+++ b/modules/audio_processing/agc2/rnn_vad/rnn.cc
@@ -40,21 +40,18 @@
 using rnnoise::kInputDenseBias;
 using rnnoise::kInputDenseWeights;
 using rnnoise::kInputLayerOutputSize;
-static_assert(kInputLayerOutputSize <= kFullyConnectedLayersMaxUnits,
-              "Increase kFullyConnectedLayersMaxUnits.");
+static_assert(kInputLayerOutputSize <= kFullyConnectedLayerMaxUnits, "");
 
 using rnnoise::kHiddenGruBias;
 using rnnoise::kHiddenGruRecurrentWeights;
 using rnnoise::kHiddenGruWeights;
 using rnnoise::kHiddenLayerOutputSize;
-static_assert(kHiddenLayerOutputSize <= kRecurrentLayersMaxUnits,
-              "Increase kRecurrentLayersMaxUnits.");
+static_assert(kHiddenLayerOutputSize <= kGruLayerMaxUnits, "");
 
 using rnnoise::kOutputDenseBias;
 using rnnoise::kOutputDenseWeights;
 using rnnoise::kOutputLayerOutputSize;
-static_assert(kOutputLayerOutputSize <= kFullyConnectedLayersMaxUnits,
-              "Increase kFullyConnectedLayersMaxUnits.");
+static_assert(kOutputLayerOutputSize <= kFullyConnectedLayerMaxUnits, "");
 
 using rnnoise::SigmoidApproximated;
 using rnnoise::TansigApproximated;
@@ -178,21 +175,21 @@
   const int stride_out = output_size * output_size;
 
   // Update gate.
-  std::array<float, kRecurrentLayersMaxUnits> update;
+  std::array<float, kGruLayerMaxUnits> update;
   ComputeGruUpdateResetGates(
       input_size, output_size, weights.subview(0, stride_in),
       recurrent_weights.subview(0, stride_out), bias.subview(0, output_size),
       input, state, update);
 
   // Reset gate.
-  std::array<float, kRecurrentLayersMaxUnits> reset;
+  std::array<float, kGruLayerMaxUnits> reset;
   ComputeGruUpdateResetGates(
       input_size, output_size, weights.subview(stride_in, stride_in),
       recurrent_weights.subview(stride_out, stride_out),
       bias.subview(output_size, output_size), input, state, reset);
 
   // Output gate.
-  std::array<float, kRecurrentLayersMaxUnits> output;
+  std::array<float, kGruLayerMaxUnits> output;
   ComputeGruOutputGate(
       input_size, output_size, weights.subview(2 * stride_in, stride_in),
       recurrent_weights.subview(2 * stride_out, stride_out),
@@ -279,7 +276,7 @@
       weights_(GetPreprocessedFcWeights(weights, output_size)),
       activation_function_(activation_function),
       cpu_features_(cpu_features) {
-  RTC_DCHECK_LE(output_size_, kFullyConnectedLayersMaxUnits)
+  RTC_DCHECK_LE(output_size_, kFullyConnectedLayerMaxUnits)
       << "Static over-allocation of fully-connected layers output vectors is "
          "not sufficient.";
   RTC_DCHECK_EQ(output_size_, bias_.size())
@@ -290,10 +287,6 @@
 
 FullyConnectedLayer::~FullyConnectedLayer() = default;
 
-rtc::ArrayView<const float> FullyConnectedLayer::GetOutput() const {
-  return rtc::ArrayView<const float>(output_.data(), output_size_);
-}
-
 void FullyConnectedLayer::ComputeOutput(rtc::ArrayView<const float> input) {
 #if defined(WEBRTC_ARCH_X86_FAMILY)
   // TODO(bugs.chromium.org/10480): Add AVX2.
@@ -321,7 +314,7 @@
       weights_(GetPreprocessedGruTensor(weights, output_size)),
       recurrent_weights_(
           GetPreprocessedGruTensor(recurrent_weights, output_size)) {
-  RTC_DCHECK_LE(output_size_, kRecurrentLayersMaxUnits)
+  RTC_DCHECK_LE(output_size_, kGruLayerMaxUnits)
       << "Static over-allocation of recurrent layers state vectors is not "
          "sufficient.";
   RTC_DCHECK_EQ(kNumGruGates * output_size_, bias_.size())
@@ -337,10 +330,6 @@
 
 GatedRecurrentLayer::~GatedRecurrentLayer() = default;
 
-rtc::ArrayView<const float> GatedRecurrentLayer::GetOutput() const {
-  return rtc::ArrayView<const float>(state_.data(), output_size_);
-}
-
 void GatedRecurrentLayer::Reset() {
   state_.fill(0.f);
 }
@@ -352,49 +341,49 @@
                         recurrent_weights_, bias_, state_);
 }
 
-RnnBasedVad::RnnBasedVad(const AvailableCpuFeatures& cpu_features)
-    : input_layer_(kInputLayerInputSize,
-                   kInputLayerOutputSize,
-                   kInputDenseBias,
-                   kInputDenseWeights,
-                   TansigApproximated,
-                   cpu_features),
-      hidden_layer_(kInputLayerOutputSize,
-                    kHiddenLayerOutputSize,
-                    kHiddenGruBias,
-                    kHiddenGruWeights,
-                    kHiddenGruRecurrentWeights),
-      output_layer_(kHiddenLayerOutputSize,
-                    kOutputLayerOutputSize,
-                    kOutputDenseBias,
-                    kOutputDenseWeights,
-                    SigmoidApproximated,
-                    cpu_features) {
+RnnVad::RnnVad(const AvailableCpuFeatures& cpu_features)
+    : input_(kInputLayerInputSize,
+             kInputLayerOutputSize,
+             kInputDenseBias,
+             kInputDenseWeights,
+             TansigApproximated,
+             cpu_features),
+      hidden_(kInputLayerOutputSize,
+              kHiddenLayerOutputSize,
+              kHiddenGruBias,
+              kHiddenGruWeights,
+              kHiddenGruRecurrentWeights),
+      output_(kHiddenLayerOutputSize,
+              kOutputLayerOutputSize,
+              kOutputDenseBias,
+              kOutputDenseWeights,
+              SigmoidApproximated,
+              cpu_features) {
   // Input-output chaining size checks.
-  RTC_DCHECK_EQ(input_layer_.output_size(), hidden_layer_.input_size())
+  RTC_DCHECK_EQ(input_.size(), hidden_.input_size())
       << "The input and the hidden layers sizes do not match.";
-  RTC_DCHECK_EQ(hidden_layer_.output_size(), output_layer_.input_size())
+  RTC_DCHECK_EQ(hidden_.size(), output_.input_size())
       << "The hidden and the output layers sizes do not match.";
 }
 
-RnnBasedVad::~RnnBasedVad() = default;
+RnnVad::~RnnVad() = default;
 
-void RnnBasedVad::Reset() {
-  hidden_layer_.Reset();
+void RnnVad::Reset() {
+  hidden_.Reset();
 }
 
-float RnnBasedVad::ComputeVadProbability(
+float RnnVad::ComputeVadProbability(
     rtc::ArrayView<const float, kFeatureVectorSize> feature_vector,
     bool is_silence) {
   if (is_silence) {
     Reset();
     return 0.f;
   }
-  input_layer_.ComputeOutput(feature_vector);
-  hidden_layer_.ComputeOutput(input_layer_.GetOutput());
-  output_layer_.ComputeOutput(hidden_layer_.GetOutput());
-  const auto vad_output = output_layer_.GetOutput();
-  return vad_output[0];
+  input_.ComputeOutput(feature_vector);
+  hidden_.ComputeOutput(input_);
+  output_.ComputeOutput(hidden_);
+  RTC_DCHECK_EQ(output_.size(), 1);
+  return output_.data()[0];
 }
 
 }  // namespace rnn_vad
diff --git a/modules/audio_processing/agc2/rnn_vad/rnn.h b/modules/audio_processing/agc2/rnn_vad/rnn.h
index 1ef4c76..c886034 100644
--- a/modules/audio_processing/agc2/rnn_vad/rnn.h
+++ b/modules/audio_processing/agc2/rnn_vad/rnn.h
@@ -26,21 +26,17 @@
 namespace webrtc {
 namespace rnn_vad {
 
-// Maximum number of units for a fully-connected layer. This value is used to
-// over-allocate space for fully-connected layers output vectors (implemented as
-// std::array). The value should equal the number of units of the largest
-// fully-connected layer.
-constexpr int kFullyConnectedLayersMaxUnits = 24;
+// Maximum number of units for an FC layer.
+constexpr int kFullyConnectedLayerMaxUnits = 24;
 
-// Maximum number of units for a recurrent layer. This value is used to
-// over-allocate space for recurrent layers state vectors (implemented as
-// std::array). The value should equal the number of units of the largest
-// recurrent layer.
-constexpr int kRecurrentLayersMaxUnits = 24;
+// Maximum number of units for a GRU layer.
+constexpr int kGruLayerMaxUnits = 24;
 
-// Fully-connected layer.
+// Fully-connected layer with a custom activation function which owns the output
+// buffer.
 class FullyConnectedLayer {
  public:
+  // Ctor. `output_size` cannot be greater than `kFullyConnectedLayerMaxUnits`.
   FullyConnectedLayer(int input_size,
                       int output_size,
                       rtc::ArrayView<const int8_t> bias,
@@ -50,9 +46,14 @@
   FullyConnectedLayer(const FullyConnectedLayer&) = delete;
   FullyConnectedLayer& operator=(const FullyConnectedLayer&) = delete;
   ~FullyConnectedLayer();
+
+  // Returns the size of the input vector.
   int input_size() const { return input_size_; }
-  int output_size() const { return output_size_; }
-  rtc::ArrayView<const float> GetOutput() const;
+  // Returns the pointer to the first element of the output buffer.
+  const float* data() const { return output_.data(); }
+  // Returns the size of the output buffer.
+  int size() const { return output_size_; }
+
   // Computes the fully-connected layer output.
   void ComputeOutput(rtc::ArrayView<const float> input);
 
@@ -64,14 +65,16 @@
   rtc::FunctionView<float(float)> activation_function_;
   // The output vector of a recurrent layer has length equal to |output_size_|.
   // However, for efficiency, over-allocation is used.
-  std::array<float, kFullyConnectedLayersMaxUnits> output_;
+  std::array<float, kFullyConnectedLayerMaxUnits> output_;
   const AvailableCpuFeatures cpu_features_;
 };
 
 // Recurrent layer with gated recurrent units (GRUs) with sigmoid and ReLU as
-// activation functions for the update/reset and output gates respectively.
+// activation functions for the update/reset and output gates respectively. It
+// owns the output buffer.
 class GatedRecurrentLayer {
  public:
+  // Ctor. `output_size` cannot be greater than `kGruLayerMaxUnits`.
   GatedRecurrentLayer(int input_size,
                       int output_size,
                       rtc::ArrayView<const int8_t> bias,
@@ -80,9 +83,15 @@
   GatedRecurrentLayer(const GatedRecurrentLayer&) = delete;
   GatedRecurrentLayer& operator=(const GatedRecurrentLayer&) = delete;
   ~GatedRecurrentLayer();
+
+  // Returns the size of the input vector.
   int input_size() const { return input_size_; }
-  int output_size() const { return output_size_; }
-  rtc::ArrayView<const float> GetOutput() const;
+  // Returns the pointer to the first element of the output buffer.
+  const float* data() const { return state_.data(); }
+  // Returns the size of the output buffer.
+  int size() const { return output_size_; }
+
+  // Resets the GRU state.
   void Reset();
   // Computes the recurrent layer output and updates the status.
   void ComputeOutput(rtc::ArrayView<const float> input);
@@ -95,26 +104,28 @@
   const std::vector<float> recurrent_weights_;
   // The state vector of a recurrent layer has length equal to |output_size_|.
   // However, to avoid dynamic allocation, over-allocation is used.
-  std::array<float, kRecurrentLayersMaxUnits> state_;
+  std::array<float, kGruLayerMaxUnits> state_;
 };
 
-// Recurrent network based VAD.
-class RnnBasedVad {
+// Recurrent network with hard-coded architecture and weights for voice activity
+// detection.
+class RnnVad {
  public:
-  explicit RnnBasedVad(const AvailableCpuFeatures& cpu_features);
-  RnnBasedVad(const RnnBasedVad&) = delete;
-  RnnBasedVad& operator=(const RnnBasedVad&) = delete;
-  ~RnnBasedVad();
+  explicit RnnVad(const AvailableCpuFeatures& cpu_features);
+  RnnVad(const RnnVad&) = delete;
+  RnnVad& operator=(const RnnVad&) = delete;
+  ~RnnVad();
   void Reset();
-  // Compute and returns the probability of voice (range: [0.0, 1.0]).
+  // Observes `feature_vector` and `is_silence`, updates the RNN and returns the
+  // current voice probability.
   float ComputeVadProbability(
       rtc::ArrayView<const float, kFeatureVectorSize> feature_vector,
       bool is_silence);
 
  private:
-  FullyConnectedLayer input_layer_;
-  GatedRecurrentLayer hidden_layer_;
-  FullyConnectedLayer output_layer_;
+  FullyConnectedLayer input_;
+  GatedRecurrentLayer hidden_;
+  FullyConnectedLayer output_;
 };
 
 }  // namespace rnn_vad
diff --git a/modules/audio_processing/agc2/rnn_vad/rnn_unittest.cc b/modules/audio_processing/agc2/rnn_vad/rnn_unittest.cc
index c311b55..19e0afd 100644
--- a/modules/audio_processing/agc2/rnn_vad/rnn_unittest.cc
+++ b/modules/audio_processing/agc2/rnn_vad/rnn_unittest.cc
@@ -39,30 +39,20 @@
     -0.690268f,  -0.925327f, -0.541354f, 0.58455f,    -0.606726f, -0.0372358f,
     0.565991f,   0.435854f,  0.420812f,  0.162198f,   -2.13f,     10.0089f};
 
-void WarmUpRnnVad(RnnBasedVad& rnn_vad) {
+void WarmUpRnnVad(RnnVad& rnn_vad) {
   for (int i = 0; i < 10; ++i) {
     rnn_vad.ComputeVadProbability(kFeatures, /*is_silence=*/false);
   }
 }
 
-void TestFullyConnectedLayer(FullyConnectedLayer* fc,
-                             rtc::ArrayView<const float> input_vector,
-                             rtc::ArrayView<const float> expected_output) {
-  RTC_CHECK(fc);
-  fc->ComputeOutput(input_vector);
-  ExpectNearAbsolute(expected_output, fc->GetOutput(), 1e-5f);
-}
-
 void TestGatedRecurrentLayer(
     GatedRecurrentLayer& gru,
     rtc::ArrayView<const float> input_sequence,
     rtc::ArrayView<const float> expected_output_sequence) {
-  auto gru_output_view = gru.GetOutput();
   const int input_sequence_length = rtc::CheckedDivExact(
       rtc::dchecked_cast<int>(input_sequence.size()), gru.input_size());
   const int output_sequence_length = rtc::CheckedDivExact(
-      rtc::dchecked_cast<int>(expected_output_sequence.size()),
-      gru.output_size());
+      rtc::dchecked_cast<int>(expected_output_sequence.size()), gru.size());
   ASSERT_EQ(input_sequence_length, output_sequence_length)
       << "The test data length is invalid.";
   // Feed the GRU layer and check the output at every step.
@@ -71,9 +61,9 @@
     SCOPED_TRACE(i);
     gru.ComputeOutput(
         input_sequence.subview(i * gru.input_size(), gru.input_size()));
-    const auto expected_output = expected_output_sequence.subview(
-        i * gru.output_size(), gru.output_size());
-    ExpectNearAbsolute(expected_output, gru_output_view, 3e-6f);
+    const auto expected_output =
+        expected_output_sequence.subview(i * gru.size(), gru.size());
+    ExpectNearAbsolute(expected_output, gru, 3e-6f);
   }
 }
 
@@ -190,8 +180,8 @@
       rnnoise::kInputLayerInputSize, rnnoise::kInputLayerOutputSize,
       rnnoise::kInputDenseBias, rnnoise::kInputDenseWeights,
       rnnoise::TansigApproximated, /*cpu_features=*/GetParam());
-  TestFullyConnectedLayer(&fc, kFullyConnectedInputVector,
-                          kFullyConnectedExpectedOutput);
+  fc.ComputeOutput(kFullyConnectedInputVector);
+  ExpectNearAbsolute(kFullyConnectedExpectedOutput, fc, 1e-5f);
 }
 
 TEST_P(RnnParametrization, DISABLED_BenchmarkFullyConnectedLayer) {
@@ -237,7 +227,7 @@
 
 // Checks that the speech probability is zero with silence.
 TEST(RnnVadTest, CheckZeroProbabilityWithSilence) {
-  RnnBasedVad rnn_vad(GetAvailableCpuFeatures());
+  RnnVad rnn_vad(GetAvailableCpuFeatures());
   WarmUpRnnVad(rnn_vad);
   EXPECT_EQ(rnn_vad.ComputeVadProbability(kFeatures, /*is_silence=*/true), 0.f);
 }
@@ -245,7 +235,7 @@
 // Checks that the same output is produced after reset given the same input
 // sequence.
 TEST(RnnVadTest, CheckRnnVadReset) {
-  RnnBasedVad rnn_vad(GetAvailableCpuFeatures());
+  RnnVad rnn_vad(GetAvailableCpuFeatures());
   WarmUpRnnVad(rnn_vad);
   float pre = rnn_vad.ComputeVadProbability(kFeatures, /*is_silence=*/false);
   rnn_vad.Reset();
@@ -257,7 +247,7 @@
 // Checks that the same output is produced after silence is observed given the
 // same input sequence.
 TEST(RnnVadTest, CheckRnnVadSilence) {
-  RnnBasedVad rnn_vad(GetAvailableCpuFeatures());
+  RnnVad rnn_vad(GetAvailableCpuFeatures());
   WarmUpRnnVad(rnn_vad);
   float pre = rnn_vad.ComputeVadProbability(kFeatures, /*is_silence=*/false);
   rnn_vad.ComputeVadProbability(kFeatures, /*is_silence=*/true);
diff --git a/modules/audio_processing/agc2/rnn_vad/rnn_vad_tool.cc b/modules/audio_processing/agc2/rnn_vad/rnn_vad_tool.cc
index 0f3ad5c..a0e1242 100644
--- a/modules/audio_processing/agc2/rnn_vad/rnn_vad_tool.cc
+++ b/modules/audio_processing/agc2/rnn_vad/rnn_vad_tool.cc
@@ -67,7 +67,7 @@
   const AvailableCpuFeatures cpu_features = GetAvailableCpuFeatures();
   FeaturesExtractor features_extractor(cpu_features);
   std::array<float, kFeatureVectorSize> feature_vector;
-  RnnBasedVad rnn_vad(cpu_features);
+  RnnVad rnn_vad(cpu_features);
 
   // Compute VAD probabilities.
   while (true) {
diff --git a/modules/audio_processing/agc2/rnn_vad/rnn_vad_unittest.cc b/modules/audio_processing/agc2/rnn_vad/rnn_vad_unittest.cc
index fa7795c..81553b4 100644
--- a/modules/audio_processing/agc2/rnn_vad/rnn_vad_unittest.cc
+++ b/modules/audio_processing/agc2/rnn_vad/rnn_vad_unittest.cc
@@ -68,7 +68,7 @@
   PushSincResampler decimator(kFrameSize10ms48kHz, kFrameSize10ms24kHz);
   const AvailableCpuFeatures cpu_features = GetParam();
   FeaturesExtractor features_extractor(cpu_features);
-  RnnBasedVad rnn_vad(cpu_features);
+  RnnVad rnn_vad(cpu_features);
 
   // Init input samples and expected output readers.
   auto samples_reader = CreatePcmSamplesReader(kFrameSize10ms48kHz);
@@ -135,7 +135,7 @@
   const AvailableCpuFeatures cpu_features = GetParam();
   FeaturesExtractor features_extractor(cpu_features);
   std::array<float, kFeatureVectorSize> feature_vector;
-  RnnBasedVad rnn_vad(cpu_features);
+  RnnVad rnn_vad(cpu_features);
   constexpr int number_of_tests = 100;
   ::webrtc::test::PerformanceTimer perf_timer(number_of_tests);
   for (int k = 0; k < number_of_tests; ++k) {
diff --git a/modules/audio_processing/agc2/rnn_vad/vector_math.h b/modules/audio_processing/agc2/rnn_vad/vector_math.h
index 51bbbfb..0600b90 100644
--- a/modules/audio_processing/agc2/rnn_vad/vector_math.h
+++ b/modules/audio_processing/agc2/rnn_vad/vector_math.h
@@ -23,6 +23,7 @@
 #include "api/array_view.h"
 #include "modules/audio_processing/agc2/cpu_features.h"
 #include "rtc_base/checks.h"
+#include "rtc_base/numerics/safe_conversions.h"
 #include "rtc_base/system/arch.h"
 
 namespace webrtc {
@@ -63,8 +64,8 @@
       accumulator = _mm_add_ps(accumulator, high);
       float dot_product = _mm_cvtss_f32(accumulator);
       // Add the result for the last block if incomplete.
-      for (int i = incomplete_block_index; static_cast<size_t>(i) < x.size();
-           ++i) {
+      for (int i = incomplete_block_index;
+           i < rtc::dchecked_cast<int>(x.size()); ++i) {
         dot_product += x[i] * y[i];
       }
       return dot_product;
diff --git a/modules/audio_processing/agc2/rnn_vad/vector_math_avx2.cc b/modules/audio_processing/agc2/rnn_vad/vector_math_avx2.cc
index 3b2c4ad..e4d246d 100644
--- a/modules/audio_processing/agc2/rnn_vad/vector_math_avx2.cc
+++ b/modules/audio_processing/agc2/rnn_vad/vector_math_avx2.cc
@@ -14,6 +14,7 @@
 
 #include "api/array_view.h"
 #include "rtc_base/checks.h"
+#include "rtc_base/numerics/safe_conversions.h"
 
 namespace webrtc {
 namespace rnn_vad {
@@ -43,7 +44,8 @@
   low = _mm_add_ss(high, low);
   float dot_product = _mm_cvtss_f32(low);
   // Add the result for the last block if incomplete.
-  for (int i = incomplete_block_index; static_cast<size_t>(i) < x.size(); ++i) {
+  for (int i = incomplete_block_index; i < rtc::dchecked_cast<int>(x.size());
+       ++i) {
     dot_product += x[i] * y[i];
   }
   return dot_product;
diff --git a/modules/audio_processing/agc2/vad_with_level.cc b/modules/audio_processing/agc2/vad_with_level.cc
index da3bd0a..b54ae56 100644
--- a/modules/audio_processing/agc2/vad_with_level.cc
+++ b/modules/audio_processing/agc2/vad_with_level.cc
@@ -60,7 +60,7 @@
  private:
   PushResampler<float> resampler_;
   rnn_vad::FeaturesExtractor features_extractor_;
-  rnn_vad::RnnBasedVad rnn_vad_;
+  rnn_vad::RnnVad rnn_vad_;
 };
 
 // Returns an updated version of `p_old` by using instant decay and the given