RNN VAD: unit test code clean-up
- test_utils.h/.cc simplified
- webrtc::rnnvad::test -> webrtc::rnnvad
- all unit test code inside the anonymous namespace
- names improved
Bug: webrtc:10480
Change-Id: I0a0f056f9728bb8a1b93006b95d7ed5bf5bd4adb
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/196509
Commit-Queue: Alessio Bazzica <alessiob@webrtc.org>
Reviewed-by: Sam Zackrisson <saza@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#32789}
diff --git a/modules/audio_processing/agc2/rnn_vad/auto_correlation_unittest.cc b/modules/audio_processing/agc2/rnn_vad/auto_correlation_unittest.cc
index ef3748d..76001ed 100644
--- a/modules/audio_processing/agc2/rnn_vad/auto_correlation_unittest.cc
+++ b/modules/audio_processing/agc2/rnn_vad/auto_correlation_unittest.cc
@@ -17,15 +17,15 @@
namespace webrtc {
namespace rnn_vad {
-namespace test {
+namespace {
// Checks that the auto correlation function produces output within tolerance
// given test input data.
TEST(RnnVadTest, PitchBufferAutoCorrelationWithinTolerance) {
PitchTestData test_data;
std::array<float, kBufSize12kHz> pitch_buf_decimated;
- Decimate2x(test_data.GetPitchBufView(), pitch_buf_decimated);
- std::array<float, kNumPitchBufAutoCorrCoeffs> computed_output;
+ Decimate2x(test_data.PitchBuffer24kHzView(), pitch_buf_decimated);
+ std::array<float, kNumLags12kHz> computed_output;
{
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
// FloatingPointExceptionObserver fpe_observer;
@@ -33,7 +33,7 @@
auto_corr_calculator.ComputeOnPitchBuffer(pitch_buf_decimated,
computed_output);
}
- auto auto_corr_view = test_data.GetPitchBufAutoCorrCoeffsView();
+ auto auto_corr_view = test_data.AutoCorrelation12kHzView();
ExpectNearAbsolute({auto_corr_view.data(), auto_corr_view.size()},
computed_output, 3e-3f);
}
@@ -44,7 +44,7 @@
// Create constant signal with no pitch.
std::array<float, kBufSize12kHz> pitch_buf_decimated;
std::fill(pitch_buf_decimated.begin(), pitch_buf_decimated.end(), 1.f);
- std::array<float, kNumPitchBufAutoCorrCoeffs> computed_output;
+ std::array<float, kNumLags12kHz> computed_output;
{
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
// FloatingPointExceptionObserver fpe_observer;
@@ -55,12 +55,12 @@
// The expected output is a vector filled with the same expected
// auto-correlation value. The latter equals the length of a 20 ms frame.
constexpr int kFrameSize20ms12kHz = kFrameSize20ms24kHz / 2;
- std::array<float, kNumPitchBufAutoCorrCoeffs> expected_output;
+ std::array<float, kNumLags12kHz> expected_output;
std::fill(expected_output.begin(), expected_output.end(),
static_cast<float>(kFrameSize20ms12kHz));
ExpectNearAbsolute(expected_output, computed_output, 4e-5f);
}
-} // namespace test
+} // namespace
} // namespace rnn_vad
} // namespace webrtc
diff --git a/modules/audio_processing/agc2/rnn_vad/features_extraction_unittest.cc b/modules/audio_processing/agc2/rnn_vad/features_extraction_unittest.cc
index 0da971e..98da39e 100644
--- a/modules/audio_processing/agc2/rnn_vad/features_extraction_unittest.cc
+++ b/modules/audio_processing/agc2/rnn_vad/features_extraction_unittest.cc
@@ -14,7 +14,6 @@
#include <vector>
#include "modules/audio_processing/agc2/cpu_features.h"
-#include "modules/audio_processing/agc2/rnn_vad/test_utils.h"
#include "rtc_base/numerics/safe_compare.h"
#include "rtc_base/numerics/safe_conversions.h"
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
@@ -23,7 +22,6 @@
namespace webrtc {
namespace rnn_vad {
-namespace test {
namespace {
constexpr int ceil(int n, int m) {
@@ -52,7 +50,7 @@
// Feeds |features_extractor| with |samples| splitting it in 10 ms frames.
// For every frame, the output is written into |feature_vector|. Returns true
// if silence is detected in the last frame.
-bool FeedTestData(FeaturesExtractor* features_extractor,
+bool FeedTestData(FeaturesExtractor& features_extractor,
rtc::ArrayView<const float> samples,
rtc::ArrayView<float, kFeatureVectorSize> feature_vector) {
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
@@ -60,15 +58,13 @@
bool is_silence = true;
const int num_frames = samples.size() / kFrameSize10ms24kHz;
for (int i = 0; i < num_frames; ++i) {
- is_silence = features_extractor->CheckSilenceComputeFeatures(
+ is_silence = features_extractor.CheckSilenceComputeFeatures(
{samples.data() + i * kFrameSize10ms24kHz, kFrameSize10ms24kHz},
feature_vector);
}
return is_silence;
}
-} // namespace
-
// Extracts the features for two pure tones and verifies that the pitch field
// values reflect the known tone frequencies.
TEST(RnnVadTest, FeatureExtractionLowHighPitch) {
@@ -91,17 +87,17 @@
constexpr int pitch_feature_index = kFeatureVectorSize - 2;
// Low frequency tone - i.e., high period.
CreatePureTone(amplitude, low_pitch_hz, samples);
- ASSERT_FALSE(FeedTestData(&features_extractor, samples, feature_vector_view));
+ ASSERT_FALSE(FeedTestData(features_extractor, samples, feature_vector_view));
float high_pitch_period = feature_vector_view[pitch_feature_index];
// High frequency tone - i.e., low period.
features_extractor.Reset();
CreatePureTone(amplitude, high_pitch_hz, samples);
- ASSERT_FALSE(FeedTestData(&features_extractor, samples, feature_vector_view));
+ ASSERT_FALSE(FeedTestData(features_extractor, samples, feature_vector_view));
float low_pitch_period = feature_vector_view[pitch_feature_index];
// Check.
EXPECT_LT(low_pitch_period, high_pitch_period);
}
-} // namespace test
+} // namespace
} // namespace rnn_vad
} // namespace webrtc
diff --git a/modules/audio_processing/agc2/rnn_vad/lp_residual.h b/modules/audio_processing/agc2/rnn_vad/lp_residual.h
index 2e54dd9..380d9f6 100644
--- a/modules/audio_processing/agc2/rnn_vad/lp_residual.h
+++ b/modules/audio_processing/agc2/rnn_vad/lp_residual.h
@@ -18,7 +18,7 @@
namespace webrtc {
namespace rnn_vad {
-// LPC inverse filter length.
+// Linear predictive coding (LPC) inverse filter length.
constexpr int kNumLpcCoefficients = 5;
// Given a frame |x|, computes a post-processed version of LPC coefficients
diff --git a/modules/audio_processing/agc2/rnn_vad/lp_residual_unittest.cc b/modules/audio_processing/agc2/rnn_vad/lp_residual_unittest.cc
index 1779776..7b3a4a3 100644
--- a/modules/audio_processing/agc2/rnn_vad/lp_residual_unittest.cc
+++ b/modules/audio_processing/agc2/rnn_vad/lp_residual_unittest.cc
@@ -22,7 +22,7 @@
namespace webrtc {
namespace rnn_vad {
-namespace test {
+namespace {
// Checks that the LP residual can be computed on an empty frame.
TEST(RnnVadTest, LpResidualOfEmptyFrame) {
@@ -33,55 +33,48 @@
std::array<float, kFrameSize10ms24kHz> empty_frame;
empty_frame.fill(0.f);
// Compute inverse filter coefficients.
- std::array<float, kNumLpcCoefficients> lpc_coeffs;
- ComputeAndPostProcessLpcCoefficients(empty_frame, lpc_coeffs);
+ std::array<float, kNumLpcCoefficients> lpc;
+ ComputeAndPostProcessLpcCoefficients(empty_frame, lpc);
// Compute LP residual.
std::array<float, kFrameSize10ms24kHz> lp_residual;
- ComputeLpResidual(lpc_coeffs, empty_frame, lp_residual);
+ ComputeLpResidual(lpc, empty_frame, lp_residual);
}
// Checks that the computed LP residual is bit-exact given test input data.
TEST(RnnVadTest, LpResidualPipelineBitExactness) {
// Input and expected output readers.
- auto pitch_buf_24kHz_reader = CreatePitchBuffer24kHzReader();
- auto lp_residual_reader = CreateLpResidualAndPitchPeriodGainReader();
+ ChunksFileReader pitch_buffer_reader = CreatePitchBuffer24kHzReader();
+ ChunksFileReader lp_pitch_reader = CreateLpResidualAndPitchInfoReader();
// Buffers.
- std::vector<float> pitch_buf_data(kBufSize24kHz);
- std::array<float, kNumLpcCoefficients> lpc_coeffs;
+ std::vector<float> pitch_buffer_24kHz(kBufSize24kHz);
+ std::array<float, kNumLpcCoefficients> lpc;
std::vector<float> computed_lp_residual(kBufSize24kHz);
std::vector<float> expected_lp_residual(kBufSize24kHz);
// Test length.
const int num_frames =
- std::min(pitch_buf_24kHz_reader.second, 300); // Max 3 s.
- ASSERT_GE(lp_residual_reader.second, num_frames);
+ std::min(pitch_buffer_reader.num_chunks, 300); // Max 3 s.
+ ASSERT_GE(lp_pitch_reader.num_chunks, num_frames);
- {
- // TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
- // FloatingPointExceptionObserver fpe_observer;
- for (int i = 0; i < num_frames; ++i) {
- // Read input.
- ASSERT_TRUE(pitch_buf_24kHz_reader.first->ReadChunk(pitch_buf_data));
- // Read expected output (ignore pitch gain and period).
- ASSERT_TRUE(lp_residual_reader.first->ReadChunk(expected_lp_residual));
- float unused;
- ASSERT_TRUE(lp_residual_reader.first->ReadValue(&unused));
- ASSERT_TRUE(lp_residual_reader.first->ReadValue(&unused));
-
- // Check every 200 ms.
- if (i % 20 != 0) {
- continue;
- }
-
- SCOPED_TRACE(i);
- ComputeAndPostProcessLpcCoefficients(pitch_buf_data, lpc_coeffs);
- ComputeLpResidual(lpc_coeffs, pitch_buf_data, computed_lp_residual);
+ // TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
+ // FloatingPointExceptionObserver fpe_observer;
+ for (int i = 0; i < num_frames; ++i) {
+ SCOPED_TRACE(i);
+ // Read input.
+ ASSERT_TRUE(pitch_buffer_reader.reader->ReadChunk(pitch_buffer_24kHz));
+ // Read expected output (ignore pitch gain and period).
+ ASSERT_TRUE(lp_pitch_reader.reader->ReadChunk(expected_lp_residual));
+ lp_pitch_reader.reader->SeekForward(2); // Pitch period and strength.
+ // Check every 200 ms.
+ if (i % 20 == 0) {
+ ComputeAndPostProcessLpcCoefficients(pitch_buffer_24kHz, lpc);
+ ComputeLpResidual(lpc, pitch_buffer_24kHz, computed_lp_residual);
ExpectNearAbsolute(expected_lp_residual, computed_lp_residual, kFloatMin);
}
}
}
-} // namespace test
+} // namespace
} // namespace rnn_vad
} // namespace webrtc
diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc
index a4a4df1..8c336af 100644
--- a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc
+++ b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc
@@ -22,7 +22,6 @@
namespace webrtc {
namespace rnn_vad {
-namespace test {
namespace {
constexpr int kTestPitchPeriodsLow = 3 * kMinPitch48kHz / 2;
@@ -63,12 +62,12 @@
const AvailableCpuFeatures cpu_features = GetAvailableCpuFeatures();
PitchTestData test_data;
- std::array<float, kNumPitchBufSquareEnergies> computed_output;
+ std::array<float, kRefineNumLags24kHz> computed_output;
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
// FloatingPointExceptionObserver fpe_observer;
- ComputeSlidingFrameSquareEnergies24kHz(test_data.GetPitchBufView(),
+ ComputeSlidingFrameSquareEnergies24kHz(test_data.PitchBuffer24kHzView(),
computed_output, cpu_features);
- auto square_energies_view = test_data.GetPitchBufSquareEnergiesView();
+ auto square_energies_view = test_data.SquareEnergies24kHzView();
ExpectNearAbsolute({square_energies_view.data(), square_energies_view.size()},
computed_output, 1e-3f);
}
@@ -79,13 +78,12 @@
PitchTestData test_data;
std::array<float, kBufSize12kHz> pitch_buf_decimated;
- Decimate2x(test_data.GetPitchBufView(), pitch_buf_decimated);
+ Decimate2x(test_data.PitchBuffer24kHzView(), pitch_buf_decimated);
CandidatePitchPeriods pitch_candidates;
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
// FloatingPointExceptionObserver fpe_observer;
- auto auto_corr_view = test_data.GetPitchBufAutoCorrCoeffsView();
- pitch_candidates = ComputePitchPeriod12kHz(pitch_buf_decimated,
- auto_corr_view, cpu_features);
+ pitch_candidates = ComputePitchPeriod12kHz(
+ pitch_buf_decimated, test_data.AutoCorrelation12kHzView(), cpu_features);
EXPECT_EQ(pitch_candidates.best, 140);
EXPECT_EQ(pitch_candidates.second_best, 142);
}
@@ -98,16 +96,16 @@
std::vector<float> y_energy(kRefineNumLags24kHz);
rtc::ArrayView<float, kRefineNumLags24kHz> y_energy_view(y_energy.data(),
kRefineNumLags24kHz);
- ComputeSlidingFrameSquareEnergies24kHz(test_data.GetPitchBufView(),
+ ComputeSlidingFrameSquareEnergies24kHz(test_data.PitchBuffer24kHzView(),
y_energy_view, cpu_features);
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
// FloatingPointExceptionObserver fpe_observer;
EXPECT_EQ(
- ComputePitchPeriod48kHz(test_data.GetPitchBufView(), y_energy_view,
+ ComputePitchPeriod48kHz(test_data.PitchBuffer24kHzView(), y_energy_view,
/*pitch_candidates=*/{280, 284}, cpu_features),
560);
EXPECT_EQ(
- ComputePitchPeriod48kHz(test_data.GetPitchBufView(), y_energy_view,
+ ComputePitchPeriod48kHz(test_data.PitchBuffer24kHzView(), y_energy_view,
/*pitch_candidates=*/{260, 284}, cpu_features),
568);
}
@@ -132,12 +130,12 @@
std::vector<float> y_energy(kRefineNumLags24kHz);
rtc::ArrayView<float, kRefineNumLags24kHz> y_energy_view(y_energy.data(),
kRefineNumLags24kHz);
- ComputeSlidingFrameSquareEnergies24kHz(test_data.GetPitchBufView(),
+ ComputeSlidingFrameSquareEnergies24kHz(test_data.PitchBuffer24kHzView(),
y_energy_view, params.cpu_features);
EXPECT_EQ(
- ComputePitchPeriod48kHz(test_data.GetPitchBufView(), y_energy_view,
+ ComputePitchPeriod48kHz(test_data.PitchBuffer24kHzView(), y_energy_view,
params.pitch_candidates, params.cpu_features),
- ComputePitchPeriod48kHz(test_data.GetPitchBufView(), y_energy_view,
+ ComputePitchPeriod48kHz(test_data.PitchBuffer24kHzView(), y_energy_view,
swapped_pitch_candidates, params.cpu_features));
}
@@ -179,13 +177,13 @@
std::vector<float> y_energy(kRefineNumLags24kHz);
rtc::ArrayView<float, kRefineNumLags24kHz> y_energy_view(y_energy.data(),
kRefineNumLags24kHz);
- ComputeSlidingFrameSquareEnergies24kHz(test_data.GetPitchBufView(),
+ ComputeSlidingFrameSquareEnergies24kHz(test_data.PitchBuffer24kHzView(),
y_energy_view, params.cpu_features);
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
// FloatingPointExceptionObserver fpe_observer;
const auto computed_output = ComputeExtendedPitchPeriod48kHz(
- test_data.GetPitchBufView(), y_energy_view, params.initial_pitch_period,
- params.last_pitch, params.cpu_features);
+ test_data.PitchBuffer24kHzView(), y_energy_view,
+ params.initial_pitch_period, params.last_pitch, params.cpu_features);
EXPECT_EQ(params.expected_pitch.period, computed_output.period);
EXPECT_NEAR(params.expected_pitch.strength, computed_output.strength, 1e-6f);
}
@@ -219,6 +217,5 @@
PrintTestIndexAndCpuFeatures<ExtendedPitchPeriodSearchParameters>);
} // namespace
-} // namespace test
} // namespace rnn_vad
} // namespace webrtc
diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search_unittest.cc b/modules/audio_processing/agc2/rnn_vad/pitch_search_unittest.cc
index fe9be5d..79b44b9 100644
--- a/modules/audio_processing/agc2/rnn_vad/pitch_search_unittest.cc
+++ b/modules/audio_processing/agc2/rnn_vad/pitch_search_unittest.cc
@@ -26,8 +26,8 @@
// Checks that the computed pitch period is bit-exact and that the computed
// pitch gain is within tolerance given test input data.
TEST(RnnVadTest, PitchSearchWithinTolerance) {
- auto lp_residual_reader = test::CreateLpResidualAndPitchPeriodGainReader();
- const int num_frames = std::min(lp_residual_reader.second, 300); // Max 3 s.
+ ChunksFileReader reader = CreateLpResidualAndPitchInfoReader();
+ const int num_frames = std::min(reader.num_chunks, 300); // Max 3 s.
std::vector<float> lp_residual(kBufSize24kHz);
float expected_pitch_period, expected_pitch_strength;
const AvailableCpuFeatures cpu_features = GetAvailableCpuFeatures();
@@ -37,9 +37,9 @@
// FloatingPointExceptionObserver fpe_observer;
for (int i = 0; i < num_frames; ++i) {
SCOPED_TRACE(i);
- lp_residual_reader.first->ReadChunk(lp_residual);
- lp_residual_reader.first->ReadValue(&expected_pitch_period);
- lp_residual_reader.first->ReadValue(&expected_pitch_strength);
+ ASSERT_TRUE(reader.reader->ReadChunk(lp_residual));
+ ASSERT_TRUE(reader.reader->ReadValue(expected_pitch_period));
+ ASSERT_TRUE(reader.reader->ReadValue(expected_pitch_strength));
int pitch_period =
pitch_estimator.Estimate({lp_residual.data(), kBufSize24kHz});
EXPECT_EQ(expected_pitch_period, pitch_period);
diff --git a/modules/audio_processing/agc2/rnn_vad/ring_buffer_unittest.cc b/modules/audio_processing/agc2/rnn_vad/ring_buffer_unittest.cc
index 8b061a9..d11d4ea 100644
--- a/modules/audio_processing/agc2/rnn_vad/ring_buffer_unittest.cc
+++ b/modules/audio_processing/agc2/rnn_vad/ring_buffer_unittest.cc
@@ -14,7 +14,6 @@
namespace webrtc {
namespace rnn_vad {
-namespace test {
namespace {
// Compare the elements of two given array views.
@@ -64,8 +63,6 @@
}
}
-} // namespace
-
// Check that for different delays, different views are returned.
TEST(RnnVadTest, RingBufferArrayViews) {
constexpr int s = 3;
@@ -110,6 +107,6 @@
TestRingBuffer<float, 5, 5>();
}
-} // namespace test
+} // namespace
} // namespace rnn_vad
} // namespace webrtc
diff --git a/modules/audio_processing/agc2/rnn_vad/rnn_fc_unittest.cc b/modules/audio_processing/agc2/rnn_vad/rnn_fc_unittest.cc
index 1094832..c586ed2 100644
--- a/modules/audio_processing/agc2/rnn_vad/rnn_fc_unittest.cc
+++ b/modules/audio_processing/agc2/rnn_vad/rnn_fc_unittest.cc
@@ -24,7 +24,6 @@
namespace webrtc {
namespace rnn_vad {
-namespace test {
namespace {
using ::rnnoise::kInputDenseBias;
@@ -104,6 +103,5 @@
});
} // namespace
-} // namespace test
} // namespace rnn_vad
} // namespace webrtc
diff --git a/modules/audio_processing/agc2/rnn_vad/rnn_gru_unittest.cc b/modules/audio_processing/agc2/rnn_vad/rnn_gru_unittest.cc
index 54e1cf5..4e8b524 100644
--- a/modules/audio_processing/agc2/rnn_vad/rnn_gru_unittest.cc
+++ b/modules/audio_processing/agc2/rnn_vad/rnn_gru_unittest.cc
@@ -21,7 +21,6 @@
namespace webrtc {
namespace rnn_vad {
-namespace test {
namespace {
void TestGatedRecurrentLayer(
@@ -135,6 +134,5 @@
}
} // namespace
-} // namespace test
} // namespace rnn_vad
} // namespace webrtc
diff --git a/modules/audio_processing/agc2/rnn_vad/rnn_unittest.cc b/modules/audio_processing/agc2/rnn_vad/rnn_unittest.cc
index 1c314d1..4c5409a 100644
--- a/modules/audio_processing/agc2/rnn_vad/rnn_unittest.cc
+++ b/modules/audio_processing/agc2/rnn_vad/rnn_unittest.cc
@@ -17,7 +17,6 @@
namespace webrtc {
namespace rnn_vad {
-namespace test {
namespace {
constexpr std::array<float, kFeatureVectorSize> kFeatures = {
@@ -67,6 +66,5 @@
}
} // namespace
-} // namespace test
} // namespace rnn_vad
} // namespace webrtc
diff --git a/modules/audio_processing/agc2/rnn_vad/rnn_vad_unittest.cc b/modules/audio_processing/agc2/rnn_vad/rnn_vad_unittest.cc
index 81553b4..7eb699c 100644
--- a/modules/audio_processing/agc2/rnn_vad/rnn_vad_unittest.cc
+++ b/modules/audio_processing/agc2/rnn_vad/rnn_vad_unittest.cc
@@ -9,6 +9,7 @@
*/
#include <array>
+#include <memory>
#include <string>
#include <vector>
@@ -26,7 +27,6 @@
namespace webrtc {
namespace rnn_vad {
-namespace test {
namespace {
constexpr int kFrameSize10ms48kHz = 480;
@@ -49,8 +49,6 @@
// constant below to true in order to write new expected output binary files.
constexpr bool kWriteComputedOutputToFile = false;
-} // namespace
-
// Avoids that one forgets to set |kWriteComputedOutputToFile| back to false
// when the expected output files are re-exported.
TEST(RnnVadTest, CheckWriteComputedOutputIsFalse) {
@@ -71,12 +69,11 @@
RnnVad rnn_vad(cpu_features);
// Init input samples and expected output readers.
- auto samples_reader = CreatePcmSamplesReader(kFrameSize10ms48kHz);
- auto expected_vad_prob_reader = CreateVadProbsReader();
+ std::unique_ptr<FileReader> samples_reader = CreatePcmSamplesReader();
+ std::unique_ptr<FileReader> expected_vad_prob_reader = CreateVadProbsReader();
- // Input length.
- const int num_frames = samples_reader.second;
- ASSERT_GE(expected_vad_prob_reader.second, num_frames);
+ // Input length. The last incomplete frame is ignored.
+ const int num_frames = samples_reader->size() / kFrameSize10ms48kHz;
// Init buffers.
std::vector<float> samples_48k(kFrameSize10ms48kHz);
@@ -86,12 +83,12 @@
std::vector<float> expected_vad_prob(num_frames);
// Read expected output.
- ASSERT_TRUE(expected_vad_prob_reader.first->ReadChunk(expected_vad_prob));
+ ASSERT_TRUE(expected_vad_prob_reader->ReadChunk(expected_vad_prob));
// Compute VAD probabilities on the downsampled input.
float cumulative_error = 0.f;
for (int i = 0; i < num_frames; ++i) {
- samples_reader.first->ReadChunk(samples_48k);
+ ASSERT_TRUE(samples_reader->ReadChunk(samples_48k));
decimator.Resample(samples_48k.data(), samples_48k.size(),
samples_24k.data(), samples_24k.size());
bool is_silence = features_extractor.CheckSilenceComputeFeatures(
@@ -106,7 +103,7 @@
EXPECT_LT(cumulative_error / num_frames, 1e-4f);
if (kWriteComputedOutputToFile) {
- BinaryFileWriter<float> vad_prob_writer("new_vad_prob.dat");
+ FileWriter vad_prob_writer("new_vad_prob.dat");
vad_prob_writer.WriteChunk(computed_vad_prob);
}
}
@@ -118,15 +115,16 @@
// - on android: run the this unit test adding "--logcat-output-file".
TEST_P(RnnVadProbabilityParametrization, DISABLED_RnnVadPerformance) {
// PCM samples reader and buffers.
- auto samples_reader = CreatePcmSamplesReader(kFrameSize10ms48kHz);
- const int num_frames = samples_reader.second;
+ std::unique_ptr<FileReader> samples_reader = CreatePcmSamplesReader();
+ // The last incomplete frame is ignored.
+ const int num_frames = samples_reader->size() / kFrameSize10ms48kHz;
std::array<float, kFrameSize10ms48kHz> samples;
// Pre-fetch and decimate samples.
PushSincResampler decimator(kFrameSize10ms48kHz, kFrameSize10ms24kHz);
std::vector<float> prefetched_decimated_samples;
prefetched_decimated_samples.resize(num_frames * kFrameSize10ms24kHz);
for (int i = 0; i < num_frames; ++i) {
- samples_reader.first->ReadChunk(samples);
+ ASSERT_TRUE(samples_reader->ReadChunk(samples));
decimator.Resample(samples.data(), samples.size(),
&prefetched_decimated_samples[i * kFrameSize10ms24kHz],
kFrameSize10ms24kHz);
@@ -151,7 +149,6 @@
rnn_vad.ComputeVadProbability(feature_vector, is_silence);
}
perf_timer.StopTimer();
- samples_reader.first->SeekBeginning();
}
DumpPerfStats(num_frames * kFrameSize10ms24kHz, kSampleRate24kHz,
perf_timer.GetDurationAverage(),
@@ -180,6 +177,6 @@
return info.param.ToString();
});
-} // namespace test
+} // namespace
} // namespace rnn_vad
} // namespace webrtc
diff --git a/modules/audio_processing/agc2/rnn_vad/sequence_buffer_unittest.cc b/modules/audio_processing/agc2/rnn_vad/sequence_buffer_unittest.cc
index 125f1b8..f577571 100644
--- a/modules/audio_processing/agc2/rnn_vad/sequence_buffer_unittest.cc
+++ b/modules/audio_processing/agc2/rnn_vad/sequence_buffer_unittest.cc
@@ -17,7 +17,6 @@
namespace webrtc {
namespace rnn_vad {
-namespace test {
namespace {
template <typename T, int S, int N>
@@ -60,8 +59,6 @@
}
}
-} // namespace
-
TEST(RnnVadTest, SequenceBufferGetters) {
constexpr int buffer_size = 8;
constexpr int chunk_size = 8;
@@ -100,6 +97,6 @@
TestSequenceBufferPushOp<float, 23, 7>(); // Non-integer ratio.
}
-} // namespace test
+} // namespace
} // namespace rnn_vad
} // namespace webrtc
diff --git a/modules/audio_processing/agc2/rnn_vad/spectral_features_internal_unittest.cc b/modules/audio_processing/agc2/rnn_vad/spectral_features_internal_unittest.cc
index 461047d..11a44a5 100644
--- a/modules/audio_processing/agc2/rnn_vad/spectral_features_internal_unittest.cc
+++ b/modules/audio_processing/agc2/rnn_vad/spectral_features_internal_unittest.cc
@@ -26,7 +26,6 @@
namespace webrtc {
namespace rnn_vad {
-namespace test {
namespace {
// Generates the values for the array named |kOpusBandWeights24kHz20ms| in the
@@ -49,8 +48,6 @@
return weights;
}
-} // namespace
-
// Checks that the values returned by GetOpusScaleNumBins24kHz20ms() match the
// Opus scale frequency boundaries.
TEST(RnnVadTest, TestOpusScaleBoundaries) {
@@ -158,6 +155,6 @@
}
}
-} // namespace test
+} // namespace
} // namespace rnn_vad
} // namespace webrtc
diff --git a/modules/audio_processing/agc2/rnn_vad/spectral_features_unittest.cc b/modules/audio_processing/agc2/rnn_vad/spectral_features_unittest.cc
index fa376f2..9f41e96 100644
--- a/modules/audio_processing/agc2/rnn_vad/spectral_features_unittest.cc
+++ b/modules/audio_processing/agc2/rnn_vad/spectral_features_unittest.cc
@@ -21,7 +21,6 @@
namespace webrtc {
namespace rnn_vad {
-namespace test {
namespace {
constexpr int kTestFeatureVectorSize = kNumBands + 3 * kNumLowerBands + 1;
@@ -66,8 +65,6 @@
constexpr float kInitialFeatureVal = -9999.f;
-} // namespace
-
// Checks that silence is detected when the input signal is 0 and that the
// feature vector is written only if the input signal is not tagged as silence.
TEST(RnnVadTest, SpectralFeaturesWithAndWithoutSilence) {
@@ -159,6 +156,6 @@
feature_vector_last[kNumBands + 3 * kNumLowerBands]);
}
-} // namespace test
+} // namespace
} // namespace rnn_vad
} // namespace webrtc
diff --git a/modules/audio_processing/agc2/rnn_vad/symmetric_matrix_buffer_unittest.cc b/modules/audio_processing/agc2/rnn_vad/symmetric_matrix_buffer_unittest.cc
index c1da8d1..6f61c87 100644
--- a/modules/audio_processing/agc2/rnn_vad/symmetric_matrix_buffer_unittest.cc
+++ b/modules/audio_processing/agc2/rnn_vad/symmetric_matrix_buffer_unittest.cc
@@ -15,7 +15,6 @@
namespace webrtc {
namespace rnn_vad {
-namespace test {
namespace {
template <typename T, int S>
@@ -44,8 +43,6 @@
return false;
}
-} // namespace
-
// Test that shows how to combine RingBuffer and SymmetricMatrixBuffer to
// efficiently compute pair-wise scores. This test verifies that the evolution
// of a SymmetricMatrixBuffer instance follows that of RingBuffer.
@@ -105,6 +102,6 @@
}
}
-} // namespace test
+} // namespace
} // namespace rnn_vad
} // namespace webrtc
diff --git a/modules/audio_processing/agc2/rnn_vad/test_utils.cc b/modules/audio_processing/agc2/rnn_vad/test_utils.cc
index 75de109..3db6774 100644
--- a/modules/audio_processing/agc2/rnn_vad/test_utils.cc
+++ b/modules/audio_processing/agc2/rnn_vad/test_utils.cc
@@ -11,7 +11,10 @@
#include "modules/audio_processing/agc2/rnn_vad/test_utils.h"
#include <algorithm>
+#include <fstream>
#include <memory>
+#include <type_traits>
+#include <vector>
#include "rtc_base/checks.h"
#include "rtc_base/numerics/safe_compare.h"
@@ -20,11 +23,46 @@
namespace webrtc {
namespace rnn_vad {
-namespace test {
namespace {
-using ReaderPairType =
- std::pair<std::unique_ptr<BinaryFileReader<float>>, const int>;
+// File reader for binary files that contain a sequence of values with
+// arithmetic type `T`. The values of type `T` that are read are cast to float.
+template <typename T>
+class FloatFileReader : public FileReader {
+ public:
+ static_assert(std::is_arithmetic<T>::value, "");
+ FloatFileReader(const std::string& filename)
+ : is_(filename, std::ios::binary | std::ios::ate),
+ size_(is_.tellg() / sizeof(T)) {
+ RTC_CHECK(is_);
+ SeekBeginning();
+ }
+ FloatFileReader(const FloatFileReader&) = delete;
+ FloatFileReader& operator=(const FloatFileReader&) = delete;
+ ~FloatFileReader() = default;
+
+ int size() const override { return size_; }
+ bool ReadChunk(rtc::ArrayView<float> dst) override {
+ const std::streamsize bytes_to_read = dst.size() * sizeof(T);
+ if (std::is_same<T, float>::value) {
+ is_.read(reinterpret_cast<char*>(dst.data()), bytes_to_read);
+ } else {
+ buffer_.resize(dst.size());
+ is_.read(reinterpret_cast<char*>(buffer_.data()), bytes_to_read);
+ std::transform(buffer_.begin(), buffer_.end(), dst.begin(),
+ [](const T& v) -> float { return static_cast<float>(v); });
+ }
+ return is_.gcount() == bytes_to_read;
+ }
+ bool ReadValue(float& dst) override { return ReadChunk({&dst, 1}); }
+ void SeekForward(int hop) override { is_.seekg(hop * sizeof(T), is_.cur); }
+ void SeekBeginning() override { is_.seekg(0, is_.beg); }
+
+ private:
+ std::ifstream is_;
+ const int size_;
+ std::vector<T> buffer_;
+};
} // namespace
@@ -49,66 +87,49 @@
}
}
-std::pair<std::unique_ptr<BinaryFileReader<int16_t, float>>, const int>
-CreatePcmSamplesReader(const int frame_length) {
- auto ptr = std::make_unique<BinaryFileReader<int16_t, float>>(
- test::ResourcePath("audio_processing/agc2/rnn_vad/samples", "pcm"),
- frame_length);
- // The last incomplete frame is ignored.
- return {std::move(ptr), ptr->data_length() / frame_length};
+std::unique_ptr<FileReader> CreatePcmSamplesReader() {
+ return std::make_unique<FloatFileReader<int16_t>>(
+ /*filename=*/test::ResourcePath("audio_processing/agc2/rnn_vad/samples",
+ "pcm"));
}
-ReaderPairType CreatePitchBuffer24kHzReader() {
- constexpr int cols = 864;
- auto ptr = std::make_unique<BinaryFileReader<float>>(
- ResourcePath("audio_processing/agc2/rnn_vad/pitch_buf_24k", "dat"), cols);
- return {std::move(ptr), rtc::CheckedDivExact(ptr->data_length(), cols)};
+ChunksFileReader CreatePitchBuffer24kHzReader() {
+ auto reader = std::make_unique<FloatFileReader<float>>(
+ /*filename=*/test::ResourcePath(
+ "audio_processing/agc2/rnn_vad/pitch_buf_24k", "dat"));
+ const int num_chunks = rtc::CheckedDivExact(reader->size(), kBufSize24kHz);
+ return {/*chunk_size=*/kBufSize24kHz, num_chunks, std::move(reader)};
}
-ReaderPairType CreateLpResidualAndPitchPeriodGainReader() {
- constexpr int num_lp_residual_coeffs = 864;
- auto ptr = std::make_unique<BinaryFileReader<float>>(
- ResourcePath("audio_processing/agc2/rnn_vad/pitch_lp_res", "dat"),
- num_lp_residual_coeffs);
- return {std::move(ptr),
- rtc::CheckedDivExact(ptr->data_length(), 2 + num_lp_residual_coeffs)};
+ChunksFileReader CreateLpResidualAndPitchInfoReader() {
+ constexpr int kPitchInfoSize = 2; // Pitch period and strength.
+ constexpr int kChunkSize = kBufSize24kHz + kPitchInfoSize;
+ auto reader = std::make_unique<FloatFileReader<float>>(
+ /*filename=*/test::ResourcePath(
+ "audio_processing/agc2/rnn_vad/pitch_lp_res", "dat"));
+ const int num_chunks = rtc::CheckedDivExact(reader->size(), kChunkSize);
+ return {kChunkSize, num_chunks, std::move(reader)};
}
-ReaderPairType CreateVadProbsReader() {
- auto ptr = std::make_unique<BinaryFileReader<float>>(
- test::ResourcePath("audio_processing/agc2/rnn_vad/vad_prob", "dat"));
- return {std::move(ptr), ptr->data_length()};
+std::unique_ptr<FileReader> CreateVadProbsReader() {
+ return std::make_unique<FloatFileReader<float>>(
+ /*filename=*/test::ResourcePath("audio_processing/agc2/rnn_vad/vad_prob",
+ "dat"));
}
PitchTestData::PitchTestData() {
- BinaryFileReader<float> test_data_reader(
- ResourcePath("audio_processing/agc2/rnn_vad/pitch_search_int", "dat"),
- 1396);
- test_data_reader.ReadChunk(test_data_);
+ FloatFileReader<float> reader(
+ /*filename=*/ResourcePath(
+ "audio_processing/agc2/rnn_vad/pitch_search_int", "dat"));
+ reader.ReadChunk(pitch_buffer_24k_);
+ reader.ReadChunk(square_energies_24k_);
+ reader.ReadChunk(auto_correlation_12k_);
// Reverse the order of the squared energy values.
// Required after the WebRTC CL 191703 which switched to forward computation.
- std::reverse(test_data_.begin() + kBufSize24kHz,
- test_data_.begin() + kBufSize24kHz + kNumPitchBufSquareEnergies);
+ std::reverse(square_energies_24k_.begin(), square_energies_24k_.end());
}
PitchTestData::~PitchTestData() = default;
-rtc::ArrayView<const float, kBufSize24kHz> PitchTestData::GetPitchBufView()
- const {
- return {test_data_.data(), kBufSize24kHz};
-}
-
-rtc::ArrayView<const float, kNumPitchBufSquareEnergies>
-PitchTestData::GetPitchBufSquareEnergiesView() const {
- return {test_data_.data() + kBufSize24kHz, kNumPitchBufSquareEnergies};
-}
-
-rtc::ArrayView<const float, kNumPitchBufAutoCorrCoeffs>
-PitchTestData::GetPitchBufAutoCorrCoeffsView() const {
- return {test_data_.data() + kBufSize24kHz + kNumPitchBufSquareEnergies,
- kNumPitchBufAutoCorrCoeffs};
-}
-
-} // namespace test
} // namespace rnn_vad
} // namespace webrtc
diff --git a/modules/audio_processing/agc2/rnn_vad/test_utils.h b/modules/audio_processing/agc2/rnn_vad/test_utils.h
index 3d1ad25..86af5e0 100644
--- a/modules/audio_processing/agc2/rnn_vad/test_utils.h
+++ b/modules/audio_processing/agc2/rnn_vad/test_utils.h
@@ -11,15 +11,10 @@
#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_TEST_UTILS_H_
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_TEST_UTILS_H_
-#include <algorithm>
#include <array>
#include <fstream>
-#include <limits>
#include <memory>
#include <string>
-#include <type_traits>
-#include <utility>
-#include <vector>
#include "api/array_view.h"
#include "modules/audio_processing/agc2/rnn_vad/common.h"
@@ -28,7 +23,6 @@
namespace webrtc {
namespace rnn_vad {
-namespace test {
constexpr float kFloatMin = std::numeric_limits<float>::min();
@@ -43,98 +37,48 @@
rtc::ArrayView<const float> computed,
float tolerance);
-// Reader for binary files consisting of an arbitrary long sequence of elements
-// having type T. It is possible to read and cast to another type D at once.
-template <typename T, typename D = T>
-class BinaryFileReader {
+// File reader interface.
+class FileReader {
public:
- BinaryFileReader(const std::string& file_path, int chunk_size = 0)
- : is_(file_path, std::ios::binary | std::ios::ate),
- data_length_(is_.tellg() / sizeof(T)),
- chunk_size_(chunk_size) {
- RTC_CHECK(is_);
- SeekBeginning();
- buf_.resize(chunk_size_);
- }
- BinaryFileReader(const BinaryFileReader&) = delete;
- BinaryFileReader& operator=(const BinaryFileReader&) = delete;
- ~BinaryFileReader() = default;
- int data_length() const { return data_length_; }
- bool ReadValue(D* dst) {
- if (std::is_same<T, D>::value) {
- is_.read(reinterpret_cast<char*>(dst), sizeof(T));
- } else {
- T v;
- is_.read(reinterpret_cast<char*>(&v), sizeof(T));
- *dst = static_cast<D>(v);
- }
- return is_.gcount() == sizeof(T);
- }
- // If |chunk_size| was specified in the ctor, it will check that the size of
- // |dst| equals |chunk_size|.
- bool ReadChunk(rtc::ArrayView<D> dst) {
- RTC_DCHECK((chunk_size_ == 0) || rtc::SafeEq(chunk_size_, dst.size()));
- const std::streamsize bytes_to_read = dst.size() * sizeof(T);
- if (std::is_same<T, D>::value) {
- is_.read(reinterpret_cast<char*>(dst.data()), bytes_to_read);
- } else {
- is_.read(reinterpret_cast<char*>(buf_.data()), bytes_to_read);
- std::transform(buf_.begin(), buf_.end(), dst.begin(),
- [](const T& v) -> D { return static_cast<D>(v); });
- }
- return is_.gcount() == bytes_to_read;
- }
- void SeekForward(int items) { is_.seekg(items * sizeof(T), is_.cur); }
- void SeekBeginning() { is_.seekg(0, is_.beg); }
-
- private:
- std::ifstream is_;
- const int data_length_;
- const int chunk_size_;
- std::vector<T> buf_;
+ virtual ~FileReader() = default;
+ // Number of values in the file.
+ virtual int size() const = 0;
+ // Reads `dst.size()` float values into `dst`, advances the internal file
+ // position according to the number of read bytes and returns true if the
+ // values are correctly read. If the number of remaining bytes in the file is
+ // not sufficient to read `dst.size()` float values, `dst` is partially
+ // modified and false is returned.
+ virtual bool ReadChunk(rtc::ArrayView<float> dst) = 0;
+ // Reads a single float value, advances the internal file position according
+ // to the number of read bytes and returns true if the value is correctly
+ // read. If the number of remaining bytes in the file is not sufficient to
+ // read one float, `dst` is not modified and false is returned.
+ virtual bool ReadValue(float& dst) = 0;
+ // Advances the internal file position by `hop` float values.
+ virtual void SeekForward(int hop) = 0;
+ // Resets the internal file position to BOF.
+ virtual void SeekBeginning() = 0;
};
-// Writer for binary files.
-template <typename T>
-class BinaryFileWriter {
- public:
- explicit BinaryFileWriter(const std::string& file_path)
- : os_(file_path, std::ios::binary) {}
- BinaryFileWriter(const BinaryFileWriter&) = delete;
- BinaryFileWriter& operator=(const BinaryFileWriter&) = delete;
- ~BinaryFileWriter() = default;
- static_assert(std::is_arithmetic<T>::value, "");
- void WriteChunk(rtc::ArrayView<const T> value) {
- const std::streamsize bytes_to_write = value.size() * sizeof(T);
- os_.write(reinterpret_cast<const char*>(value.data()), bytes_to_write);
- }
-
- private:
- std::ofstream os_;
+// File reader for files that contain `num_chunks` chunks with size equal to
+// `chunk_size`.
+struct ChunksFileReader {
+ const int chunk_size;
+ const int num_chunks;
+ std::unique_ptr<FileReader> reader;
};
-// Factories for resource file readers.
-// The functions below return a pair where the first item is a reader unique
-// pointer and the second the number of chunks that can be read from the file.
-// Creates a reader for the PCM samples that casts from S16 to float and reads
-// chunks with length |frame_length|.
-std::pair<std::unique_ptr<BinaryFileReader<int16_t, float>>, const int>
-CreatePcmSamplesReader(const int frame_length);
-// Creates a reader for the pitch buffer content at 24 kHz.
-std::pair<std::unique_ptr<BinaryFileReader<float>>, const int>
-CreatePitchBuffer24kHzReader();
-// Creates a reader for the the LP residual coefficients and the pitch period
-// and gain values.
-std::pair<std::unique_ptr<BinaryFileReader<float>>, const int>
-CreateLpResidualAndPitchPeriodGainReader();
-// Creates a reader for the VAD probabilities.
-std::pair<std::unique_ptr<BinaryFileReader<float>>, const int>
-CreateVadProbsReader();
+// Creates a reader for the PCM S16 samples file.
+std::unique_ptr<FileReader> CreatePcmSamplesReader();
-constexpr int kNumPitchBufAutoCorrCoeffs = 147;
-constexpr int kNumPitchBufSquareEnergies = 385;
-constexpr int kPitchTestDataSize =
- kBufSize24kHz + kNumPitchBufSquareEnergies + kNumPitchBufAutoCorrCoeffs;
+// Creates a reader for the 24 kHz pitch buffer test data.
+ChunksFileReader CreatePitchBuffer24kHzReader();
+
+// Creates a reader for the LP residual and pitch information test data.
+ChunksFileReader CreateLpResidualAndPitchInfoReader();
+
+// Creates a reader for the VAD probabilities test data.
+std::unique_ptr<FileReader> CreateVadProbsReader();
// Class to retrieve a test pitch buffer content and the expected output for the
// analysis steps.
@@ -142,17 +86,40 @@
public:
PitchTestData();
~PitchTestData();
- rtc::ArrayView<const float, kBufSize24kHz> GetPitchBufView() const;
- rtc::ArrayView<const float, kNumPitchBufSquareEnergies>
- GetPitchBufSquareEnergiesView() const;
- rtc::ArrayView<const float, kNumPitchBufAutoCorrCoeffs>
- GetPitchBufAutoCorrCoeffsView() const;
+ rtc::ArrayView<const float, kBufSize24kHz> PitchBuffer24kHzView() const {
+ return pitch_buffer_24k_;
+ }
+ rtc::ArrayView<const float, kRefineNumLags24kHz> SquareEnergies24kHzView()
+ const {
+ return square_energies_24k_;
+ }
+ rtc::ArrayView<const float, kNumLags12kHz> AutoCorrelation12kHzView() const {
+ return auto_correlation_12k_;
+ }
private:
- std::array<float, kPitchTestDataSize> test_data_;
+ std::array<float, kBufSize24kHz> pitch_buffer_24k_;
+ std::array<float, kRefineNumLags24kHz> square_energies_24k_;
+ std::array<float, kNumLags12kHz> auto_correlation_12k_;
};
-} // namespace test
+// Writer for binary files.
+class FileWriter {
+ public:
+ explicit FileWriter(const std::string& file_path)
+ : os_(file_path, std::ios::binary) {}
+ FileWriter(const FileWriter&) = delete;
+ FileWriter& operator=(const FileWriter&) = delete;
+ ~FileWriter() = default;
+ void WriteChunk(rtc::ArrayView<const float> value) {
+ const std::streamsize bytes_to_write = value.size() * sizeof(float);
+ os_.write(reinterpret_cast<const char*>(value.data()), bytes_to_write);
+ }
+
+ private:
+ std::ofstream os_;
+};
+
} // namespace rnn_vad
} // namespace webrtc