RNN VAD: cast and scale quantized weights at init
This CL has two goals: (i) avoid casting and scaling of the NN weights
for every processed feature vector and (ii) prepare for SIMD
optimizations.
Bug: webrtc:10480
Change-Id: Ice7bac5657123354714cc7c63b00abbb8a76c7d7
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/141413
Commit-Queue: Alessio Bazzica <alessiob@webrtc.org>
Reviewed-by: Fredrik Hernqvist <fhernqvist@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#29675}
diff --git a/modules/audio_processing/agc2/rnn_vad/rnn.cc b/modules/audio_processing/agc2/rnn_vad/rnn.cc
index a5b34c4..94cc254 100644
--- a/modules/audio_processing/agc2/rnn_vad/rnn.cc
+++ b/modules/audio_processing/agc2/rnn_vad/rnn.cc
@@ -44,10 +44,26 @@
static_assert(kOutputLayerOutputSize <= kFullyConnectedLayersMaxUnits,
"Increase kFullyConnectedLayersMaxUnits.");
-using rnnoise::RectifiedLinearUnit;
using rnnoise::SigmoidApproximated;
using rnnoise::TansigApproximated;
+namespace {
+
+inline float RectifiedLinearUnit(float x) {
+ return x < 0.f ? 0.f : x;
+}
+
+std::vector<float> GetScaledParams(rtc::ArrayView<const int8_t> params) {
+ std::vector<float> scaled_params(params.size());
+ std::transform(params.begin(), params.end(), scaled_params.begin(),
+ [](int8_t x) -> float {
+ return rnnoise::kWeightsScale * static_cast<float>(x);
+ });
+ return scaled_params;
+}
+
+} // namespace
+
FullyConnectedLayer::FullyConnectedLayer(
const size_t input_size,
const size_t output_size,
@@ -56,8 +72,8 @@
float (*const activation_function)(float))
: input_size_(input_size),
output_size_(output_size),
- bias_(bias),
- weights_(weights),
+ bias_(GetScaledParams(bias)),
+ weights_(GetScaledParams(weights)),
activation_function_(activation_function) {
RTC_DCHECK_LE(output_size_, kFullyConnectedLayersMaxUnits)
<< "Static over-allocation of fully-connected layers output vectors is "
@@ -84,7 +100,7 @@
for (size_t i = 0; i < input_size_; ++i) {
output_[o] += input[i] * weights_[i * output_size_ + o];
}
- output_[o] = (*activation_function_)(kWeightsScale * output_[o]);
+ output_[o] = (*activation_function_)(output_[o]);
}
}
@@ -93,14 +109,12 @@
const size_t output_size,
const rtc::ArrayView<const int8_t> bias,
const rtc::ArrayView<const int8_t> weights,
- const rtc::ArrayView<const int8_t> recurrent_weights,
- float (*const activation_function)(float))
+ const rtc::ArrayView<const int8_t> recurrent_weights)
: input_size_(input_size),
output_size_(output_size),
- bias_(bias),
- weights_(weights),
- recurrent_weights_(recurrent_weights),
- activation_function_(activation_function) {
+ bias_(GetScaledParams(bias)),
+ weights_(GetScaledParams(weights)),
+ recurrent_weights_(GetScaledParams(recurrent_weights)) {
RTC_DCHECK_LE(output_size_, kRecurrentLayersMaxUnits)
<< "Static over-allocation of recurrent layers state vectors is not "
<< "sufficient.";
@@ -144,7 +158,7 @@
for (size_t s = 0; s < output_size_; ++s) {
update[o] += state_[s] * recurrent_weights_[s * stride + o];
} // Add state.
- update[o] = SigmoidApproximated(kWeightsScale * update[o]);
+ update[o] = SigmoidApproximated(update[o]);
}
// Compute reset gates.
@@ -158,7 +172,7 @@
for (size_t s = 0; s < output_size_; ++s) { // Add state.
reset[o] += state_[s] * recurrent_weights_[offset + s * stride + o];
}
- reset[o] = SigmoidApproximated(kWeightsScale * reset[o]);
+ reset[o] = SigmoidApproximated(reset[o]);
}
// Compute output.
@@ -174,7 +188,7 @@
output[o] +=
state_[s] * recurrent_weights_[offset + s * stride + o] * reset[s];
}
- output[o] = (*activation_function_)(kWeightsScale * output[o]);
+ output[o] = RectifiedLinearUnit(output[o]);
// Update output through the update gates.
output[o] = update[o] * state_[o] + (1.f - update[o]) * output[o];
}
@@ -194,8 +208,7 @@
kHiddenLayerOutputSize,
kHiddenGruBias,
kHiddenGruWeights,
- kHiddenGruRecurrentWeights,
- RectifiedLinearUnit),
+ kHiddenGruRecurrentWeights),
output_layer_(kHiddenLayerOutputSize,
kOutputLayerOutputSize,
kOutputDenseBias,
diff --git a/modules/audio_processing/agc2/rnn_vad/rnn.h b/modules/audio_processing/agc2/rnn_vad/rnn.h
index 1129464..c38ff01 100644
--- a/modules/audio_processing/agc2/rnn_vad/rnn.h
+++ b/modules/audio_processing/agc2/rnn_vad/rnn.h
@@ -15,6 +15,7 @@
#include <sys/types.h>
#include <array>
+#include <vector>
#include "api/array_view.h"
#include "modules/audio_processing/agc2/rnn_vad/common.h"
@@ -54,23 +55,23 @@
private:
const size_t input_size_;
const size_t output_size_;
- const rtc::ArrayView<const int8_t> bias_;
- const rtc::ArrayView<const int8_t> weights_;
+ const std::vector<float> bias_;
+ const std::vector<float> weights_;
float (*const activation_function_)(float);
// The output vector of a recurrent layer has length equal to |output_size_|.
// However, for efficiency, over-allocation is used.
std::array<float, kFullyConnectedLayersMaxUnits> output_;
};
-// Recurrent layer with gated recurrent units (GRUs).
+// Recurrent layer with gated recurrent units (GRUs) with sigmoid and ReLU as
+// activation functions for the update/reset and output gates respectively.
class GatedRecurrentLayer {
public:
GatedRecurrentLayer(const size_t input_size,
const size_t output_size,
const rtc::ArrayView<const int8_t> bias,
const rtc::ArrayView<const int8_t> weights,
- const rtc::ArrayView<const int8_t> recurrent_weights,
- float (*const activation_function)(float));
+ const rtc::ArrayView<const int8_t> recurrent_weights);
GatedRecurrentLayer(const GatedRecurrentLayer&) = delete;
GatedRecurrentLayer& operator=(const GatedRecurrentLayer&) = delete;
~GatedRecurrentLayer();
@@ -84,10 +85,9 @@
private:
const size_t input_size_;
const size_t output_size_;
- const rtc::ArrayView<const int8_t> bias_;
- const rtc::ArrayView<const int8_t> weights_;
- const rtc::ArrayView<const int8_t> recurrent_weights_;
- float (*const activation_function_)(float);
+ const std::vector<float> bias_;
+ const std::vector<float> weights_;
+ const std::vector<float> recurrent_weights_;
// The state vector of a recurrent layer has length equal to |output_size_|.
// However, to avoid dynamic allocation, over-allocation is used.
std::array<float, kRecurrentLayersMaxUnits> state_;
diff --git a/modules/audio_processing/agc2/rnn_vad/rnn_unittest.cc b/modules/audio_processing/agc2/rnn_vad/rnn_unittest.cc
index 40ac70b..61e6f26 100644
--- a/modules/audio_processing/agc2/rnn_vad/rnn_unittest.cc
+++ b/modules/audio_processing/agc2/rnn_vad/rnn_unittest.cc
@@ -123,8 +123,7 @@
64, -62, 117, 85, -51, -43, 54, -105, 120, 56, -128, -107,
39, 50, -17, -47, -117, 14, 108, 12, -7, -72, 103, -87,
-66, 82, 84, 100, -98, 102, -49, 44, 122, 106, -20, -69};
- GatedRecurrentLayer gru(5, 4, bias, weights, recurrent_weights,
- RectifiedLinearUnit);
+ GatedRecurrentLayer gru(5, 4, bias, weights, recurrent_weights);
// Test on different inputs.
{
const std::array<float, 20> input_sequence = {