RNN VAD: GRU layer optimized
Using `VectorMath::DotProduct()` in GatedRecurrentLayer to reuse existing
SIMD optimizations. Results:
- When SSE2/AVX2 is avilable, the GRU layer takes 40% of the unoptimized
code
- The realtime factor for the VAD improved as follows
- SSE2: from 570x to 630x
- AVX2: from 610x to 680x
This CL also improved the GRU layer benchmark by (i) benchmarking a GRU
layer havibng the same size of that used in the VAD and (ii) by prefetching
a long input sequence.
Bug: webrtc:10480
Change-Id: I9716b15661e4c6b81592b4cf7c172d90e41b5223
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/195545
Reviewed-by: Per Åhgren <peah@webrtc.org>
Commit-Queue: Alessio Bazzica <alessiob@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#32803}
diff --git a/modules/audio_processing/agc2/rnn_vad/BUILD.gn b/modules/audio_processing/agc2/rnn_vad/BUILD.gn
index 29cdfeb..ef2370c 100644
--- a/modules/audio_processing/agc2/rnn_vad/BUILD.gn
+++ b/modules/audio_processing/agc2/rnn_vad/BUILD.gn
@@ -86,6 +86,7 @@
]
deps = [
":rnn_vad_common",
+ ":vector_math",
"..:cpu_features",
"../../../../api:array_view",
"../../../../api:function_view",
@@ -94,6 +95,9 @@
"../../../../rtc_base/system:arch",
"//third_party/rnnoise:rnn_vad",
]
+ if (current_cpu == "x86" || current_cpu == "x64") {
+ deps += [ ":vector_math_avx2" ]
+ }
absl_deps = [ "//third_party/abseil-cpp/absl/strings" ]
}
diff --git a/modules/audio_processing/agc2/rnn_vad/rnn.cc b/modules/audio_processing/agc2/rnn_vad/rnn.cc
index c1bded1..f828a24 100644
--- a/modules/audio_processing/agc2/rnn_vad/rnn.cc
+++ b/modules/audio_processing/agc2/rnn_vad/rnn.cc
@@ -50,6 +50,7 @@
kHiddenGruBias,
kHiddenGruWeights,
kHiddenGruRecurrentWeights,
+ cpu_features,
/*layer_name=*/"GRU1"),
output_(kHiddenLayerOutputSize,
kOutputLayerOutputSize,
diff --git a/modules/audio_processing/agc2/rnn_vad/rnn_fc_unittest.cc b/modules/audio_processing/agc2/rnn_vad/rnn_fc_unittest.cc
index c586ed2..900ce63 100644
--- a/modules/audio_processing/agc2/rnn_vad/rnn_fc_unittest.cc
+++ b/modules/audio_processing/agc2/rnn_vad/rnn_fc_unittest.cc
@@ -46,12 +46,12 @@
0.983443f, 0.999991f, -0.824335f, 0.984742f, 0.990208f, 0.938179f,
0.875092f, 0.999846f, 0.997707f, -0.999382f, 0.973153f, -0.966605f};
-class RnnParametrization
+class RnnFcParametrization
: public ::testing::TestWithParam<AvailableCpuFeatures> {};
// Checks that the output of a fully connected layer is within tolerance given
// test input data.
-TEST_P(RnnParametrization, CheckFullyConnectedLayerOutput) {
+TEST_P(RnnFcParametrization, CheckFullyConnectedLayerOutput) {
FullyConnectedLayer fc(kInputLayerInputSize, kInputLayerOutputSize,
kInputDenseBias, kInputDenseWeights,
ActivationFunction::kTansigApproximated,
@@ -61,7 +61,7 @@
ExpectNearAbsolute(kFullyConnectedExpectedOutput, fc, 1e-5f);
}
-TEST_P(RnnParametrization, DISABLED_BenchmarkFullyConnectedLayer) {
+TEST_P(RnnFcParametrization, DISABLED_BenchmarkFullyConnectedLayer) {
const AvailableCpuFeatures cpu_features = GetParam();
FullyConnectedLayer fc(kInputLayerInputSize, kInputLayerOutputSize,
kInputDenseBias, kInputDenseWeights,
@@ -87,16 +87,14 @@
v.push_back({/*sse2=*/false, /*avx2=*/false, /*neon=*/false});
AvailableCpuFeatures available = GetAvailableCpuFeatures();
if (available.sse2) {
- AvailableCpuFeatures features(
- {/*sse2=*/true, /*avx2=*/false, /*neon=*/false});
- v.push_back(features);
+ v.push_back({/*sse2=*/true, /*avx2=*/false, /*neon=*/false});
}
return v;
}
INSTANTIATE_TEST_SUITE_P(
RnnVadTest,
- RnnParametrization,
+ RnnFcParametrization,
::testing::ValuesIn(GetCpuFeaturesToTest()),
[](const ::testing::TestParamInfo<AvailableCpuFeatures>& info) {
return info.param.ToString();
diff --git a/modules/audio_processing/agc2/rnn_vad/rnn_gru.cc b/modules/audio_processing/agc2/rnn_vad/rnn_gru.cc
index f37fc2a..482016e 100644
--- a/modules/audio_processing/agc2/rnn_vad/rnn_gru.cc
+++ b/modules/audio_processing/agc2/rnn_vad/rnn_gru.cc
@@ -43,47 +43,79 @@
return tensor_dst;
}
-void ComputeGruUpdateResetGates(int input_size,
- int output_size,
- rtc::ArrayView<const float> weights,
- rtc::ArrayView<const float> recurrent_weights,
- rtc::ArrayView<const float> bias,
- rtc::ArrayView<const float> input,
- rtc::ArrayView<const float> state,
- rtc::ArrayView<float> gate) {
+// Computes the output for the update or the reset gate.
+// Operation: `g = sigmoid(W^T∙i + R^T∙s + b)` where
+// - `g`: output gate vector
+// - `W`: weights matrix
+// - `i`: input vector
+// - `R`: recurrent weights matrix
+// - `s`: state gate vector
+// - `b`: bias vector
+void ComputeUpdateResetGate(int input_size,
+ int output_size,
+ const VectorMath& vector_math,
+ rtc::ArrayView<const float> input,
+ rtc::ArrayView<const float> state,
+ rtc::ArrayView<const float> bias,
+ rtc::ArrayView<const float> weights,
+ rtc::ArrayView<const float> recurrent_weights,
+ rtc::ArrayView<float> gate) {
+ RTC_DCHECK_EQ(input.size(), input_size);
+ RTC_DCHECK_EQ(state.size(), output_size);
+ RTC_DCHECK_EQ(bias.size(), output_size);
+ RTC_DCHECK_EQ(weights.size(), input_size * output_size);
+ RTC_DCHECK_EQ(recurrent_weights.size(), output_size * output_size);
+ RTC_DCHECK_GE(gate.size(), output_size); // `gate` is over-allocated.
for (int o = 0; o < output_size; ++o) {
- gate[o] = bias[o];
- for (int i = 0; i < input_size; ++i) {
- gate[o] += input[i] * weights[o * input_size + i];
- }
- for (int s = 0; s < output_size; ++s) {
- gate[o] += state[s] * recurrent_weights[o * output_size + s];
- }
- gate[o] = ::rnnoise::SigmoidApproximated(gate[o]);
+ float x = bias[o];
+ x += vector_math.DotProduct(input,
+ weights.subview(o * input_size, input_size));
+ x += vector_math.DotProduct(
+ state, recurrent_weights.subview(o * output_size, output_size));
+ gate[o] = ::rnnoise::SigmoidApproximated(x);
}
}
-void ComputeGruOutputGate(int input_size,
- int output_size,
- rtc::ArrayView<const float> weights,
- rtc::ArrayView<const float> recurrent_weights,
- rtc::ArrayView<const float> bias,
- rtc::ArrayView<const float> input,
- rtc::ArrayView<const float> state,
- rtc::ArrayView<const float> reset,
- rtc::ArrayView<float> gate) {
+// Computes the output for the state gate.
+// Operation: `s' = u .* s + (1 - u) .* ReLU(W^T∙i + R^T∙(s .* r) + b)` where
+// - `s'`: output state gate vector
+// - `s`: previous state gate vector
+// - `u`: update gate vector
+// - `W`: weights matrix
+// - `i`: input vector
+// - `R`: recurrent weights matrix
+// - `r`: reset gate vector
+// - `b`: bias vector
+// - `.*` element-wise product
+void ComputeStateGate(int input_size,
+ int output_size,
+ const VectorMath& vector_math,
+ rtc::ArrayView<const float> input,
+ rtc::ArrayView<const float> update,
+ rtc::ArrayView<const float> reset,
+ rtc::ArrayView<const float> bias,
+ rtc::ArrayView<const float> weights,
+ rtc::ArrayView<const float> recurrent_weights,
+ rtc::ArrayView<float> state) {
+ RTC_DCHECK_EQ(input.size(), input_size);
+ RTC_DCHECK_GE(update.size(), output_size); // `update` is over-allocated.
+ RTC_DCHECK_GE(reset.size(), output_size); // `reset` is over-allocated.
+ RTC_DCHECK_EQ(bias.size(), output_size);
+ RTC_DCHECK_EQ(weights.size(), input_size * output_size);
+ RTC_DCHECK_EQ(recurrent_weights.size(), output_size * output_size);
+ RTC_DCHECK_EQ(state.size(), output_size);
+ std::array<float, kGruLayerMaxUnits> reset_x_state;
for (int o = 0; o < output_size; ++o) {
- gate[o] = bias[o];
- for (int i = 0; i < input_size; ++i) {
- gate[o] += input[i] * weights[o * input_size + i];
- }
- for (int s = 0; s < output_size; ++s) {
- gate[o] += state[s] * recurrent_weights[o * output_size + s] * reset[s];
- }
- // Rectified linear unit.
- if (gate[o] < 0.f) {
- gate[o] = 0.f;
- }
+ reset_x_state[o] = state[o] * reset[o];
+ }
+ for (int o = 0; o < output_size; ++o) {
+ float x = bias[o];
+ x += vector_math.DotProduct(input,
+ weights.subview(o * input_size, input_size));
+ x += vector_math.DotProduct(
+ {reset_x_state.data(), static_cast<size_t>(output_size)},
+ recurrent_weights.subview(o * output_size, output_size));
+ state[o] = update[o] * state[o] + (1.f - update[o]) * std::max(0.f, x);
}
}
@@ -95,12 +127,14 @@
const rtc::ArrayView<const int8_t> bias,
const rtc::ArrayView<const int8_t> weights,
const rtc::ArrayView<const int8_t> recurrent_weights,
+ const AvailableCpuFeatures& cpu_features,
absl::string_view layer_name)
: input_size_(input_size),
output_size_(output_size),
bias_(PreprocessGruTensor(bias, output_size)),
weights_(PreprocessGruTensor(weights, output_size)),
- recurrent_weights_(PreprocessGruTensor(recurrent_weights, output_size)) {
+ recurrent_weights_(PreprocessGruTensor(recurrent_weights, output_size)),
+ vector_math_(cpu_features) {
RTC_DCHECK_LE(output_size_, kGruLayerMaxUnits)
<< "Insufficient GRU layer over-allocation (" << layer_name << ").";
RTC_DCHECK_EQ(kNumGruGates * output_size_, bias_.size())
@@ -126,44 +160,38 @@
void GatedRecurrentLayer::ComputeOutput(rtc::ArrayView<const float> input) {
RTC_DCHECK_EQ(input.size(), input_size_);
- // TODO(bugs.chromium.org/10480): Add AVX2.
- // TODO(bugs.chromium.org/10480): Add Neon.
-
- // Stride and offset used to read parameter arrays.
- const int stride_in = input_size_ * output_size_;
- const int stride_out = output_size_ * output_size_;
-
+ // The tensors below are organized as a sequence of flattened tensors for the
+ // `update`, `reset` and `state` gates.
rtc::ArrayView<const float> bias(bias_);
rtc::ArrayView<const float> weights(weights_);
rtc::ArrayView<const float> recurrent_weights(recurrent_weights_);
+ // Strides to access to the flattened tensors for a specific gate.
+ const int stride_weights = input_size_ * output_size_;
+ const int stride_recurrent_weights = output_size_ * output_size_;
+
+ rtc::ArrayView<float> state(state_.data(), output_size_);
// Update gate.
std::array<float, kGruLayerMaxUnits> update;
- ComputeGruUpdateResetGates(
- input_size_, output_size_, weights.subview(0, stride_in),
- recurrent_weights.subview(0, stride_out), bias.subview(0, output_size_),
- input, state_, update);
-
+ ComputeUpdateResetGate(
+ input_size_, output_size_, vector_math_, input, state,
+ bias.subview(0, output_size_), weights.subview(0, stride_weights),
+ recurrent_weights.subview(0, stride_recurrent_weights), update);
// Reset gate.
std::array<float, kGruLayerMaxUnits> reset;
- ComputeGruUpdateResetGates(
- input_size_, output_size_, weights.subview(stride_in, stride_in),
- recurrent_weights.subview(stride_out, stride_out),
- bias.subview(output_size_, output_size_), input, state_, reset);
-
- // Output gate.
- std::array<float, kGruLayerMaxUnits> output;
- ComputeGruOutputGate(input_size_, output_size_,
- weights.subview(2 * stride_in, stride_in),
- recurrent_weights.subview(2 * stride_out, stride_out),
- bias.subview(2 * output_size_, output_size_), input,
- state_, reset, output);
-
- // Update output through the update gates and update the state.
- for (int o = 0; o < output_size_; ++o) {
- output[o] = update[o] * state_[o] + (1.f - update[o]) * output[o];
- state_[o] = output[o];
- }
+ ComputeUpdateResetGate(input_size_, output_size_, vector_math_, input, state,
+ bias.subview(output_size_, output_size_),
+ weights.subview(stride_weights, stride_weights),
+ recurrent_weights.subview(stride_recurrent_weights,
+ stride_recurrent_weights),
+ reset);
+ // State gate.
+ ComputeStateGate(input_size_, output_size_, vector_math_, input, update,
+ reset, bias.subview(2 * output_size_, output_size_),
+ weights.subview(2 * stride_weights, stride_weights),
+ recurrent_weights.subview(2 * stride_recurrent_weights,
+ stride_recurrent_weights),
+ state);
}
} // namespace rnn_vad
diff --git a/modules/audio_processing/agc2/rnn_vad/rnn_gru.h b/modules/audio_processing/agc2/rnn_vad/rnn_gru.h
index f66b048..3407dfcd 100644
--- a/modules/audio_processing/agc2/rnn_vad/rnn_gru.h
+++ b/modules/audio_processing/agc2/rnn_vad/rnn_gru.h
@@ -17,6 +17,7 @@
#include "absl/strings/string_view.h"
#include "api/array_view.h"
#include "modules/audio_processing/agc2/cpu_features.h"
+#include "modules/audio_processing/agc2/rnn_vad/vector_math.h"
namespace webrtc {
namespace rnn_vad {
@@ -34,6 +35,7 @@
rtc::ArrayView<const int8_t> bias,
rtc::ArrayView<const int8_t> weights,
rtc::ArrayView<const int8_t> recurrent_weights,
+ const AvailableCpuFeatures& cpu_features,
absl::string_view layer_name);
GatedRecurrentLayer(const GatedRecurrentLayer&) = delete;
GatedRecurrentLayer& operator=(const GatedRecurrentLayer&) = delete;
@@ -57,6 +59,7 @@
const std::vector<float> bias_;
const std::vector<float> weights_;
const std::vector<float> recurrent_weights_;
+ const VectorMath vector_math_;
// Over-allocated array with size equal to `output_size_`.
std::array<float, kGruLayerMaxUnits> state_;
};
diff --git a/modules/audio_processing/agc2/rnn_vad/rnn_gru_unittest.cc b/modules/audio_processing/agc2/rnn_vad/rnn_gru_unittest.cc
index 4e8b524..ee8bdac 100644
--- a/modules/audio_processing/agc2/rnn_vad/rnn_gru_unittest.cc
+++ b/modules/audio_processing/agc2/rnn_vad/rnn_gru_unittest.cc
@@ -11,6 +11,8 @@
#include "modules/audio_processing/agc2/rnn_vad/rnn_gru.h"
#include <array>
+#include <memory>
+#include <vector>
#include "api/array_view.h"
#include "modules/audio_processing/agc2/rnn_vad/test_utils.h"
@@ -18,6 +20,7 @@
#include "rtc_base/checks.h"
#include "rtc_base/logging.h"
#include "test/gtest.h"
+#include "third_party/rnnoise/src/rnn_vad_weights.h"
namespace webrtc {
namespace rnn_vad {
@@ -101,24 +104,44 @@
0.00781069f, 0.75267816f, 0.f, 0.02579715f,
0.00471378f, 0.59162533f, 0.11087593f, 0.01334511f};
+class RnnGruParametrization
+ : public ::testing::TestWithParam<AvailableCpuFeatures> {};
+
// Checks that the output of a GRU layer is within tolerance given test input
// data.
-TEST(RnnVadTest, CheckGatedRecurrentLayer) {
+TEST_P(RnnGruParametrization, CheckGatedRecurrentLayer) {
GatedRecurrentLayer gru(kGruInputSize, kGruOutputSize, kGruBias, kGruWeights,
- kGruRecurrentWeights, /*layer_name=*/"GRU");
+ kGruRecurrentWeights,
+ /*cpu_features=*/GetParam(),
+ /*layer_name=*/"GRU");
TestGatedRecurrentLayer(gru, kGruInputSequence, kGruExpectedOutputSequence);
}
-TEST(RnnVadTest, DISABLED_BenchmarkGatedRecurrentLayer) {
- GatedRecurrentLayer gru(kGruInputSize, kGruOutputSize, kGruBias, kGruWeights,
- kGruRecurrentWeights, /*layer_name=*/"GRU");
+TEST_P(RnnGruParametrization, DISABLED_BenchmarkGatedRecurrentLayer) {
+ // Prefetch test data.
+ std::unique_ptr<FileReader> reader = CreateGruInputReader();
+ std::vector<float> gru_input_sequence(reader->size());
+ reader->ReadChunk(gru_input_sequence);
- rtc::ArrayView<const float> input_sequence(kGruInputSequence);
- static_assert(kGruInputSequence.size() % kGruInputSize == 0, "");
- constexpr int input_sequence_length =
- kGruInputSequence.size() / kGruInputSize;
+ using ::rnnoise::kHiddenGruBias;
+ using ::rnnoise::kHiddenGruRecurrentWeights;
+ using ::rnnoise::kHiddenGruWeights;
+ using ::rnnoise::kHiddenLayerOutputSize;
+ using ::rnnoise::kInputLayerOutputSize;
- constexpr int kNumTests = 10000;
+ GatedRecurrentLayer gru(kInputLayerOutputSize, kHiddenLayerOutputSize,
+ kHiddenGruBias, kHiddenGruWeights,
+ kHiddenGruRecurrentWeights,
+ /*cpu_features=*/GetParam(),
+ /*layer_name=*/"GRU");
+
+ rtc::ArrayView<const float> input_sequence(gru_input_sequence);
+ ASSERT_EQ(input_sequence.size() % kInputLayerOutputSize,
+ static_cast<size_t>(0));
+ const int input_sequence_length =
+ input_sequence.size() / kInputLayerOutputSize;
+
+ constexpr int kNumTests = 100;
::webrtc::test::PerformanceTimer perf_timer(kNumTests);
for (int k = 0; k < kNumTests; ++k) {
perf_timer.StartTimer();
@@ -133,6 +156,28 @@
<< " ms";
}
+// Finds the relevant CPU features combinations to test.
+std::vector<AvailableCpuFeatures> GetCpuFeaturesToTest() {
+ std::vector<AvailableCpuFeatures> v;
+ AvailableCpuFeatures available = GetAvailableCpuFeatures();
+ v.push_back({/*sse2=*/false, /*avx2=*/false, /*neon=*/false});
+ if (available.avx2) {
+ v.push_back({/*sse2=*/false, /*avx2=*/true, /*neon=*/false});
+ }
+ if (available.sse2) {
+ v.push_back({/*sse2=*/true, /*avx2=*/false, /*neon=*/false});
+ }
+ return v;
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ RnnVadTest,
+ RnnGruParametrization,
+ ::testing::ValuesIn(GetCpuFeaturesToTest()),
+ [](const ::testing::TestParamInfo<AvailableCpuFeatures>& info) {
+ return info.param.ToString();
+ });
+
} // namespace
} // namespace rnn_vad
} // namespace webrtc
diff --git a/modules/audio_processing/agc2/rnn_vad/test_utils.cc b/modules/audio_processing/agc2/rnn_vad/test_utils.cc
index 3db6774..b8ca9c3 100644
--- a/modules/audio_processing/agc2/rnn_vad/test_utils.cc
+++ b/modules/audio_processing/agc2/rnn_vad/test_utils.cc
@@ -111,6 +111,12 @@
return {kChunkSize, num_chunks, std::move(reader)};
}
+std::unique_ptr<FileReader> CreateGruInputReader() {
+ return std::make_unique<FloatFileReader<float>>(
+ /*filename=*/test::ResourcePath("audio_processing/agc2/rnn_vad/gru_in",
+ "dat"));
+}
+
std::unique_ptr<FileReader> CreateVadProbsReader() {
return std::make_unique<FloatFileReader<float>>(
/*filename=*/test::ResourcePath("audio_processing/agc2/rnn_vad/vad_prob",
diff --git a/modules/audio_processing/agc2/rnn_vad/test_utils.h b/modules/audio_processing/agc2/rnn_vad/test_utils.h
index 86af5e0..e366e18 100644
--- a/modules/audio_processing/agc2/rnn_vad/test_utils.h
+++ b/modules/audio_processing/agc2/rnn_vad/test_utils.h
@@ -77,6 +77,9 @@
// Creates a reader for the LP residual and pitch information test data.
ChunksFileReader CreateLpResidualAndPitchInfoReader();
+// Creates a reader for the sequence of GRU input vectors.
+std::unique_ptr<FileReader> CreateGruInputReader();
+
// Creates a reader for the VAD probabilities test data.
std::unique_ptr<FileReader> CreateVadProbsReader();
diff --git a/resources/audio_processing/agc2/rnn_vad/gru_in.dat.sha1 b/resources/audio_processing/agc2/rnn_vad/gru_in.dat.sha1
new file mode 100644
index 0000000..f78c40e
--- /dev/null
+++ b/resources/audio_processing/agc2/rnn_vad/gru_in.dat.sha1
@@ -0,0 +1 @@
+402abf7a4e5d35abb78906fff2b3f4d8d24aa629
\ No newline at end of file