RNN VAD: GRU layer isolated into rnn_gru.h/.cc
Refactoring done to more easily and cleanly add SIMD optimizations and
to remove `GatedRecurrentLayer` from the RNN VAD api.
Bug: webrtc:10480
Change-Id: Ie1dffdd9b19c57c03a0b634f6818c0780456a66c
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/195445
Commit-Queue: Alessio Bazzica <alessiob@webrtc.org>
Reviewed-by: Jakob Ivarsson <jakobi@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#32770}
diff --git a/modules/audio_processing/agc2/rnn_vad/BUILD.gn b/modules/audio_processing/agc2/rnn_vad/BUILD.gn
index c57971a..29cdfeb 100644
--- a/modules/audio_processing/agc2/rnn_vad/BUILD.gn
+++ b/modules/audio_processing/agc2/rnn_vad/BUILD.gn
@@ -32,11 +32,9 @@
"..:biquad_filter",
"..:cpu_features",
"../../../../api:array_view",
- "../../../../api:function_view",
"../../../../rtc_base:checks",
"../../../../rtc_base:safe_compare",
"../../../../rtc_base:safe_conversions",
- "../../../../rtc_base/system:arch",
"//third_party/rnnoise:rnn_vad",
]
}
@@ -83,6 +81,8 @@
sources = [
"rnn_fc.cc",
"rnn_fc.h",
+ "rnn_gru.cc",
+ "rnn_gru.h",
]
deps = [
":rnn_vad_common",
@@ -241,6 +241,7 @@
"pitch_search_unittest.cc",
"ring_buffer_unittest.cc",
"rnn_fc_unittest.cc",
+ "rnn_gru_unittest.cc",
"rnn_unittest.cc",
"rnn_vad_unittest.cc",
"sequence_buffer_unittest.cc",
diff --git a/modules/audio_processing/agc2/rnn_vad/rnn.cc b/modules/audio_processing/agc2/rnn_vad/rnn.cc
index 9d6d28f..c1bded1 100644
--- a/modules/audio_processing/agc2/rnn_vad/rnn.cc
+++ b/modules/audio_processing/agc2/rnn_vad/rnn.cc
@@ -10,208 +10,33 @@
#include "modules/audio_processing/agc2/rnn_vad/rnn.h"
-// Defines WEBRTC_ARCH_X86_FAMILY, used below.
-#include "rtc_base/system/arch.h"
-
-#if defined(WEBRTC_HAS_NEON)
-#include <arm_neon.h>
-#endif
-#if defined(WEBRTC_ARCH_X86_FAMILY)
-#include <emmintrin.h>
-#endif
-#include <algorithm>
-#include <array>
-#include <cmath>
-#include <numeric>
-
#include "rtc_base/checks.h"
-#include "rtc_base/numerics/safe_conversions.h"
-#include "third_party/rnnoise/src/rnn_activations.h"
#include "third_party/rnnoise/src/rnn_vad_weights.h"
namespace webrtc {
namespace rnn_vad {
namespace {
-using rnnoise::kWeightsScale;
-
-using rnnoise::kInputLayerInputSize;
+using ::rnnoise::kInputLayerInputSize;
static_assert(kFeatureVectorSize == kInputLayerInputSize, "");
-using rnnoise::kInputDenseBias;
-using rnnoise::kInputDenseWeights;
-using rnnoise::kInputLayerOutputSize;
+using ::rnnoise::kInputDenseBias;
+using ::rnnoise::kInputDenseWeights;
+using ::rnnoise::kInputLayerOutputSize;
static_assert(kInputLayerOutputSize <= kFullyConnectedLayerMaxUnits, "");
-using rnnoise::kHiddenGruBias;
-using rnnoise::kHiddenGruRecurrentWeights;
-using rnnoise::kHiddenGruWeights;
-using rnnoise::kHiddenLayerOutputSize;
+using ::rnnoise::kHiddenGruBias;
+using ::rnnoise::kHiddenGruRecurrentWeights;
+using ::rnnoise::kHiddenGruWeights;
+using ::rnnoise::kHiddenLayerOutputSize;
static_assert(kHiddenLayerOutputSize <= kGruLayerMaxUnits, "");
-using rnnoise::kOutputDenseBias;
-using rnnoise::kOutputDenseWeights;
-using rnnoise::kOutputLayerOutputSize;
+using ::rnnoise::kOutputDenseBias;
+using ::rnnoise::kOutputDenseWeights;
+using ::rnnoise::kOutputLayerOutputSize;
static_assert(kOutputLayerOutputSize <= kFullyConnectedLayerMaxUnits, "");
-using rnnoise::SigmoidApproximated;
-using rnnoise::TansigApproximated;
-
-inline float RectifiedLinearUnit(float x) {
- return x < 0.f ? 0.f : x;
-}
-
-constexpr int kNumGruGates = 3; // Update, reset, output.
-
-// TODO(bugs.chromium.org/10480): Hard-coded optimized layout and remove this
-// function to improve setup time.
-// Casts and scales |tensor_src| for a GRU layer and re-arranges the layout.
-// It works both for weights, recurrent weights and bias.
-std::vector<float> GetPreprocessedGruTensor(
- rtc::ArrayView<const int8_t> tensor_src,
- int output_size) {
- // Transpose, cast and scale.
- // |n| is the size of the first dimension of the 3-dim tensor |weights|.
- const int n = rtc::CheckedDivExact(rtc::dchecked_cast<int>(tensor_src.size()),
- output_size * kNumGruGates);
- const int stride_src = kNumGruGates * output_size;
- const int stride_dst = n * output_size;
- std::vector<float> tensor_dst(tensor_src.size());
- for (int g = 0; g < kNumGruGates; ++g) {
- for (int o = 0; o < output_size; ++o) {
- for (int i = 0; i < n; ++i) {
- tensor_dst[g * stride_dst + o * n + i] =
- rnnoise::kWeightsScale *
- static_cast<float>(
- tensor_src[i * stride_src + g * output_size + o]);
- }
- }
- }
- return tensor_dst;
-}
-
-void ComputeGruUpdateResetGates(int input_size,
- int output_size,
- rtc::ArrayView<const float> weights,
- rtc::ArrayView<const float> recurrent_weights,
- rtc::ArrayView<const float> bias,
- rtc::ArrayView<const float> input,
- rtc::ArrayView<const float> state,
- rtc::ArrayView<float> gate) {
- for (int o = 0; o < output_size; ++o) {
- gate[o] = bias[o];
- for (int i = 0; i < input_size; ++i) {
- gate[o] += input[i] * weights[o * input_size + i];
- }
- for (int s = 0; s < output_size; ++s) {
- gate[o] += state[s] * recurrent_weights[o * output_size + s];
- }
- gate[o] = SigmoidApproximated(gate[o]);
- }
-}
-
-void ComputeGruOutputGate(int input_size,
- int output_size,
- rtc::ArrayView<const float> weights,
- rtc::ArrayView<const float> recurrent_weights,
- rtc::ArrayView<const float> bias,
- rtc::ArrayView<const float> input,
- rtc::ArrayView<const float> state,
- rtc::ArrayView<const float> reset,
- rtc::ArrayView<float> gate) {
- for (int o = 0; o < output_size; ++o) {
- gate[o] = bias[o];
- for (int i = 0; i < input_size; ++i) {
- gate[o] += input[i] * weights[o * input_size + i];
- }
- for (int s = 0; s < output_size; ++s) {
- gate[o] += state[s] * recurrent_weights[o * output_size + s] * reset[s];
- }
- gate[o] = RectifiedLinearUnit(gate[o]);
- }
-}
-
-// Gated recurrent unit (GRU) layer un-optimized implementation.
-void ComputeGruLayerOutput(int input_size,
- int output_size,
- rtc::ArrayView<const float> input,
- rtc::ArrayView<const float> weights,
- rtc::ArrayView<const float> recurrent_weights,
- rtc::ArrayView<const float> bias,
- rtc::ArrayView<float> state) {
- RTC_DCHECK_EQ(input_size, input.size());
- // Stride and offset used to read parameter arrays.
- const int stride_in = input_size * output_size;
- const int stride_out = output_size * output_size;
-
- // Update gate.
- std::array<float, kGruLayerMaxUnits> update;
- ComputeGruUpdateResetGates(
- input_size, output_size, weights.subview(0, stride_in),
- recurrent_weights.subview(0, stride_out), bias.subview(0, output_size),
- input, state, update);
-
- // Reset gate.
- std::array<float, kGruLayerMaxUnits> reset;
- ComputeGruUpdateResetGates(
- input_size, output_size, weights.subview(stride_in, stride_in),
- recurrent_weights.subview(stride_out, stride_out),
- bias.subview(output_size, output_size), input, state, reset);
-
- // Output gate.
- std::array<float, kGruLayerMaxUnits> output;
- ComputeGruOutputGate(
- input_size, output_size, weights.subview(2 * stride_in, stride_in),
- recurrent_weights.subview(2 * stride_out, stride_out),
- bias.subview(2 * output_size, output_size), input, state, reset, output);
-
- // Update output through the update gates and update the state.
- for (int o = 0; o < output_size; ++o) {
- output[o] = update[o] * state[o] + (1.f - update[o]) * output[o];
- state[o] = output[o];
- }
-}
-
} // namespace
-GatedRecurrentLayer::GatedRecurrentLayer(
- const int input_size,
- const int output_size,
- const rtc::ArrayView<const int8_t> bias,
- const rtc::ArrayView<const int8_t> weights,
- const rtc::ArrayView<const int8_t> recurrent_weights)
- : input_size_(input_size),
- output_size_(output_size),
- bias_(GetPreprocessedGruTensor(bias, output_size)),
- weights_(GetPreprocessedGruTensor(weights, output_size)),
- recurrent_weights_(
- GetPreprocessedGruTensor(recurrent_weights, output_size)) {
- RTC_DCHECK_LE(output_size_, kGruLayerMaxUnits)
- << "Static over-allocation of recurrent layers state vectors is not "
- "sufficient.";
- RTC_DCHECK_EQ(kNumGruGates * output_size_, bias_.size())
- << "Mismatching output size and bias terms array size.";
- RTC_DCHECK_EQ(kNumGruGates * input_size_ * output_size_, weights_.size())
- << "Mismatching input-output size and weight coefficients array size.";
- RTC_DCHECK_EQ(kNumGruGates * output_size_ * output_size_,
- recurrent_weights_.size())
- << "Mismatching input-output size and recurrent weight coefficients array"
- " size.";
- Reset();
-}
-
-GatedRecurrentLayer::~GatedRecurrentLayer() = default;
-
-void GatedRecurrentLayer::Reset() {
- state_.fill(0.f);
-}
-
-void GatedRecurrentLayer::ComputeOutput(rtc::ArrayView<const float> input) {
- // TODO(bugs.chromium.org/10480): Add AVX2.
- // TODO(bugs.chromium.org/10480): Add Neon.
- ComputeGruLayerOutput(input_size_, output_size_, input, weights_,
- recurrent_weights_, bias_, state_);
-}
-
RnnVad::RnnVad(const AvailableCpuFeatures& cpu_features)
: input_(kInputLayerInputSize,
kInputLayerOutputSize,
@@ -224,7 +49,8 @@
kHiddenLayerOutputSize,
kHiddenGruBias,
kHiddenGruWeights,
- kHiddenGruRecurrentWeights),
+ kHiddenGruRecurrentWeights,
+ /*layer_name=*/"GRU1"),
output_(kHiddenLayerOutputSize,
kOutputLayerOutputSize,
kOutputDenseBias,
diff --git a/modules/audio_processing/agc2/rnn_vad/rnn.h b/modules/audio_processing/agc2/rnn_vad/rnn.h
index df99c3c..3148f1b 100644
--- a/modules/audio_processing/agc2/rnn_vad/rnn.h
+++ b/modules/audio_processing/agc2/rnn_vad/rnn.h
@@ -18,56 +18,14 @@
#include <vector>
#include "api/array_view.h"
-#include "api/function_view.h"
#include "modules/audio_processing/agc2/cpu_features.h"
#include "modules/audio_processing/agc2/rnn_vad/common.h"
#include "modules/audio_processing/agc2/rnn_vad/rnn_fc.h"
-#include "rtc_base/system/arch.h"
+#include "modules/audio_processing/agc2/rnn_vad/rnn_gru.h"
namespace webrtc {
namespace rnn_vad {
-// Maximum number of units for a GRU layer.
-constexpr int kGruLayerMaxUnits = 24;
-
-// Recurrent layer with gated recurrent units (GRUs) with sigmoid and ReLU as
-// activation functions for the update/reset and output gates respectively. It
-// owns the output buffer.
-class GatedRecurrentLayer {
- public:
- // Ctor. `output_size` cannot be greater than `kGruLayerMaxUnits`.
- GatedRecurrentLayer(int input_size,
- int output_size,
- rtc::ArrayView<const int8_t> bias,
- rtc::ArrayView<const int8_t> weights,
- rtc::ArrayView<const int8_t> recurrent_weights);
- GatedRecurrentLayer(const GatedRecurrentLayer&) = delete;
- GatedRecurrentLayer& operator=(const GatedRecurrentLayer&) = delete;
- ~GatedRecurrentLayer();
-
- // Returns the size of the input vector.
- int input_size() const { return input_size_; }
- // Returns the pointer to the first element of the output buffer.
- const float* data() const { return state_.data(); }
- // Returns the size of the output buffer.
- int size() const { return output_size_; }
-
- // Resets the GRU state.
- void Reset();
- // Computes the recurrent layer output and updates the status.
- void ComputeOutput(rtc::ArrayView<const float> input);
-
- private:
- const int input_size_;
- const int output_size_;
- const std::vector<float> bias_;
- const std::vector<float> weights_;
- const std::vector<float> recurrent_weights_;
- // The state vector of a recurrent layer has length equal to |output_size_|.
- // However, to avoid dynamic allocation, over-allocation is used.
- std::array<float, kGruLayerMaxUnits> state_;
-};
-
// Recurrent network with hard-coded architecture and weights for voice activity
// detection.
class RnnVad {
diff --git a/modules/audio_processing/agc2/rnn_vad/rnn_gru.cc b/modules/audio_processing/agc2/rnn_vad/rnn_gru.cc
new file mode 100644
index 0000000..f37fc2a
--- /dev/null
+++ b/modules/audio_processing/agc2/rnn_vad/rnn_gru.cc
@@ -0,0 +1,170 @@
+/*
+ * Copyright (c) 2020 The WebRTC project authors. All Rights Reserved.
+ *
+ * Use of this source code is governed by a BSD-style license
+ * that can be found in the LICENSE file in the root of the source
+ * tree. An additional intellectual property rights grant can be found
+ * in the file PATENTS. All contributing project authors may
+ * be found in the AUTHORS file in the root of the source tree.
+ */
+
+#include "modules/audio_processing/agc2/rnn_vad/rnn_gru.h"
+
+#include "rtc_base/checks.h"
+#include "rtc_base/numerics/safe_conversions.h"
+#include "third_party/rnnoise/src/rnn_activations.h"
+#include "third_party/rnnoise/src/rnn_vad_weights.h"
+
+namespace webrtc {
+namespace rnn_vad {
+namespace {
+
+constexpr int kNumGruGates = 3; // Update, reset, output.
+
+std::vector<float> PreprocessGruTensor(rtc::ArrayView<const int8_t> tensor_src,
+ int output_size) {
+ // Transpose, cast and scale.
+ // |n| is the size of the first dimension of the 3-dim tensor |weights|.
+ const int n = rtc::CheckedDivExact(rtc::dchecked_cast<int>(tensor_src.size()),
+ output_size * kNumGruGates);
+ const int stride_src = kNumGruGates * output_size;
+ const int stride_dst = n * output_size;
+ std::vector<float> tensor_dst(tensor_src.size());
+ for (int g = 0; g < kNumGruGates; ++g) {
+ for (int o = 0; o < output_size; ++o) {
+ for (int i = 0; i < n; ++i) {
+ tensor_dst[g * stride_dst + o * n + i] =
+ ::rnnoise::kWeightsScale *
+ static_cast<float>(
+ tensor_src[i * stride_src + g * output_size + o]);
+ }
+ }
+ }
+ return tensor_dst;
+}
+
+void ComputeGruUpdateResetGates(int input_size,
+ int output_size,
+ rtc::ArrayView<const float> weights,
+ rtc::ArrayView<const float> recurrent_weights,
+ rtc::ArrayView<const float> bias,
+ rtc::ArrayView<const float> input,
+ rtc::ArrayView<const float> state,
+ rtc::ArrayView<float> gate) {
+ for (int o = 0; o < output_size; ++o) {
+ gate[o] = bias[o];
+ for (int i = 0; i < input_size; ++i) {
+ gate[o] += input[i] * weights[o * input_size + i];
+ }
+ for (int s = 0; s < output_size; ++s) {
+ gate[o] += state[s] * recurrent_weights[o * output_size + s];
+ }
+ gate[o] = ::rnnoise::SigmoidApproximated(gate[o]);
+ }
+}
+
+void ComputeGruOutputGate(int input_size,
+ int output_size,
+ rtc::ArrayView<const float> weights,
+ rtc::ArrayView<const float> recurrent_weights,
+ rtc::ArrayView<const float> bias,
+ rtc::ArrayView<const float> input,
+ rtc::ArrayView<const float> state,
+ rtc::ArrayView<const float> reset,
+ rtc::ArrayView<float> gate) {
+ for (int o = 0; o < output_size; ++o) {
+ gate[o] = bias[o];
+ for (int i = 0; i < input_size; ++i) {
+ gate[o] += input[i] * weights[o * input_size + i];
+ }
+ for (int s = 0; s < output_size; ++s) {
+ gate[o] += state[s] * recurrent_weights[o * output_size + s] * reset[s];
+ }
+ // Rectified linear unit.
+ if (gate[o] < 0.f) {
+ gate[o] = 0.f;
+ }
+ }
+}
+
+} // namespace
+
+GatedRecurrentLayer::GatedRecurrentLayer(
+ const int input_size,
+ const int output_size,
+ const rtc::ArrayView<const int8_t> bias,
+ const rtc::ArrayView<const int8_t> weights,
+ const rtc::ArrayView<const int8_t> recurrent_weights,
+ absl::string_view layer_name)
+ : input_size_(input_size),
+ output_size_(output_size),
+ bias_(PreprocessGruTensor(bias, output_size)),
+ weights_(PreprocessGruTensor(weights, output_size)),
+ recurrent_weights_(PreprocessGruTensor(recurrent_weights, output_size)) {
+ RTC_DCHECK_LE(output_size_, kGruLayerMaxUnits)
+ << "Insufficient GRU layer over-allocation (" << layer_name << ").";
+ RTC_DCHECK_EQ(kNumGruGates * output_size_, bias_.size())
+ << "Mismatching output size and bias terms array size (" << layer_name
+ << ").";
+ RTC_DCHECK_EQ(kNumGruGates * input_size_ * output_size_, weights_.size())
+ << "Mismatching input-output size and weight coefficients array size ("
+ << layer_name << ").";
+ RTC_DCHECK_EQ(kNumGruGates * output_size_ * output_size_,
+ recurrent_weights_.size())
+ << "Mismatching input-output size and recurrent weight coefficients array"
+ " size ("
+ << layer_name << ").";
+ Reset();
+}
+
+GatedRecurrentLayer::~GatedRecurrentLayer() = default;
+
+void GatedRecurrentLayer::Reset() {
+ state_.fill(0.f);
+}
+
+void GatedRecurrentLayer::ComputeOutput(rtc::ArrayView<const float> input) {
+ RTC_DCHECK_EQ(input.size(), input_size_);
+
+ // TODO(bugs.chromium.org/10480): Add AVX2.
+ // TODO(bugs.chromium.org/10480): Add Neon.
+
+ // Stride and offset used to read parameter arrays.
+ const int stride_in = input_size_ * output_size_;
+ const int stride_out = output_size_ * output_size_;
+
+ rtc::ArrayView<const float> bias(bias_);
+ rtc::ArrayView<const float> weights(weights_);
+ rtc::ArrayView<const float> recurrent_weights(recurrent_weights_);
+
+ // Update gate.
+ std::array<float, kGruLayerMaxUnits> update;
+ ComputeGruUpdateResetGates(
+ input_size_, output_size_, weights.subview(0, stride_in),
+ recurrent_weights.subview(0, stride_out), bias.subview(0, output_size_),
+ input, state_, update);
+
+ // Reset gate.
+ std::array<float, kGruLayerMaxUnits> reset;
+ ComputeGruUpdateResetGates(
+ input_size_, output_size_, weights.subview(stride_in, stride_in),
+ recurrent_weights.subview(stride_out, stride_out),
+ bias.subview(output_size_, output_size_), input, state_, reset);
+
+ // Output gate.
+ std::array<float, kGruLayerMaxUnits> output;
+ ComputeGruOutputGate(input_size_, output_size_,
+ weights.subview(2 * stride_in, stride_in),
+ recurrent_weights.subview(2 * stride_out, stride_out),
+ bias.subview(2 * output_size_, output_size_), input,
+ state_, reset, output);
+
+ // Update output through the update gates and update the state.
+ for (int o = 0; o < output_size_; ++o) {
+ output[o] = update[o] * state_[o] + (1.f - update[o]) * output[o];
+ state_[o] = output[o];
+ }
+}
+
+} // namespace rnn_vad
+} // namespace webrtc
diff --git a/modules/audio_processing/agc2/rnn_vad/rnn_gru.h b/modules/audio_processing/agc2/rnn_vad/rnn_gru.h
new file mode 100644
index 0000000..f66b048
--- /dev/null
+++ b/modules/audio_processing/agc2/rnn_vad/rnn_gru.h
@@ -0,0 +1,67 @@
+/*
+ * Copyright (c) 2020 The WebRTC project authors. All Rights Reserved.
+ *
+ * Use of this source code is governed by a BSD-style license
+ * that can be found in the LICENSE file in the root of the source
+ * tree. An additional intellectual property rights grant can be found
+ * in the file PATENTS. All contributing project authors may
+ * be found in the AUTHORS file in the root of the source tree.
+ */
+
+#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RNN_GRU_H_
+#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RNN_GRU_H_
+
+#include <array>
+#include <vector>
+
+#include "absl/strings/string_view.h"
+#include "api/array_view.h"
+#include "modules/audio_processing/agc2/cpu_features.h"
+
+namespace webrtc {
+namespace rnn_vad {
+
+// Maximum number of units for a GRU layer.
+constexpr int kGruLayerMaxUnits = 24;
+
+// Recurrent layer with gated recurrent units (GRUs) with sigmoid and ReLU as
+// activation functions for the update/reset and output gates respectively.
+class GatedRecurrentLayer {
+ public:
+ // Ctor. `output_size` cannot be greater than `kGruLayerMaxUnits`.
+ GatedRecurrentLayer(int input_size,
+ int output_size,
+ rtc::ArrayView<const int8_t> bias,
+ rtc::ArrayView<const int8_t> weights,
+ rtc::ArrayView<const int8_t> recurrent_weights,
+ absl::string_view layer_name);
+ GatedRecurrentLayer(const GatedRecurrentLayer&) = delete;
+ GatedRecurrentLayer& operator=(const GatedRecurrentLayer&) = delete;
+ ~GatedRecurrentLayer();
+
+ // Returns the size of the input vector.
+ int input_size() const { return input_size_; }
+ // Returns the pointer to the first element of the output buffer.
+ const float* data() const { return state_.data(); }
+ // Returns the size of the output buffer.
+ int size() const { return output_size_; }
+
+ // Resets the GRU state.
+ void Reset();
+ // Computes the recurrent layer output and updates the status.
+ void ComputeOutput(rtc::ArrayView<const float> input);
+
+ private:
+ const int input_size_;
+ const int output_size_;
+ const std::vector<float> bias_;
+ const std::vector<float> weights_;
+ const std::vector<float> recurrent_weights_;
+ // Over-allocated array with size equal to `output_size_`.
+ std::array<float, kGruLayerMaxUnits> state_;
+};
+
+} // namespace rnn_vad
+} // namespace webrtc
+
+#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RNN_GRU_H_
diff --git a/modules/audio_processing/agc2/rnn_vad/rnn_gru_unittest.cc b/modules/audio_processing/agc2/rnn_vad/rnn_gru_unittest.cc
new file mode 100644
index 0000000..54e1cf5
--- /dev/null
+++ b/modules/audio_processing/agc2/rnn_vad/rnn_gru_unittest.cc
@@ -0,0 +1,140 @@
+/*
+ * Copyright (c) 2020 The WebRTC project authors. All Rights Reserved.
+ *
+ * Use of this source code is governed by a BSD-style license
+ * that can be found in the LICENSE file in the root of the source
+ * tree. An additional intellectual property rights grant can be found
+ * in the file PATENTS. All contributing project authors may
+ * be found in the AUTHORS file in the root of the source tree.
+ */
+
+#include "modules/audio_processing/agc2/rnn_vad/rnn_gru.h"
+
+#include <array>
+
+#include "api/array_view.h"
+#include "modules/audio_processing/agc2/rnn_vad/test_utils.h"
+#include "modules/audio_processing/test/performance_timer.h"
+#include "rtc_base/checks.h"
+#include "rtc_base/logging.h"
+#include "test/gtest.h"
+
+namespace webrtc {
+namespace rnn_vad {
+namespace test {
+namespace {
+
+void TestGatedRecurrentLayer(
+ GatedRecurrentLayer& gru,
+ rtc::ArrayView<const float> input_sequence,
+ rtc::ArrayView<const float> expected_output_sequence) {
+ const int input_sequence_length = rtc::CheckedDivExact(
+ rtc::dchecked_cast<int>(input_sequence.size()), gru.input_size());
+ const int output_sequence_length = rtc::CheckedDivExact(
+ rtc::dchecked_cast<int>(expected_output_sequence.size()), gru.size());
+ ASSERT_EQ(input_sequence_length, output_sequence_length)
+ << "The test data length is invalid.";
+ // Feed the GRU layer and check the output at every step.
+ gru.Reset();
+ for (int i = 0; i < input_sequence_length; ++i) {
+ SCOPED_TRACE(i);
+ gru.ComputeOutput(
+ input_sequence.subview(i * gru.input_size(), gru.input_size()));
+ const auto expected_output =
+ expected_output_sequence.subview(i * gru.size(), gru.size());
+ ExpectNearAbsolute(expected_output, gru, 3e-6f);
+ }
+}
+
+// Gated recurrent units layer test data.
+constexpr int kGruInputSize = 5;
+constexpr int kGruOutputSize = 4;
+constexpr std::array<int8_t, 12> kGruBias = {96, -99, -81, -114, 49, 119,
+ -118, 68, -76, 91, 121, 125};
+constexpr std::array<int8_t, 60> kGruWeights = {
+ // Input 0.
+ 124, 9, 1, 116, // Update.
+ -66, -21, -118, -110, // Reset.
+ 104, 75, -23, -51, // Output.
+ // Input 1.
+ -72, -111, 47, 93, // Update.
+ 77, -98, 41, -8, // Reset.
+ 40, -23, -43, -107, // Output.
+ // Input 2.
+ 9, -73, 30, -32, // Update.
+ -2, 64, -26, 91, // Reset.
+ -48, -24, -28, -104, // Output.
+ // Input 3.
+ 74, -46, 116, 15, // Update.
+ 32, 52, -126, -38, // Reset.
+ -121, 12, -16, 110, // Output.
+ // Input 4.
+ -95, 66, -103, -35, // Update.
+ -38, 3, -126, -61, // Reset.
+ 28, 98, -117, -43 // Output.
+};
+constexpr std::array<int8_t, 48> kGruRecurrentWeights = {
+ // Output 0.
+ -3, 87, 50, 51, // Update.
+ -22, 27, -39, 62, // Reset.
+ 31, -83, -52, -48, // Output.
+ // Output 1.
+ -6, 83, -19, 104, // Update.
+ 105, 48, 23, 68, // Reset.
+ 23, 40, 7, -120, // Output.
+ // Output 2.
+ 64, -62, 117, 85, // Update.
+ 51, -43, 54, -105, // Reset.
+ 120, 56, -128, -107, // Output.
+ // Output 3.
+ 39, 50, -17, -47, // Update.
+ -117, 14, 108, 12, // Reset.
+ -7, -72, 103, -87, // Output.
+};
+constexpr std::array<float, 20> kGruInputSequence = {
+ 0.89395463f, 0.93224651f, 0.55788344f, 0.32341808f, 0.93355054f,
+ 0.13475326f, 0.97370994f, 0.14253306f, 0.93710381f, 0.76093364f,
+ 0.65780413f, 0.41657975f, 0.49403164f, 0.46843281f, 0.75138855f,
+ 0.24517593f, 0.47657707f, 0.57064998f, 0.435184f, 0.19319285f};
+constexpr std::array<float, 16> kGruExpectedOutputSequence = {
+ 0.0239123f, 0.5773077f, 0.f, 0.f,
+ 0.01282811f, 0.64330572f, 0.f, 0.04863098f,
+ 0.00781069f, 0.75267816f, 0.f, 0.02579715f,
+ 0.00471378f, 0.59162533f, 0.11087593f, 0.01334511f};
+
+// Checks that the output of a GRU layer is within tolerance given test input
+// data.
+TEST(RnnVadTest, CheckGatedRecurrentLayer) {
+ GatedRecurrentLayer gru(kGruInputSize, kGruOutputSize, kGruBias, kGruWeights,
+ kGruRecurrentWeights, /*layer_name=*/"GRU");
+ TestGatedRecurrentLayer(gru, kGruInputSequence, kGruExpectedOutputSequence);
+}
+
+TEST(RnnVadTest, DISABLED_BenchmarkGatedRecurrentLayer) {
+ GatedRecurrentLayer gru(kGruInputSize, kGruOutputSize, kGruBias, kGruWeights,
+ kGruRecurrentWeights, /*layer_name=*/"GRU");
+
+ rtc::ArrayView<const float> input_sequence(kGruInputSequence);
+ static_assert(kGruInputSequence.size() % kGruInputSize == 0, "");
+ constexpr int input_sequence_length =
+ kGruInputSequence.size() / kGruInputSize;
+
+ constexpr int kNumTests = 10000;
+ ::webrtc::test::PerformanceTimer perf_timer(kNumTests);
+ for (int k = 0; k < kNumTests; ++k) {
+ perf_timer.StartTimer();
+ for (int i = 0; i < input_sequence_length; ++i) {
+ gru.ComputeOutput(
+ input_sequence.subview(i * gru.input_size(), gru.input_size()));
+ }
+ perf_timer.StopTimer();
+ }
+ RTC_LOG(LS_INFO) << (perf_timer.GetDurationAverage() / 1000) << " +/- "
+ << (perf_timer.GetDurationStandardDeviation() / 1000)
+ << " ms";
+}
+
+} // namespace
+} // namespace test
+} // namespace rnn_vad
+} // namespace webrtc
diff --git a/modules/audio_processing/agc2/rnn_vad/rnn_unittest.cc b/modules/audio_processing/agc2/rnn_vad/rnn_unittest.cc
index 4f42d11..1c314d1 100644
--- a/modules/audio_processing/agc2/rnn_vad/rnn_unittest.cc
+++ b/modules/audio_processing/agc2/rnn_vad/rnn_unittest.cc
@@ -10,18 +10,10 @@
#include "modules/audio_processing/agc2/rnn_vad/rnn.h"
-#include <array>
-#include <memory>
-#include <vector>
-
+#include "api/array_view.h"
#include "modules/audio_processing/agc2/cpu_features.h"
-#include "modules/audio_processing/agc2/rnn_vad/test_utils.h"
-#include "modules/audio_processing/test/performance_timer.h"
-#include "rtc_base/checks.h"
-#include "rtc_base/logging.h"
-#include "rtc_base/numerics/safe_conversions.h"
+#include "modules/audio_processing/agc2/rnn_vad/common.h"
#include "test/gtest.h"
-#include "third_party/rnnoise/src/rnn_vad_weights.h"
namespace webrtc {
namespace rnn_vad {
@@ -43,116 +35,6 @@
}
}
-void TestGatedRecurrentLayer(
- GatedRecurrentLayer& gru,
- rtc::ArrayView<const float> input_sequence,
- rtc::ArrayView<const float> expected_output_sequence) {
- const int input_sequence_length = rtc::CheckedDivExact(
- rtc::dchecked_cast<int>(input_sequence.size()), gru.input_size());
- const int output_sequence_length = rtc::CheckedDivExact(
- rtc::dchecked_cast<int>(expected_output_sequence.size()), gru.size());
- ASSERT_EQ(input_sequence_length, output_sequence_length)
- << "The test data length is invalid.";
- // Feed the GRU layer and check the output at every step.
- gru.Reset();
- for (int i = 0; i < input_sequence_length; ++i) {
- SCOPED_TRACE(i);
- gru.ComputeOutput(
- input_sequence.subview(i * gru.input_size(), gru.input_size()));
- const auto expected_output =
- expected_output_sequence.subview(i * gru.size(), gru.size());
- ExpectNearAbsolute(expected_output, gru, 3e-6f);
- }
-}
-
-// Gated recurrent units layer test data.
-constexpr int kGruInputSize = 5;
-constexpr int kGruOutputSize = 4;
-constexpr std::array<int8_t, 12> kGruBias = {96, -99, -81, -114, 49, 119,
- -118, 68, -76, 91, 121, 125};
-constexpr std::array<int8_t, 60> kGruWeights = {
- // Input 0.
- 124, 9, 1, 116, // Update.
- -66, -21, -118, -110, // Reset.
- 104, 75, -23, -51, // Output.
- // Input 1.
- -72, -111, 47, 93, // Update.
- 77, -98, 41, -8, // Reset.
- 40, -23, -43, -107, // Output.
- // Input 2.
- 9, -73, 30, -32, // Update.
- -2, 64, -26, 91, // Reset.
- -48, -24, -28, -104, // Output.
- // Input 3.
- 74, -46, 116, 15, // Update.
- 32, 52, -126, -38, // Reset.
- -121, 12, -16, 110, // Output.
- // Input 4.
- -95, 66, -103, -35, // Update.
- -38, 3, -126, -61, // Reset.
- 28, 98, -117, -43 // Output.
-};
-constexpr std::array<int8_t, 48> kGruRecurrentWeights = {
- // Output 0.
- -3, 87, 50, 51, // Update.
- -22, 27, -39, 62, // Reset.
- 31, -83, -52, -48, // Output.
- // Output 1.
- -6, 83, -19, 104, // Update.
- 105, 48, 23, 68, // Reset.
- 23, 40, 7, -120, // Output.
- // Output 2.
- 64, -62, 117, 85, // Update.
- 51, -43, 54, -105, // Reset.
- 120, 56, -128, -107, // Output.
- // Output 3.
- 39, 50, -17, -47, // Update.
- -117, 14, 108, 12, // Reset.
- -7, -72, 103, -87, // Output.
-};
-constexpr std::array<float, 20> kGruInputSequence = {
- 0.89395463f, 0.93224651f, 0.55788344f, 0.32341808f, 0.93355054f,
- 0.13475326f, 0.97370994f, 0.14253306f, 0.93710381f, 0.76093364f,
- 0.65780413f, 0.41657975f, 0.49403164f, 0.46843281f, 0.75138855f,
- 0.24517593f, 0.47657707f, 0.57064998f, 0.435184f, 0.19319285f};
-constexpr std::array<float, 16> kGruExpectedOutputSequence = {
- 0.0239123f, 0.5773077f, 0.f, 0.f,
- 0.01282811f, 0.64330572f, 0.f, 0.04863098f,
- 0.00781069f, 0.75267816f, 0.f, 0.02579715f,
- 0.00471378f, 0.59162533f, 0.11087593f, 0.01334511f};
-
-// Checks that the output of a GRU layer is within tolerance given test input
-// data.
-TEST(RnnVadTest, CheckGatedRecurrentLayer) {
- GatedRecurrentLayer gru(kGruInputSize, kGruOutputSize, kGruBias, kGruWeights,
- kGruRecurrentWeights);
- TestGatedRecurrentLayer(gru, kGruInputSequence, kGruExpectedOutputSequence);
-}
-
-TEST(RnnVadTest, DISABLED_BenchmarkGatedRecurrentLayer) {
- GatedRecurrentLayer gru(kGruInputSize, kGruOutputSize, kGruBias, kGruWeights,
- kGruRecurrentWeights);
-
- rtc::ArrayView<const float> input_sequence(kGruInputSequence);
- static_assert(kGruInputSequence.size() % kGruInputSize == 0, "");
- constexpr int input_sequence_length =
- kGruInputSequence.size() / kGruInputSize;
-
- constexpr int kNumTests = 10000;
- ::webrtc::test::PerformanceTimer perf_timer(kNumTests);
- for (int k = 0; k < kNumTests; ++k) {
- perf_timer.StartTimer();
- for (int i = 0; i < input_sequence_length; ++i) {
- gru.ComputeOutput(
- input_sequence.subview(i * gru.input_size(), gru.input_size()));
- }
- perf_timer.StopTimer();
- }
- RTC_LOG(LS_INFO) << (perf_timer.GetDurationAverage() / 1000) << " +/- "
- << (perf_timer.GetDurationStandardDeviation() / 1000)
- << " ms";
-}
-
// Checks that the speech probability is zero with silence.
TEST(RnnVadTest, CheckZeroProbabilityWithSilence) {
RnnVad rnn_vad(GetAvailableCpuFeatures());