blob: 31324453886f4a00f7a1eb92f6fb4ef90e49404c [file] [log] [blame]
// Copyright 2018 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include <array>
#include <cmath>
#include <limits>
#include <tuple>
#include <vector>
#include "testing/gtest/include/gtest/gtest.h"
#include "third_party/rnnoise/src/kiss_fft.h"
namespace rnnoise {
namespace test {
namespace {
const double kPi = std::acos(-1.0);
void FillFftInputBuffer(const size_t num_samples,
const float* samples,
std::complex<float>* input_buf) {
for (size_t i = 0; i < num_samples; ++i)
input_buf[i].real(samples[i]);
}
void CheckFftResult(const size_t num_fft_points,
const float* expected_real,
const float* expected_imag,
const std::complex<float>* computed,
const float tolerance) {
for (size_t i = 0; i < num_fft_points; ++i) {
SCOPED_TRACE(i);
EXPECT_NEAR(expected_real[i], computed[i].real(), tolerance);
EXPECT_NEAR(expected_imag[i], computed[i].imag(), tolerance);
}
}
} // namespace
class RnnVadTest
: public testing::Test,
public ::testing::WithParamInterface<std::tuple<size_t, float, float>> {};
// Check that IFFT(FFT(x)) == x (tolerating round-off errors).
TEST_P(RnnVadTest, KissFftForwardReverseCheckIdentity) {
const auto params = GetParam();
const float amplitude = std::get<0>(params);
const size_t num_fft = std::get<1>(params);
const float tolerance = std::get<2>(params);
std::vector<float> samples;
std::vector<float> zeros;
samples.resize(num_fft);
zeros.resize(num_fft);
for (size_t i = 0; i < num_fft; ++i) {
samples[i] = amplitude * std::sin(2.f * kPi * 10 * i / num_fft);
zeros[i] = 0.f;
}
KissFft fft(num_fft);
std::vector<std::complex<float>> fft_buf_1;
fft_buf_1.resize(num_fft);
std::vector<std::complex<float>> fft_buf_2;
fft_buf_2.resize(num_fft);
FillFftInputBuffer(samples.size(), samples.data(), fft_buf_1.data());
{
// TODO(alessiob): Underflow with non power of 2 frame sizes.
// FloatingPointExceptionObserver fpe_observer;
fft.ForwardFft(fft_buf_1.size(), fft_buf_1.data(), fft_buf_2.size(),
fft_buf_2.data());
fft.ReverseFft(fft_buf_2.size(), fft_buf_2.data(), fft_buf_1.size(),
fft_buf_1.data());
}
CheckFftResult(samples.size(), samples.data(), zeros.data(), fft_buf_1.data(),
tolerance);
}
INSTANTIATE_TEST_CASE_P(FftPoints,
RnnVadTest,
::testing::Values(std::make_tuple(1.f, 240, 3e-7f),
std::make_tuple(1.f, 256, 3e-7f),
std::make_tuple(1.f, 480, 3e-7f),
std::make_tuple(1.f, 512, 3e-7f),
std::make_tuple(1.f, 960, 4e-7f),
std::make_tuple(1.f, 1024, 3e-7f),
std::make_tuple(30.f, 240, 5e-6f),
std::make_tuple(30.f, 256, 5e-6f),
std::make_tuple(30.f, 480, 6e-6f),
std::make_tuple(30.f, 512, 6e-6f),
std::make_tuple(30.f, 960, 8e-6f),
std::make_tuple(30.f, 1024, 6e-6f)));
TEST(RnnVadTest, KissFftBitExactness) {
constexpr std::array<float, 32> samples = {
{0.3524301946163177490234375f, 0.891803801059722900390625f,
0.07706542313098907470703125f, 0.699530780315399169921875f,
0.3789891898632049560546875f, 0.5438187122344970703125f,
0.332781612873077392578125f, 0.449340641498565673828125f,
0.105229437351226806640625f, 0.722373783588409423828125f,
0.13155306875705718994140625f, 0.340857982635498046875f,
0.970204889774322509765625f, 0.53061950206756591796875f,
0.91507828235626220703125f, 0.830274522304534912109375f,
0.74468600749969482421875f, 0.24320767819881439208984375f,
0.743998110294342041015625f, 0.17574800550937652587890625f,
0.1834825575351715087890625f, 0.63317775726318359375f,
0.11414264142513275146484375f, 0.1612723171710968017578125f,
0.80316197872161865234375f, 0.4979794919490814208984375f,
0.554282128810882568359375f, 0.67189347743988037109375f,
0.06660757958889007568359375f, 0.89568817615509033203125f,
0.29327380657196044921875f, 0.3472573757171630859375f}};
constexpr std::array<float, 17> expected_real = {
{0.4813065826892852783203125f, -0.0246877372264862060546875f,
0.04095232486724853515625f, -0.0401695556938648223876953125f,
0.00500857271254062652587890625f, 0.0160773508250713348388671875f,
-0.011385642923414707183837890625f, -0.008461721241474151611328125f,
0.01383177936077117919921875f, 0.0117270611226558685302734375f,
-0.0164460353553295135498046875f, 0.0585579685866832733154296875f,
0.02038039825856685638427734375f, -0.0209107734262943267822265625f,
0.01046995259821414947509765625f, -0.09019653499126434326171875f,
-0.0583711564540863037109375f}};
constexpr std::array<float, 17> expected_imag = {
{0.f, -0.010482530109584331512451171875f, 0.04762755334377288818359375f,
-0.0558677613735198974609375f, 0.007908363826572895050048828125f,
-0.0071932487189769744873046875f, 0.01322011835873126983642578125f,
-0.011227893643081188201904296875f, -0.0400779247283935546875f,
-0.0290451310575008392333984375f, 0.01519204117357730865478515625f,
-0.09711246192455291748046875f, -0.00136523949913680553436279296875f,
0.038602568209171295166015625f, -0.009693108499050140380859375f,
-0.0183933563530445098876953125f, 0.f}};
KissFft fft(32);
std::array<std::complex<float>, 32> fft_buf_in;
std::array<std::complex<float>, 32> fft_buf_out;
FillFftInputBuffer(samples.size(), samples.data(), fft_buf_in.data());
fft.ForwardFft(fft_buf_in.size(), fft_buf_in.data(), fft_buf_out.size(),
fft_buf_out.data());
CheckFftResult(expected_real.size(), expected_real.data(),
expected_imag.data(), fft_buf_out.data(),
std::numeric_limits<float>::min());
}
} // namespace test
} // namespace rnnoise