Add parameterization for three multi channel AEC3 unit tests
Bug: webrtc:11295
Change-Id: I478aa02908c494cf9609db00021438a59a132b66
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/167202
Commit-Queue: Sam Zackrisson <saza@webrtc.org>
Reviewed-by: Per Ã…hgren <peah@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#30370}
diff --git a/modules/audio_processing/aec3/erl_estimator_unittest.cc b/modules/audio_processing/aec3/erl_estimator_unittest.cc
index 344551d..79e5465 100644
--- a/modules/audio_processing/aec3/erl_estimator_unittest.cc
+++ b/modules/audio_processing/aec3/erl_estimator_unittest.cc
@@ -34,67 +34,71 @@
} // namespace
+class ErlEstimatorMultiChannel
+ : public ::testing::Test,
+ public ::testing::WithParamInterface<std::tuple<size_t, size_t>> {};
+
+INSTANTIATE_TEST_SUITE_P(MultiChannel,
+ ErlEstimatorMultiChannel,
+ ::testing::Combine(::testing::Values(1, 2, 8),
+ ::testing::Values(1, 2, 8)));
+
// Verifies that the correct ERL estimates are achieved.
-TEST(ErlEstimator, Estimates) {
- for (size_t num_render_channels : {1, 2, 8}) {
- for (size_t num_capture_channels : {1, 2, 8}) {
- SCOPED_TRACE(ProduceDebugText(num_render_channels, num_capture_channels));
- std::vector<std::array<float, kFftLengthBy2Plus1>> X2(
- num_render_channels);
- for (auto& X2_ch : X2) {
- X2_ch.fill(0.f);
- }
- std::vector<std::array<float, kFftLengthBy2Plus1>> Y2(
- num_capture_channels);
- for (auto& Y2_ch : Y2) {
- Y2_ch.fill(0.f);
- }
- std::vector<bool> converged_filters(num_capture_channels, false);
- const size_t converged_idx = num_capture_channels - 1;
- converged_filters[converged_idx] = true;
-
- ErlEstimator estimator(0);
-
- // Verifies that the ERL estimate is properly reduced to lower values.
- for (auto& X2_ch : X2) {
- X2_ch.fill(500 * 1000.f * 1000.f);
- }
- Y2[converged_idx].fill(10 * X2[0][0]);
- for (size_t k = 0; k < 200; ++k) {
- estimator.Update(converged_filters, X2, Y2);
- }
- VerifyErl(estimator.Erl(), estimator.ErlTimeDomain(), 10.f);
-
- // Verifies that the ERL is not immediately increased when the ERL in the
- // data increases.
- Y2[converged_idx].fill(10000 * X2[0][0]);
- for (size_t k = 0; k < 998; ++k) {
- estimator.Update(converged_filters, X2, Y2);
- }
- VerifyErl(estimator.Erl(), estimator.ErlTimeDomain(), 10.f);
-
- // Verifies that the rate of increase is 3 dB.
- estimator.Update(converged_filters, X2, Y2);
- VerifyErl(estimator.Erl(), estimator.ErlTimeDomain(), 20.f);
-
- // Verifies that the maximum ERL is achieved when there are no low RLE
- // estimates.
- for (size_t k = 0; k < 1000; ++k) {
- estimator.Update(converged_filters, X2, Y2);
- }
- VerifyErl(estimator.Erl(), estimator.ErlTimeDomain(), 1000.f);
-
- // Verifies that the ERL estimate is is not updated for low-level signals
- for (auto& X2_ch : X2) {
- X2_ch.fill(1000.f * 1000.f);
- }
- Y2[converged_idx].fill(10 * X2[0][0]);
- for (size_t k = 0; k < 200; ++k) {
- estimator.Update(converged_filters, X2, Y2);
- }
- VerifyErl(estimator.Erl(), estimator.ErlTimeDomain(), 1000.f);
- }
+TEST_P(ErlEstimatorMultiChannel, Estimates) {
+ const size_t num_render_channels = std::get<0>(GetParam());
+ const size_t num_capture_channels = std::get<1>(GetParam());
+ SCOPED_TRACE(ProduceDebugText(num_render_channels, num_capture_channels));
+ std::vector<std::array<float, kFftLengthBy2Plus1>> X2(num_render_channels);
+ for (auto& X2_ch : X2) {
+ X2_ch.fill(0.f);
}
-}
+ std::vector<std::array<float, kFftLengthBy2Plus1>> Y2(num_capture_channels);
+ for (auto& Y2_ch : Y2) {
+ Y2_ch.fill(0.f);
+ }
+ std::vector<bool> converged_filters(num_capture_channels, false);
+ const size_t converged_idx = num_capture_channels - 1;
+ converged_filters[converged_idx] = true;
+ ErlEstimator estimator(0);
+
+ // Verifies that the ERL estimate is properly reduced to lower values.
+ for (auto& X2_ch : X2) {
+ X2_ch.fill(500 * 1000.f * 1000.f);
+ }
+ Y2[converged_idx].fill(10 * X2[0][0]);
+ for (size_t k = 0; k < 200; ++k) {
+ estimator.Update(converged_filters, X2, Y2);
+ }
+ VerifyErl(estimator.Erl(), estimator.ErlTimeDomain(), 10.f);
+
+ // Verifies that the ERL is not immediately increased when the ERL in the
+ // data increases.
+ Y2[converged_idx].fill(10000 * X2[0][0]);
+ for (size_t k = 0; k < 998; ++k) {
+ estimator.Update(converged_filters, X2, Y2);
+ }
+ VerifyErl(estimator.Erl(), estimator.ErlTimeDomain(), 10.f);
+
+ // Verifies that the rate of increase is 3 dB.
+ estimator.Update(converged_filters, X2, Y2);
+ VerifyErl(estimator.Erl(), estimator.ErlTimeDomain(), 20.f);
+
+ // Verifies that the maximum ERL is achieved when there are no low RLE
+ // estimates.
+ for (size_t k = 0; k < 1000; ++k) {
+ estimator.Update(converged_filters, X2, Y2);
+ }
+ VerifyErl(estimator.Erl(), estimator.ErlTimeDomain(), 1000.f);
+
+ // Verifies that the ERL estimate is is not updated for low-level signals
+ for (auto& X2_ch : X2) {
+ X2_ch.fill(1000.f * 1000.f);
+ }
+ Y2[converged_idx].fill(10 * X2[0][0]);
+ for (size_t k = 0; k < 200; ++k) {
+ estimator.Update(converged_filters, X2, Y2);
+ }
+ VerifyErl(estimator.Erl(), estimator.ErlTimeDomain(), 1000.f);
+}
} // namespace webrtc