blob: 23ff49a2fcf191c8bc1b2d5cc07480a98363e49d [file] [log] [blame]
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h"
#include <array>
#include <tuple>
#include "modules/audio_processing/agc2/rnn_vad/test_utils.h"
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
// #include "test/fpe_observer.h"
#include "test/gtest.h"
namespace webrtc {
namespace rnn_vad {
namespace test {
namespace {
constexpr int kTestPitchPeriodsLow = 3 * kMinPitch48kHz / 2;
constexpr int kTestPitchPeriodsHigh = (3 * kMinPitch48kHz + kMaxPitch48kHz) / 2;
constexpr float kTestPitchGainsLow = 0.35f;
constexpr float kTestPitchGainsHigh = 0.75f;
} // namespace
class ComputePitchGainThresholdTest
: public ::testing::Test,
public ::testing::WithParamInterface<std::tuple<
/*candidate_pitch_period=*/size_t,
/*pitch_period_ratio=*/size_t,
/*initial_pitch_period=*/size_t,
/*initial_pitch_gain=*/float,
/*prev_pitch_period=*/size_t,
/*prev_pitch_gain=*/float,
/*threshold=*/float>> {};
// Checks that the computed pitch gain is within tolerance given test input
// data.
TEST_P(ComputePitchGainThresholdTest, WithinTolerance) {
const auto params = GetParam();
const size_t candidate_pitch_period = std::get<0>(params);
const size_t pitch_period_ratio = std::get<1>(params);
const size_t initial_pitch_period = std::get<2>(params);
const float initial_pitch_gain = std::get<3>(params);
const size_t prev_pitch_period = std::get<4>(params);
const float prev_pitch_gain = std::get<5>(params);
const float threshold = std::get<6>(params);
{
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
// FloatingPointExceptionObserver fpe_observer;
EXPECT_NEAR(
threshold,
ComputePitchGainThreshold(candidate_pitch_period, pitch_period_ratio,
initial_pitch_period, initial_pitch_gain,
prev_pitch_period, prev_pitch_gain),
5e-7f);
}
}
INSTANTIATE_TEST_SUITE_P(
RnnVadTest,
ComputePitchGainThresholdTest,
::testing::Values(
std::make_tuple(31, 7, 219, 0.45649201f, 199, 0.604747f, 0.40000001f),
std::make_tuple(113,
2,
226,
0.20967799f,
219,
0.40392199f,
0.30000001f),
std::make_tuple(63, 2, 126, 0.210788f, 364, 0.098519f, 0.40000001f),
std::make_tuple(30, 5, 152, 0.82356697f, 149, 0.55535901f, 0.700032f),
std::make_tuple(76, 2, 151, 0.79522997f, 151, 0.82356697f, 0.675946f),
std::make_tuple(31, 5, 153, 0.85069299f, 150, 0.79073799f, 0.72308898f),
std::make_tuple(78, 2, 156, 0.72750503f, 153, 0.85069299f, 0.618379f)));
// Checks that the frame-wise sliding square energy function produces output
// within tolerance given test input data.
TEST(RnnVadTest, ComputeSlidingFrameSquareEnergiesWithinTolerance) {
PitchTestData test_data;
std::array<float, kNumPitchBufSquareEnergies> computed_output;
{
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
// FloatingPointExceptionObserver fpe_observer;
ComputeSlidingFrameSquareEnergies(test_data.GetPitchBufView(),
computed_output);
}
auto square_energies_view = test_data.GetPitchBufSquareEnergiesView();
ExpectNearAbsolute({square_energies_view.data(), square_energies_view.size()},
computed_output, 3e-2f);
}
// Checks that the estimated pitch period is bit-exact given test input data.
TEST(RnnVadTest, FindBestPitchPeriodsBitExactness) {
PitchTestData test_data;
std::array<float, kBufSize12kHz> pitch_buf_decimated;
Decimate2x(test_data.GetPitchBufView(), pitch_buf_decimated);
std::array<size_t, 2> pitch_candidates_inv_lags;
{
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
// FloatingPointExceptionObserver fpe_observer;
auto auto_corr_view = test_data.GetPitchBufAutoCorrCoeffsView();
pitch_candidates_inv_lags =
FindBestPitchPeriods({auto_corr_view.data(), auto_corr_view.size()},
pitch_buf_decimated, kMaxPitch12kHz);
}
EXPECT_EQ(pitch_candidates_inv_lags[0], static_cast<size_t>(140));
EXPECT_EQ(pitch_candidates_inv_lags[1], static_cast<size_t>(142));
}
// Checks that the refined pitch period is bit-exact given test input data.
TEST(RnnVadTest, RefinePitchPeriod48kHzBitExactness) {
PitchTestData test_data;
size_t pitch_inv_lag;
{
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
// FloatingPointExceptionObserver fpe_observer;
const std::array<size_t, 2> pitch_candidates_inv_lags = {280, 284};
pitch_inv_lag = RefinePitchPeriod48kHz(test_data.GetPitchBufView(),
pitch_candidates_inv_lags);
}
EXPECT_EQ(560u, pitch_inv_lag);
}
class CheckLowerPitchPeriodsAndComputePitchGainTest
: public ::testing::Test,
public ::testing::WithParamInterface<std::tuple<
/*initial_pitch_period=*/int,
/*prev_pitch_period=*/int,
/*prev_pitch_gain=*/float,
/*expected_pitch_period=*/int,
/*expected_pitch_gain=*/float>> {};
// Checks that the computed pitch period is bit-exact and that the computed
// pitch gain is within tolerance given test input data.
TEST_P(CheckLowerPitchPeriodsAndComputePitchGainTest,
PeriodBitExactnessGainWithinTolerance) {
const auto params = GetParam();
const int initial_pitch_period = std::get<0>(params);
const int prev_pitch_period = std::get<1>(params);
const float prev_pitch_gain = std::get<2>(params);
const int expected_pitch_period = std::get<3>(params);
const float expected_pitch_gain = std::get<4>(params);
PitchTestData test_data;
{
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
// FloatingPointExceptionObserver fpe_observer;
const auto computed_output = CheckLowerPitchPeriodsAndComputePitchGain(
test_data.GetPitchBufView(), initial_pitch_period,
{prev_pitch_period, prev_pitch_gain});
EXPECT_EQ(expected_pitch_period, computed_output.period);
EXPECT_NEAR(expected_pitch_gain, computed_output.gain, 1e-6f);
}
}
INSTANTIATE_TEST_SUITE_P(
RnnVadTest,
CheckLowerPitchPeriodsAndComputePitchGainTest,
::testing::Values(std::make_tuple(kTestPitchPeriodsLow,
kTestPitchPeriodsLow,
kTestPitchGainsLow,
91,
-0.0188608f),
std::make_tuple(kTestPitchPeriodsLow,
kTestPitchPeriodsLow,
kTestPitchGainsHigh,
91,
-0.0188608f),
std::make_tuple(kTestPitchPeriodsLow,
kTestPitchPeriodsHigh,
kTestPitchGainsLow,
91,
-0.0188608f),
std::make_tuple(kTestPitchPeriodsLow,
kTestPitchPeriodsHigh,
kTestPitchGainsHigh,
91,
-0.0188608f),
std::make_tuple(kTestPitchPeriodsHigh,
kTestPitchPeriodsLow,
kTestPitchGainsLow,
475,
-0.0904344f),
std::make_tuple(kTestPitchPeriodsHigh,
kTestPitchPeriodsLow,
kTestPitchGainsHigh,
475,
-0.0904344f),
std::make_tuple(kTestPitchPeriodsHigh,
kTestPitchPeriodsHigh,
kTestPitchGainsLow,
475,
-0.0904344f),
std::make_tuple(kTestPitchPeriodsHigh,
kTestPitchPeriodsHigh,
kTestPitchGainsHigh,
475,
-0.0904344f)));
} // namespace test
} // namespace rnn_vad
} // namespace webrtc