RNN VAD: FC and GRU layers implicit conversion to ArrayView
Plus a few minor code readability improvements.
Bug: webrtc:10480
Change-Id: I590d8e203b1d05959a8c15373841e37abe83237e
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/195334
Commit-Queue: Alessio Bazzica <alessiob@webrtc.org>
Reviewed-by: Karl Wiberg <kwiberg@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#32764}
diff --git a/modules/audio_processing/agc2/rnn_vad/BUILD.gn b/modules/audio_processing/agc2/rnn_vad/BUILD.gn
index dbba6c1..4351afd 100644
--- a/modules/audio_processing/agc2/rnn_vad/BUILD.gn
+++ b/modules/audio_processing/agc2/rnn_vad/BUILD.gn
@@ -84,6 +84,7 @@
"..:cpu_features",
"../../../../api:array_view",
"../../../../rtc_base:checks",
+ "../../../../rtc_base:safe_conversions",
"../../../../rtc_base/system:arch",
]
}
@@ -103,6 +104,7 @@
":vector_math",
"../../../../api:array_view",
"../../../../rtc_base:checks",
+ "../../../../rtc_base:safe_conversions",
]
}
}
diff --git a/modules/audio_processing/agc2/rnn_vad/rnn.cc b/modules/audio_processing/agc2/rnn_vad/rnn.cc
index fb4962f..1c9b736 100644
--- a/modules/audio_processing/agc2/rnn_vad/rnn.cc
+++ b/modules/audio_processing/agc2/rnn_vad/rnn.cc
@@ -40,21 +40,18 @@
using rnnoise::kInputDenseBias;
using rnnoise::kInputDenseWeights;
using rnnoise::kInputLayerOutputSize;
-static_assert(kInputLayerOutputSize <= kFullyConnectedLayersMaxUnits,
- "Increase kFullyConnectedLayersMaxUnits.");
+static_assert(kInputLayerOutputSize <= kFullyConnectedLayerMaxUnits, "");
using rnnoise::kHiddenGruBias;
using rnnoise::kHiddenGruRecurrentWeights;
using rnnoise::kHiddenGruWeights;
using rnnoise::kHiddenLayerOutputSize;
-static_assert(kHiddenLayerOutputSize <= kRecurrentLayersMaxUnits,
- "Increase kRecurrentLayersMaxUnits.");
+static_assert(kHiddenLayerOutputSize <= kGruLayerMaxUnits, "");
using rnnoise::kOutputDenseBias;
using rnnoise::kOutputDenseWeights;
using rnnoise::kOutputLayerOutputSize;
-static_assert(kOutputLayerOutputSize <= kFullyConnectedLayersMaxUnits,
- "Increase kFullyConnectedLayersMaxUnits.");
+static_assert(kOutputLayerOutputSize <= kFullyConnectedLayerMaxUnits, "");
using rnnoise::SigmoidApproximated;
using rnnoise::TansigApproximated;
@@ -178,21 +175,21 @@
const int stride_out = output_size * output_size;
// Update gate.
- std::array<float, kRecurrentLayersMaxUnits> update;
+ std::array<float, kGruLayerMaxUnits> update;
ComputeGruUpdateResetGates(
input_size, output_size, weights.subview(0, stride_in),
recurrent_weights.subview(0, stride_out), bias.subview(0, output_size),
input, state, update);
// Reset gate.
- std::array<float, kRecurrentLayersMaxUnits> reset;
+ std::array<float, kGruLayerMaxUnits> reset;
ComputeGruUpdateResetGates(
input_size, output_size, weights.subview(stride_in, stride_in),
recurrent_weights.subview(stride_out, stride_out),
bias.subview(output_size, output_size), input, state, reset);
// Output gate.
- std::array<float, kRecurrentLayersMaxUnits> output;
+ std::array<float, kGruLayerMaxUnits> output;
ComputeGruOutputGate(
input_size, output_size, weights.subview(2 * stride_in, stride_in),
recurrent_weights.subview(2 * stride_out, stride_out),
@@ -279,7 +276,7 @@
weights_(GetPreprocessedFcWeights(weights, output_size)),
activation_function_(activation_function),
cpu_features_(cpu_features) {
- RTC_DCHECK_LE(output_size_, kFullyConnectedLayersMaxUnits)
+ RTC_DCHECK_LE(output_size_, kFullyConnectedLayerMaxUnits)
<< "Static over-allocation of fully-connected layers output vectors is "
"not sufficient.";
RTC_DCHECK_EQ(output_size_, bias_.size())
@@ -290,10 +287,6 @@
FullyConnectedLayer::~FullyConnectedLayer() = default;
-rtc::ArrayView<const float> FullyConnectedLayer::GetOutput() const {
- return rtc::ArrayView<const float>(output_.data(), output_size_);
-}
-
void FullyConnectedLayer::ComputeOutput(rtc::ArrayView<const float> input) {
#if defined(WEBRTC_ARCH_X86_FAMILY)
// TODO(bugs.chromium.org/10480): Add AVX2.
@@ -321,7 +314,7 @@
weights_(GetPreprocessedGruTensor(weights, output_size)),
recurrent_weights_(
GetPreprocessedGruTensor(recurrent_weights, output_size)) {
- RTC_DCHECK_LE(output_size_, kRecurrentLayersMaxUnits)
+ RTC_DCHECK_LE(output_size_, kGruLayerMaxUnits)
<< "Static over-allocation of recurrent layers state vectors is not "
"sufficient.";
RTC_DCHECK_EQ(kNumGruGates * output_size_, bias_.size())
@@ -337,10 +330,6 @@
GatedRecurrentLayer::~GatedRecurrentLayer() = default;
-rtc::ArrayView<const float> GatedRecurrentLayer::GetOutput() const {
- return rtc::ArrayView<const float>(state_.data(), output_size_);
-}
-
void GatedRecurrentLayer::Reset() {
state_.fill(0.f);
}
@@ -352,49 +341,49 @@
recurrent_weights_, bias_, state_);
}
-RnnBasedVad::RnnBasedVad(const AvailableCpuFeatures& cpu_features)
- : input_layer_(kInputLayerInputSize,
- kInputLayerOutputSize,
- kInputDenseBias,
- kInputDenseWeights,
- TansigApproximated,
- cpu_features),
- hidden_layer_(kInputLayerOutputSize,
- kHiddenLayerOutputSize,
- kHiddenGruBias,
- kHiddenGruWeights,
- kHiddenGruRecurrentWeights),
- output_layer_(kHiddenLayerOutputSize,
- kOutputLayerOutputSize,
- kOutputDenseBias,
- kOutputDenseWeights,
- SigmoidApproximated,
- cpu_features) {
+RnnVad::RnnVad(const AvailableCpuFeatures& cpu_features)
+ : input_(kInputLayerInputSize,
+ kInputLayerOutputSize,
+ kInputDenseBias,
+ kInputDenseWeights,
+ TansigApproximated,
+ cpu_features),
+ hidden_(kInputLayerOutputSize,
+ kHiddenLayerOutputSize,
+ kHiddenGruBias,
+ kHiddenGruWeights,
+ kHiddenGruRecurrentWeights),
+ output_(kHiddenLayerOutputSize,
+ kOutputLayerOutputSize,
+ kOutputDenseBias,
+ kOutputDenseWeights,
+ SigmoidApproximated,
+ cpu_features) {
// Input-output chaining size checks.
- RTC_DCHECK_EQ(input_layer_.output_size(), hidden_layer_.input_size())
+ RTC_DCHECK_EQ(input_.size(), hidden_.input_size())
<< "The input and the hidden layers sizes do not match.";
- RTC_DCHECK_EQ(hidden_layer_.output_size(), output_layer_.input_size())
+ RTC_DCHECK_EQ(hidden_.size(), output_.input_size())
<< "The hidden and the output layers sizes do not match.";
}
-RnnBasedVad::~RnnBasedVad() = default;
+RnnVad::~RnnVad() = default;
-void RnnBasedVad::Reset() {
- hidden_layer_.Reset();
+void RnnVad::Reset() {
+ hidden_.Reset();
}
-float RnnBasedVad::ComputeVadProbability(
+float RnnVad::ComputeVadProbability(
rtc::ArrayView<const float, kFeatureVectorSize> feature_vector,
bool is_silence) {
if (is_silence) {
Reset();
return 0.f;
}
- input_layer_.ComputeOutput(feature_vector);
- hidden_layer_.ComputeOutput(input_layer_.GetOutput());
- output_layer_.ComputeOutput(hidden_layer_.GetOutput());
- const auto vad_output = output_layer_.GetOutput();
- return vad_output[0];
+ input_.ComputeOutput(feature_vector);
+ hidden_.ComputeOutput(input_);
+ output_.ComputeOutput(hidden_);
+ RTC_DCHECK_EQ(output_.size(), 1);
+ return output_.data()[0];
}
} // namespace rnn_vad
diff --git a/modules/audio_processing/agc2/rnn_vad/rnn.h b/modules/audio_processing/agc2/rnn_vad/rnn.h
index 1ef4c76..c886034 100644
--- a/modules/audio_processing/agc2/rnn_vad/rnn.h
+++ b/modules/audio_processing/agc2/rnn_vad/rnn.h
@@ -26,21 +26,17 @@
namespace webrtc {
namespace rnn_vad {
-// Maximum number of units for a fully-connected layer. This value is used to
-// over-allocate space for fully-connected layers output vectors (implemented as
-// std::array). The value should equal the number of units of the largest
-// fully-connected layer.
-constexpr int kFullyConnectedLayersMaxUnits = 24;
+// Maximum number of units for an FC layer.
+constexpr int kFullyConnectedLayerMaxUnits = 24;
-// Maximum number of units for a recurrent layer. This value is used to
-// over-allocate space for recurrent layers state vectors (implemented as
-// std::array). The value should equal the number of units of the largest
-// recurrent layer.
-constexpr int kRecurrentLayersMaxUnits = 24;
+// Maximum number of units for a GRU layer.
+constexpr int kGruLayerMaxUnits = 24;
-// Fully-connected layer.
+// Fully-connected layer with a custom activation function which owns the output
+// buffer.
class FullyConnectedLayer {
public:
+ // Ctor. `output_size` cannot be greater than `kFullyConnectedLayerMaxUnits`.
FullyConnectedLayer(int input_size,
int output_size,
rtc::ArrayView<const int8_t> bias,
@@ -50,9 +46,14 @@
FullyConnectedLayer(const FullyConnectedLayer&) = delete;
FullyConnectedLayer& operator=(const FullyConnectedLayer&) = delete;
~FullyConnectedLayer();
+
+ // Returns the size of the input vector.
int input_size() const { return input_size_; }
- int output_size() const { return output_size_; }
- rtc::ArrayView<const float> GetOutput() const;
+ // Returns the pointer to the first element of the output buffer.
+ const float* data() const { return output_.data(); }
+ // Returns the size of the output buffer.
+ int size() const { return output_size_; }
+
// Computes the fully-connected layer output.
void ComputeOutput(rtc::ArrayView<const float> input);
@@ -64,14 +65,16 @@
rtc::FunctionView<float(float)> activation_function_;
// The output vector of a recurrent layer has length equal to |output_size_|.
// However, for efficiency, over-allocation is used.
- std::array<float, kFullyConnectedLayersMaxUnits> output_;
+ std::array<float, kFullyConnectedLayerMaxUnits> output_;
const AvailableCpuFeatures cpu_features_;
};
// Recurrent layer with gated recurrent units (GRUs) with sigmoid and ReLU as
-// activation functions for the update/reset and output gates respectively.
+// activation functions for the update/reset and output gates respectively. It
+// owns the output buffer.
class GatedRecurrentLayer {
public:
+ // Ctor. `output_size` cannot be greater than `kGruLayerMaxUnits`.
GatedRecurrentLayer(int input_size,
int output_size,
rtc::ArrayView<const int8_t> bias,
@@ -80,9 +83,15 @@
GatedRecurrentLayer(const GatedRecurrentLayer&) = delete;
GatedRecurrentLayer& operator=(const GatedRecurrentLayer&) = delete;
~GatedRecurrentLayer();
+
+ // Returns the size of the input vector.
int input_size() const { return input_size_; }
- int output_size() const { return output_size_; }
- rtc::ArrayView<const float> GetOutput() const;
+ // Returns the pointer to the first element of the output buffer.
+ const float* data() const { return state_.data(); }
+ // Returns the size of the output buffer.
+ int size() const { return output_size_; }
+
+ // Resets the GRU state.
void Reset();
// Computes the recurrent layer output and updates the status.
void ComputeOutput(rtc::ArrayView<const float> input);
@@ -95,26 +104,28 @@
const std::vector<float> recurrent_weights_;
// The state vector of a recurrent layer has length equal to |output_size_|.
// However, to avoid dynamic allocation, over-allocation is used.
- std::array<float, kRecurrentLayersMaxUnits> state_;
+ std::array<float, kGruLayerMaxUnits> state_;
};
-// Recurrent network based VAD.
-class RnnBasedVad {
+// Recurrent network with hard-coded architecture and weights for voice activity
+// detection.
+class RnnVad {
public:
- explicit RnnBasedVad(const AvailableCpuFeatures& cpu_features);
- RnnBasedVad(const RnnBasedVad&) = delete;
- RnnBasedVad& operator=(const RnnBasedVad&) = delete;
- ~RnnBasedVad();
+ explicit RnnVad(const AvailableCpuFeatures& cpu_features);
+ RnnVad(const RnnVad&) = delete;
+ RnnVad& operator=(const RnnVad&) = delete;
+ ~RnnVad();
void Reset();
- // Compute and returns the probability of voice (range: [0.0, 1.0]).
+ // Observes `feature_vector` and `is_silence`, updates the RNN and returns the
+ // current voice probability.
float ComputeVadProbability(
rtc::ArrayView<const float, kFeatureVectorSize> feature_vector,
bool is_silence);
private:
- FullyConnectedLayer input_layer_;
- GatedRecurrentLayer hidden_layer_;
- FullyConnectedLayer output_layer_;
+ FullyConnectedLayer input_;
+ GatedRecurrentLayer hidden_;
+ FullyConnectedLayer output_;
};
} // namespace rnn_vad
diff --git a/modules/audio_processing/agc2/rnn_vad/rnn_unittest.cc b/modules/audio_processing/agc2/rnn_vad/rnn_unittest.cc
index c311b55..19e0afd 100644
--- a/modules/audio_processing/agc2/rnn_vad/rnn_unittest.cc
+++ b/modules/audio_processing/agc2/rnn_vad/rnn_unittest.cc
@@ -39,30 +39,20 @@
-0.690268f, -0.925327f, -0.541354f, 0.58455f, -0.606726f, -0.0372358f,
0.565991f, 0.435854f, 0.420812f, 0.162198f, -2.13f, 10.0089f};
-void WarmUpRnnVad(RnnBasedVad& rnn_vad) {
+void WarmUpRnnVad(RnnVad& rnn_vad) {
for (int i = 0; i < 10; ++i) {
rnn_vad.ComputeVadProbability(kFeatures, /*is_silence=*/false);
}
}
-void TestFullyConnectedLayer(FullyConnectedLayer* fc,
- rtc::ArrayView<const float> input_vector,
- rtc::ArrayView<const float> expected_output) {
- RTC_CHECK(fc);
- fc->ComputeOutput(input_vector);
- ExpectNearAbsolute(expected_output, fc->GetOutput(), 1e-5f);
-}
-
void TestGatedRecurrentLayer(
GatedRecurrentLayer& gru,
rtc::ArrayView<const float> input_sequence,
rtc::ArrayView<const float> expected_output_sequence) {
- auto gru_output_view = gru.GetOutput();
const int input_sequence_length = rtc::CheckedDivExact(
rtc::dchecked_cast<int>(input_sequence.size()), gru.input_size());
const int output_sequence_length = rtc::CheckedDivExact(
- rtc::dchecked_cast<int>(expected_output_sequence.size()),
- gru.output_size());
+ rtc::dchecked_cast<int>(expected_output_sequence.size()), gru.size());
ASSERT_EQ(input_sequence_length, output_sequence_length)
<< "The test data length is invalid.";
// Feed the GRU layer and check the output at every step.
@@ -71,9 +61,9 @@
SCOPED_TRACE(i);
gru.ComputeOutput(
input_sequence.subview(i * gru.input_size(), gru.input_size()));
- const auto expected_output = expected_output_sequence.subview(
- i * gru.output_size(), gru.output_size());
- ExpectNearAbsolute(expected_output, gru_output_view, 3e-6f);
+ const auto expected_output =
+ expected_output_sequence.subview(i * gru.size(), gru.size());
+ ExpectNearAbsolute(expected_output, gru, 3e-6f);
}
}
@@ -190,8 +180,8 @@
rnnoise::kInputLayerInputSize, rnnoise::kInputLayerOutputSize,
rnnoise::kInputDenseBias, rnnoise::kInputDenseWeights,
rnnoise::TansigApproximated, /*cpu_features=*/GetParam());
- TestFullyConnectedLayer(&fc, kFullyConnectedInputVector,
- kFullyConnectedExpectedOutput);
+ fc.ComputeOutput(kFullyConnectedInputVector);
+ ExpectNearAbsolute(kFullyConnectedExpectedOutput, fc, 1e-5f);
}
TEST_P(RnnParametrization, DISABLED_BenchmarkFullyConnectedLayer) {
@@ -237,7 +227,7 @@
// Checks that the speech probability is zero with silence.
TEST(RnnVadTest, CheckZeroProbabilityWithSilence) {
- RnnBasedVad rnn_vad(GetAvailableCpuFeatures());
+ RnnVad rnn_vad(GetAvailableCpuFeatures());
WarmUpRnnVad(rnn_vad);
EXPECT_EQ(rnn_vad.ComputeVadProbability(kFeatures, /*is_silence=*/true), 0.f);
}
@@ -245,7 +235,7 @@
// Checks that the same output is produced after reset given the same input
// sequence.
TEST(RnnVadTest, CheckRnnVadReset) {
- RnnBasedVad rnn_vad(GetAvailableCpuFeatures());
+ RnnVad rnn_vad(GetAvailableCpuFeatures());
WarmUpRnnVad(rnn_vad);
float pre = rnn_vad.ComputeVadProbability(kFeatures, /*is_silence=*/false);
rnn_vad.Reset();
@@ -257,7 +247,7 @@
// Checks that the same output is produced after silence is observed given the
// same input sequence.
TEST(RnnVadTest, CheckRnnVadSilence) {
- RnnBasedVad rnn_vad(GetAvailableCpuFeatures());
+ RnnVad rnn_vad(GetAvailableCpuFeatures());
WarmUpRnnVad(rnn_vad);
float pre = rnn_vad.ComputeVadProbability(kFeatures, /*is_silence=*/false);
rnn_vad.ComputeVadProbability(kFeatures, /*is_silence=*/true);
diff --git a/modules/audio_processing/agc2/rnn_vad/rnn_vad_tool.cc b/modules/audio_processing/agc2/rnn_vad/rnn_vad_tool.cc
index 0f3ad5c..a0e1242 100644
--- a/modules/audio_processing/agc2/rnn_vad/rnn_vad_tool.cc
+++ b/modules/audio_processing/agc2/rnn_vad/rnn_vad_tool.cc
@@ -67,7 +67,7 @@
const AvailableCpuFeatures cpu_features = GetAvailableCpuFeatures();
FeaturesExtractor features_extractor(cpu_features);
std::array<float, kFeatureVectorSize> feature_vector;
- RnnBasedVad rnn_vad(cpu_features);
+ RnnVad rnn_vad(cpu_features);
// Compute VAD probabilities.
while (true) {
diff --git a/modules/audio_processing/agc2/rnn_vad/rnn_vad_unittest.cc b/modules/audio_processing/agc2/rnn_vad/rnn_vad_unittest.cc
index fa7795c..81553b4 100644
--- a/modules/audio_processing/agc2/rnn_vad/rnn_vad_unittest.cc
+++ b/modules/audio_processing/agc2/rnn_vad/rnn_vad_unittest.cc
@@ -68,7 +68,7 @@
PushSincResampler decimator(kFrameSize10ms48kHz, kFrameSize10ms24kHz);
const AvailableCpuFeatures cpu_features = GetParam();
FeaturesExtractor features_extractor(cpu_features);
- RnnBasedVad rnn_vad(cpu_features);
+ RnnVad rnn_vad(cpu_features);
// Init input samples and expected output readers.
auto samples_reader = CreatePcmSamplesReader(kFrameSize10ms48kHz);
@@ -135,7 +135,7 @@
const AvailableCpuFeatures cpu_features = GetParam();
FeaturesExtractor features_extractor(cpu_features);
std::array<float, kFeatureVectorSize> feature_vector;
- RnnBasedVad rnn_vad(cpu_features);
+ RnnVad rnn_vad(cpu_features);
constexpr int number_of_tests = 100;
::webrtc::test::PerformanceTimer perf_timer(number_of_tests);
for (int k = 0; k < number_of_tests; ++k) {
diff --git a/modules/audio_processing/agc2/rnn_vad/vector_math.h b/modules/audio_processing/agc2/rnn_vad/vector_math.h
index 51bbbfb..0600b90 100644
--- a/modules/audio_processing/agc2/rnn_vad/vector_math.h
+++ b/modules/audio_processing/agc2/rnn_vad/vector_math.h
@@ -23,6 +23,7 @@
#include "api/array_view.h"
#include "modules/audio_processing/agc2/cpu_features.h"
#include "rtc_base/checks.h"
+#include "rtc_base/numerics/safe_conversions.h"
#include "rtc_base/system/arch.h"
namespace webrtc {
@@ -63,8 +64,8 @@
accumulator = _mm_add_ps(accumulator, high);
float dot_product = _mm_cvtss_f32(accumulator);
// Add the result for the last block if incomplete.
- for (int i = incomplete_block_index; static_cast<size_t>(i) < x.size();
- ++i) {
+ for (int i = incomplete_block_index;
+ i < rtc::dchecked_cast<int>(x.size()); ++i) {
dot_product += x[i] * y[i];
}
return dot_product;
diff --git a/modules/audio_processing/agc2/rnn_vad/vector_math_avx2.cc b/modules/audio_processing/agc2/rnn_vad/vector_math_avx2.cc
index 3b2c4ad..e4d246d 100644
--- a/modules/audio_processing/agc2/rnn_vad/vector_math_avx2.cc
+++ b/modules/audio_processing/agc2/rnn_vad/vector_math_avx2.cc
@@ -14,6 +14,7 @@
#include "api/array_view.h"
#include "rtc_base/checks.h"
+#include "rtc_base/numerics/safe_conversions.h"
namespace webrtc {
namespace rnn_vad {
@@ -43,7 +44,8 @@
low = _mm_add_ss(high, low);
float dot_product = _mm_cvtss_f32(low);
// Add the result for the last block if incomplete.
- for (int i = incomplete_block_index; static_cast<size_t>(i) < x.size(); ++i) {
+ for (int i = incomplete_block_index; i < rtc::dchecked_cast<int>(x.size());
+ ++i) {
dot_product += x[i] * y[i];
}
return dot_product;
diff --git a/modules/audio_processing/agc2/vad_with_level.cc b/modules/audio_processing/agc2/vad_with_level.cc
index da3bd0a..b54ae56 100644
--- a/modules/audio_processing/agc2/vad_with_level.cc
+++ b/modules/audio_processing/agc2/vad_with_level.cc
@@ -60,7 +60,7 @@
private:
PushResampler<float> resampler_;
rnn_vad::FeaturesExtractor features_extractor_;
- rnn_vad::RnnBasedVad rnn_vad_;
+ rnn_vad::RnnVad rnn_vad_;
};
// Returns an updated version of `p_old` by using instant decay and the given