blob: 152d5698238f813954433269074ffcb170522437 [file] [log] [blame]
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h"
#include <array>
#include <tuple>
#include "modules/audio_processing/agc2/rnn_vad/test_utils.h"
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
// #include "test/fpe_observer.h"
#include "test/gtest.h"
namespace webrtc {
namespace rnn_vad {
namespace test {
namespace {
constexpr int kTestPitchPeriodsLow = 3 * kMinPitch48kHz / 2;
constexpr int kTestPitchPeriodsHigh = (3 * kMinPitch48kHz + kMaxPitch48kHz) / 2;
constexpr float kTestPitchGainsLow = 0.35f;
constexpr float kTestPitchGainsHigh = 0.75f;
} // namespace
// Checks that the frame-wise sliding square energy function produces output
// within tolerance given test input data.
TEST(RnnVadTest, ComputeSlidingFrameSquareEnergies24kHzWithinTolerance) {
PitchTestData test_data;
std::array<float, kNumPitchBufSquareEnergies> computed_output;
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
// FloatingPointExceptionObserver fpe_observer;
ComputeSlidingFrameSquareEnergies24kHz(test_data.GetPitchBufView(),
computed_output);
auto square_energies_view = test_data.GetPitchBufSquareEnergiesView();
ExpectNearAbsolute({square_energies_view.data(), square_energies_view.size()},
computed_output, 1e-3f);
}
// Checks that the estimated pitch period is bit-exact given test input data.
TEST(RnnVadTest, ComputePitchPeriod12kHzBitExactness) {
PitchTestData test_data;
std::array<float, kBufSize12kHz> pitch_buf_decimated;
Decimate2x(test_data.GetPitchBufView(), pitch_buf_decimated);
CandidatePitchPeriods pitch_candidates;
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
// FloatingPointExceptionObserver fpe_observer;
auto auto_corr_view = test_data.GetPitchBufAutoCorrCoeffsView();
pitch_candidates =
ComputePitchPeriod12kHz(pitch_buf_decimated, auto_corr_view);
EXPECT_EQ(pitch_candidates.best, 140);
EXPECT_EQ(pitch_candidates.second_best, 142);
}
// Checks that the refined pitch period is bit-exact given test input data.
TEST(RnnVadTest, ComputePitchPeriod48kHzBitExactness) {
PitchTestData test_data;
std::vector<float> y_energy(kRefineNumLags24kHz);
rtc::ArrayView<float, kRefineNumLags24kHz> y_energy_view(y_energy.data(),
kRefineNumLags24kHz);
ComputeSlidingFrameSquareEnergies24kHz(test_data.GetPitchBufView(),
y_energy_view);
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
// FloatingPointExceptionObserver fpe_observer;
EXPECT_EQ(ComputePitchPeriod48kHz(test_data.GetPitchBufView(), y_energy_view,
/*pitch_candidates=*/{280, 284}),
560);
EXPECT_EQ(ComputePitchPeriod48kHz(test_data.GetPitchBufView(), y_energy_view,
/*pitch_candidates=*/{260, 284}),
568);
}
class PitchCandidatesParametrization
: public ::testing::TestWithParam<CandidatePitchPeriods> {
protected:
CandidatePitchPeriods GetPitchCandidates() const { return GetParam(); }
CandidatePitchPeriods GetSwappedPitchCandidates() const {
CandidatePitchPeriods candidate = GetParam();
return {candidate.second_best, candidate.best};
}
};
// Checks that the result of `ComputePitchPeriod48kHz()` does not depend on the
// order of the input pitch candidates.
TEST_P(PitchCandidatesParametrization,
ComputePitchPeriod48kHzOrderDoesNotMatter) {
PitchTestData test_data;
std::vector<float> y_energy(kRefineNumLags24kHz);
rtc::ArrayView<float, kRefineNumLags24kHz> y_energy_view(y_energy.data(),
kRefineNumLags24kHz);
ComputeSlidingFrameSquareEnergies24kHz(test_data.GetPitchBufView(),
y_energy_view);
EXPECT_EQ(ComputePitchPeriod48kHz(test_data.GetPitchBufView(), y_energy_view,
GetPitchCandidates()),
ComputePitchPeriod48kHz(test_data.GetPitchBufView(), y_energy_view,
GetSwappedPitchCandidates()));
}
INSTANTIATE_TEST_SUITE_P(RnnVadTest,
PitchCandidatesParametrization,
::testing::Values(CandidatePitchPeriods{0, 2},
CandidatePitchPeriods{260, 284},
CandidatePitchPeriods{280, 284},
CandidatePitchPeriods{
kInitialNumLags24kHz - 2,
kInitialNumLags24kHz - 1}));
class ExtendedPitchPeriodSearchParametrizaion
: public ::testing::TestWithParam<std::tuple<int, int, float, int, float>> {
protected:
int GetInitialPitchPeriod() const { return std::get<0>(GetParam()); }
int GetLastPitchPeriod() const { return std::get<1>(GetParam()); }
float GetLastPitchStrength() const { return std::get<2>(GetParam()); }
int GetExpectedPitchPeriod() const { return std::get<3>(GetParam()); }
float GetExpectedPitchStrength() const { return std::get<4>(GetParam()); }
};
// Checks that the computed pitch period is bit-exact and that the computed
// pitch strength is within tolerance given test input data.
TEST_P(ExtendedPitchPeriodSearchParametrizaion,
PeriodBitExactnessGainWithinTolerance) {
PitchTestData test_data;
std::vector<float> y_energy(kRefineNumLags24kHz);
rtc::ArrayView<float, kRefineNumLags24kHz> y_energy_view(y_energy.data(),
kRefineNumLags24kHz);
ComputeSlidingFrameSquareEnergies24kHz(test_data.GetPitchBufView(),
y_energy_view);
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
// FloatingPointExceptionObserver fpe_observer;
const auto computed_output = ComputeExtendedPitchPeriod48kHz(
test_data.GetPitchBufView(), y_energy_view, GetInitialPitchPeriod(),
{GetLastPitchPeriod(), GetLastPitchStrength()});
EXPECT_EQ(GetExpectedPitchPeriod(), computed_output.period);
EXPECT_NEAR(GetExpectedPitchStrength(), computed_output.strength, 1e-6f);
}
INSTANTIATE_TEST_SUITE_P(
RnnVadTest,
ExtendedPitchPeriodSearchParametrizaion,
::testing::Values(std::make_tuple(kTestPitchPeriodsLow,
kTestPitchPeriodsLow,
kTestPitchGainsLow,
91,
-0.0188608f),
std::make_tuple(kTestPitchPeriodsLow,
kTestPitchPeriodsLow,
kTestPitchGainsHigh,
91,
-0.0188608f),
std::make_tuple(kTestPitchPeriodsLow,
kTestPitchPeriodsHigh,
kTestPitchGainsLow,
91,
-0.0188608f),
std::make_tuple(kTestPitchPeriodsLow,
kTestPitchPeriodsHigh,
kTestPitchGainsHigh,
91,
-0.0188608f),
std::make_tuple(kTestPitchPeriodsHigh,
kTestPitchPeriodsLow,
kTestPitchGainsLow,
475,
-0.0904344f),
std::make_tuple(kTestPitchPeriodsHigh,
kTestPitchPeriodsLow,
kTestPitchGainsHigh,
475,
-0.0904344f),
std::make_tuple(kTestPitchPeriodsHigh,
kTestPitchPeriodsHigh,
kTestPitchGainsLow,
475,
-0.0904344f),
std::make_tuple(kTestPitchPeriodsHigh,
kTestPitchPeriodsHigh,
kTestPitchGainsHigh,
475,
-0.0904344f)));
} // namespace test
} // namespace rnn_vad
} // namespace webrtc