/*
 *  Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
 *
 *  Use of this source code is governed by a BSD-style license
 *  that can be found in the LICENSE file in the root of the source
 *  tree. An additional intellectual property rights grant can be found
 *  in the file PATENTS.  All contributing project authors may
 *  be found in the AUTHORS file in the root of the source tree.
 */

#include "modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h"

#include <array>
#include <tuple>

#include "modules/audio_processing/agc2/rnn_vad/test_utils.h"
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
// #include "test/fpe_observer.h"
#include "test/gtest.h"

namespace webrtc {
namespace rnn_vad {
namespace test {
namespace {

constexpr std::array<int, 2> kTestPitchPeriods = {
    3 * kMinPitch48kHz / 2,
    (3 * kMinPitch48kHz + kMaxPitch48kHz) / 2,
};
constexpr std::array<float, 2> kTestPitchGains = {0.35f, 0.75f};

}  // namespace

class ComputePitchGainThresholdTest
    : public ::testing::Test,
      public ::testing::WithParamInterface<
          std::tuple<size_t, size_t, size_t, float, size_t, float, float>> {};

TEST_P(ComputePitchGainThresholdTest, BitExactness) {
  const auto params = GetParam();
  const size_t candidate_pitch_period = std::get<0>(params);
  const size_t pitch_period_ratio = std::get<1>(params);
  const size_t initial_pitch_period = std::get<2>(params);
  const float initial_pitch_gain = std::get<3>(params);
  const size_t prev_pitch_period = std::get<4>(params);
  const size_t prev_pitch_gain = std::get<5>(params);
  const float threshold = std::get<6>(params);

  {
    // TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
    // FloatingPointExceptionObserver fpe_observer;

    EXPECT_NEAR(
        threshold,
        ComputePitchGainThreshold(candidate_pitch_period, pitch_period_ratio,
                                  initial_pitch_period, initial_pitch_gain,
                                  prev_pitch_period, prev_pitch_gain),
        3e-6f);
  }
}

INSTANTIATE_TEST_SUITE_P(
    RnnVadTest,
    ComputePitchGainThresholdTest,
    ::testing::Values(
        std::make_tuple(31, 7, 219, 0.45649201f, 199, 0.604747f, 0.40000001f),
        std::make_tuple(113,
                        2,
                        226,
                        0.20967799f,
                        219,
                        0.40392199f,
                        0.30000001f),
        std::make_tuple(63, 2, 126, 0.210788f, 364, 0.098519f, 0.40000001f),
        std::make_tuple(30, 5, 152, 0.82356697f, 149, 0.55535901f, 0.700032f),
        std::make_tuple(76, 2, 151, 0.79522997f, 151, 0.82356697f, 0.675946f),
        std::make_tuple(31, 5, 153, 0.85069299f, 150, 0.79073799f, 0.72308898f),
        std::make_tuple(78, 2, 156, 0.72750503f, 153, 0.85069299f, 0.618379f)));

TEST(RnnVadTest, ComputeSlidingFrameSquareEnergiesBitExactness) {
  PitchTestData test_data;
  std::array<float, kNumPitchBufSquareEnergies> computed_output;
  {
    // TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
    // FloatingPointExceptionObserver fpe_observer;
    ComputeSlidingFrameSquareEnergies(test_data.GetPitchBufView(),
                                      computed_output);
  }
  auto square_energies_view = test_data.GetPitchBufSquareEnergiesView();
  ExpectNearAbsolute({square_energies_view.data(), square_energies_view.size()},
                     computed_output, 3e-2f);
}

TEST(RnnVadTest, FindBestPitchPeriodsBitExactness) {
  PitchTestData test_data;
  std::array<float, kBufSize12kHz> pitch_buf_decimated;
  Decimate2x(test_data.GetPitchBufView(), pitch_buf_decimated);
  std::array<size_t, 2> pitch_candidates_inv_lags;
  {
    // TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
    // FloatingPointExceptionObserver fpe_observer;
    auto auto_corr_view = test_data.GetPitchBufAutoCorrCoeffsView();
    pitch_candidates_inv_lags =
        FindBestPitchPeriods({auto_corr_view.data(), auto_corr_view.size()},
                             pitch_buf_decimated, kMaxPitch12kHz);
  }
  const std::array<size_t, 2> expected_output = {140, 142};
  EXPECT_EQ(expected_output, pitch_candidates_inv_lags);
}

TEST(RnnVadTest, RefinePitchPeriod48kHzBitExactness) {
  PitchTestData test_data;
  std::array<float, kBufSize12kHz> pitch_buf_decimated;
  Decimate2x(test_data.GetPitchBufView(), pitch_buf_decimated);
  size_t pitch_inv_lag;
  {
    // TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
    // FloatingPointExceptionObserver fpe_observer;
    const std::array<size_t, 2> pitch_candidates_inv_lags = {280, 284};
    pitch_inv_lag = RefinePitchPeriod48kHz(test_data.GetPitchBufView(),
                                           pitch_candidates_inv_lags);
  }
  EXPECT_EQ(560u, pitch_inv_lag);
}

class CheckLowerPitchPeriodsAndComputePitchGainTest
    : public ::testing::Test,
      public ::testing::WithParamInterface<
          std::tuple<int, int, float, int, float>> {};

TEST_P(CheckLowerPitchPeriodsAndComputePitchGainTest, BitExactness) {
  const auto params = GetParam();
  const int initial_pitch_period = std::get<0>(params);
  const int prev_pitch_period = std::get<1>(params);
  const float prev_pitch_gain = std::get<2>(params);
  const int expected_pitch_period = std::get<3>(params);
  const float expected_pitch_gain = std::get<4>(params);
  PitchTestData test_data;
  {
    // TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
    // FloatingPointExceptionObserver fpe_observer;
    const auto computed_output = CheckLowerPitchPeriodsAndComputePitchGain(
        test_data.GetPitchBufView(), initial_pitch_period,
        {prev_pitch_period, prev_pitch_gain});
    EXPECT_EQ(expected_pitch_period, computed_output.period);
    EXPECT_NEAR(expected_pitch_gain, computed_output.gain, 1e-6f);
  }
}

INSTANTIATE_TEST_SUITE_P(RnnVadTest,
                         CheckLowerPitchPeriodsAndComputePitchGainTest,
                         ::testing::Values(std::make_tuple(kTestPitchPeriods[0],
                                                           kTestPitchPeriods[0],
                                                           kTestPitchGains[0],
                                                           91,
                                                           -0.0188608f),
                                           std::make_tuple(kTestPitchPeriods[0],
                                                           kTestPitchPeriods[0],
                                                           kTestPitchGains[1],
                                                           91,
                                                           -0.0188608f),
                                           std::make_tuple(kTestPitchPeriods[0],
                                                           kTestPitchPeriods[1],
                                                           kTestPitchGains[0],
                                                           91,
                                                           -0.0188608f),
                                           std::make_tuple(kTestPitchPeriods[0],
                                                           kTestPitchPeriods[1],
                                                           kTestPitchGains[1],
                                                           91,
                                                           -0.0188608f),
                                           std::make_tuple(kTestPitchPeriods[1],
                                                           kTestPitchPeriods[0],
                                                           kTestPitchGains[0],
                                                           475,
                                                           -0.0904344f),
                                           std::make_tuple(kTestPitchPeriods[1],
                                                           kTestPitchPeriods[0],
                                                           kTestPitchGains[1],
                                                           475,
                                                           -0.0904344f),
                                           std::make_tuple(kTestPitchPeriods[1],
                                                           kTestPitchPeriods[1],
                                                           kTestPitchGains[0],
                                                           475,
                                                           -0.0904344f),
                                           std::make_tuple(kTestPitchPeriods[1],
                                                           kTestPitchPeriods[1],
                                                           kTestPitchGains[1],
                                                           475,
                                                           -0.0904344f)));

}  // namespace test
}  // namespace rnn_vad
}  // namespace webrtc
