| /* |
| * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. |
| * |
| * Use of this source code is governed by a BSD-style license |
| * that can be found in the LICENSE file in the root of the source |
| * tree. An additional intellectual property rights grant can be found |
| * in the file PATENTS. All contributing project authors may |
| * be found in the AUTHORS file in the root of the source tree. |
| */ |
| |
| #include "modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h" |
| |
| #include <algorithm> |
| #include <cmath> |
| #include <numeric> |
| #include <utility> |
| |
| #include "modules/audio_processing/agc2/rnn_vad/common.h" |
| #include "rtc_base/checks.h" |
| |
| namespace webrtc { |
| namespace rnn_vad { |
| namespace { |
| |
| // Converts a lag to an inverted lag (only for 24kHz). |
| size_t GetInvertedLag(size_t lag) { |
| RTC_DCHECK_LE(lag, kMaxPitch24kHz); |
| return kMaxPitch24kHz - lag; |
| } |
| |
| float ComputeAutoCorrelationCoeff(rtc::ArrayView<const float> pitch_buf, |
| size_t inv_lag, |
| size_t max_pitch_period) { |
| RTC_DCHECK_LT(inv_lag, pitch_buf.size()); |
| RTC_DCHECK_LT(max_pitch_period, pitch_buf.size()); |
| RTC_DCHECK_LE(inv_lag, max_pitch_period); |
| // TODO(bugs.webrtc.org/9076): Maybe optimize using vectorization. |
| return std::inner_product(pitch_buf.begin() + max_pitch_period, |
| pitch_buf.end(), pitch_buf.begin() + inv_lag, 0.f); |
| } |
| |
| // Computes a pseudo-interpolation offset for an estimated pitch period |lag| by |
| // looking at the auto-correlation coefficients in the neighborhood of |lag|. |
| // (namely, |prev_auto_corr|, |lag_auto_corr| and |next_auto_corr|). The output |
| // is a lag in {-1, 0, +1}. |
| // TODO(bugs.webrtc.org/9076): Consider removing pseudo-i since it |
| // is relevant only if the spectral analysis works at a sample rate that is |
| // twice as that of the pitch buffer (not so important instead for the estimated |
| // pitch period feature fed into the RNN). |
| int GetPitchPseudoInterpolationOffset(size_t lag, |
| float prev_auto_corr, |
| float lag_auto_corr, |
| float next_auto_corr) { |
| const float& a = prev_auto_corr; |
| const float& b = lag_auto_corr; |
| const float& c = next_auto_corr; |
| |
| int offset = 0; |
| if ((c - a) > 0.7f * (b - a)) { |
| offset = 1; // |c| is the largest auto-correlation coefficient. |
| } else if ((a - c) > 0.7f * (b - c)) { |
| offset = -1; // |a| is the largest auto-correlation coefficient. |
| } |
| return offset; |
| } |
| |
| // Refines a pitch period |lag| encoded as lag with pseudo-interpolation. The |
| // output sample rate is twice as that of |lag|. |
| size_t PitchPseudoInterpolationLagPitchBuf( |
| size_t lag, |
| rtc::ArrayView<const float, kBufSize24kHz> pitch_buf) { |
| int offset = 0; |
| // Cannot apply pseudo-interpolation at the boundaries. |
| if (lag > 0 && lag < kMaxPitch24kHz) { |
| offset = GetPitchPseudoInterpolationOffset( |
| lag, |
| ComputeAutoCorrelationCoeff(pitch_buf, GetInvertedLag(lag - 1), |
| kMaxPitch24kHz), |
| ComputeAutoCorrelationCoeff(pitch_buf, GetInvertedLag(lag), |
| kMaxPitch24kHz), |
| ComputeAutoCorrelationCoeff(pitch_buf, GetInvertedLag(lag + 1), |
| kMaxPitch24kHz)); |
| } |
| return 2 * lag + offset; |
| } |
| |
| // Refines a pitch period |inv_lag| encoded as inverted lag with |
| // pseudo-interpolation. The output sample rate is twice as that of |
| // |inv_lag|. |
| size_t PitchPseudoInterpolationInvLagAutoCorr( |
| size_t inv_lag, |
| rtc::ArrayView<const float> auto_corr) { |
| int offset = 0; |
| // Cannot apply pseudo-interpolation at the boundaries. |
| if (inv_lag > 0 && inv_lag < auto_corr.size() - 1) { |
| offset = GetPitchPseudoInterpolationOffset(inv_lag, auto_corr[inv_lag + 1], |
| auto_corr[inv_lag], |
| auto_corr[inv_lag - 1]); |
| } |
| // TODO(bugs.webrtc.org/9076): When retraining, check if |offset| below should |
| // be subtracted since |inv_lag| is an inverted lag but offset is a lag. |
| return 2 * inv_lag + offset; |
| } |
| |
| // Integer multipliers used in CheckLowerPitchPeriodsAndComputePitchGain() when |
| // looking for sub-harmonics. |
| // The values have been chosen to serve the following algorithm. Given the |
| // initial pitch period T, we examine whether one of its harmonics is the true |
| // fundamental frequency. We consider T/k with k in {2, ..., 15}. For each of |
| // these harmonics, in addition to the pitch gain of itself, we choose one |
| // multiple of its pitch period, n*T/k, to validate it (by averaging their pitch |
| // gains). The multiplier n is chosen so that n*T/k is used only one time over |
| // all k. When for example k = 4, we should also expect a peak at 3*T/4. When |
| // k = 8 instead we don't want to look at 2*T/8, since we have already checked |
| // T/4 before. Instead, we look at T*3/8. |
| // The array can be generate in Python as follows: |
| // from fractions import Fraction |
| // # Smallest positive integer not in X. |
| // def mex(X): |
| // for i in range(1, int(max(X)+2)): |
| // if i not in X: |
| // return i |
| // # Visited multiples of the period. |
| // S = {1} |
| // for n in range(2, 16): |
| // sn = mex({n * i for i in S} | {1}) |
| // S = S | {Fraction(1, n), Fraction(sn, n)} |
| // print(sn, end=', ') |
| constexpr std::array<size_t, 14> kSubHarmonicMultipliers = { |
| {3, 2, 3, 2, 5, 2, 3, 2, 3, 2, 5, 2, 3, 2}}; |
| |
| // Initial pitch period candidate thresholds for ComputePitchGainThreshold() for |
| // a sample rate of 24 kHz. Computed as [5*k*k for k in range(16)]. |
| constexpr std::array<size_t, 14> kInitialPitchPeriodThresholds = { |
| {20, 45, 80, 125, 180, 245, 320, 405, 500, 605, 720, 845, 980, 1125}}; |
| |
| } // namespace |
| |
| void Decimate2x(rtc::ArrayView<const float, kBufSize24kHz> src, |
| rtc::ArrayView<float, kBufSize12kHz> dst) { |
| // TODO(bugs.webrtc.org/9076): Consider adding anti-aliasing filter. |
| static_assert(2 * dst.size() == src.size(), ""); |
| for (size_t i = 0; i < dst.size(); ++i) { |
| dst[i] = src[2 * i]; |
| } |
| } |
| |
| float ComputePitchGainThreshold(size_t candidate_pitch_period, |
| size_t pitch_period_ratio, |
| size_t initial_pitch_period, |
| float initial_pitch_gain, |
| size_t prev_pitch_period, |
| size_t prev_pitch_gain) { |
| // Map arguments to more compact aliases. |
| const size_t& t1 = candidate_pitch_period; |
| const size_t& k = pitch_period_ratio; |
| const size_t& t0 = initial_pitch_period; |
| const float& g0 = initial_pitch_gain; |
| const size_t& t_prev = prev_pitch_period; |
| const size_t& g_prev = prev_pitch_gain; |
| |
| // Validate input. |
| RTC_DCHECK_GE(k, 2); |
| |
| // Compute a term that lowers the threshold when |t1| is close to the last |
| // estimated period |t_prev| - i.e., pitch tracking. |
| float lower_threshold_term = 0; |
| if (abs(static_cast<int>(t1) - static_cast<int>(t_prev)) <= 1) { |
| // The candidate pitch period is within 1 sample from the previous one. |
| // Make the candidate at |t1| very easy to be accepted. |
| lower_threshold_term = g_prev; |
| } else if (abs(static_cast<int>(t1) - static_cast<int>(t_prev)) == 2 && |
| t0 > kInitialPitchPeriodThresholds[k - 2]) { |
| // The candidate pitch period is 2 samples far from the previous one and the |
| // period |t0| (from which |t1| has been derived) is greater than a |
| // threshold. Make |t1| easy to be accepted. |
| lower_threshold_term = 0.5f * g_prev; |
| } |
| // Set the threshold based on the gain of the initial estimate |t0|. Also |
| // reduce the chance of false positives caused by a bias towards high |
| // frequencies (originating from short-term correlations). |
| float threshold = std::max(0.3f, 0.7f * g0 - lower_threshold_term); |
| if (t1 < 3 * kMinPitch24kHz) { // High frequency. |
| threshold = std::max(0.4f, 0.85f * g0 - lower_threshold_term); |
| } else if (t1 < 2 * kMinPitch24kHz) { // Even higher frequency. |
| threshold = std::max(0.5f, 0.9f * g0 - lower_threshold_term); |
| } |
| return threshold; |
| } |
| |
| void ComputeSlidingFrameSquareEnergies( |
| rtc::ArrayView<const float, kBufSize24kHz> pitch_buf, |
| rtc::ArrayView<float, kMaxPitch24kHz + 1> yy_values) { |
| float yy = |
| ComputeAutoCorrelationCoeff(pitch_buf, kMaxPitch24kHz, kMaxPitch24kHz); |
| yy_values[0] = yy; |
| for (size_t i = 1; i < yy_values.size(); ++i) { |
| RTC_DCHECK_LE(i, kMaxPitch24kHz + kFrameSize20ms24kHz); |
| RTC_DCHECK_LE(i, kMaxPitch24kHz); |
| const float old_coeff = pitch_buf[kMaxPitch24kHz + kFrameSize20ms24kHz - i]; |
| const float new_coeff = pitch_buf[kMaxPitch24kHz - i]; |
| yy -= old_coeff * old_coeff; |
| yy += new_coeff * new_coeff; |
| yy = std::max(0.f, yy); |
| yy_values[i] = yy; |
| } |
| } |
| |
| void ComputePitchAutoCorrelation( |
| rtc::ArrayView<const float, kBufSize12kHz> pitch_buf, |
| size_t max_pitch_period, |
| rtc::ArrayView<float, kNumInvertedLags12kHz> auto_corr, |
| webrtc::RealFourier* fft) { |
| RTC_DCHECK_GT(max_pitch_period, auto_corr.size()); |
| RTC_DCHECK_LT(max_pitch_period, pitch_buf.size()); |
| RTC_DCHECK(fft); |
| |
| constexpr size_t time_domain_fft_length = 1 << kAutoCorrelationFftOrder; |
| constexpr size_t freq_domain_fft_length = time_domain_fft_length / 2 + 1; |
| |
| RTC_DCHECK_EQ(RealFourier::FftLength(fft->order()), time_domain_fft_length); |
| RTC_DCHECK_EQ(RealFourier::ComplexLength(fft->order()), |
| freq_domain_fft_length); |
| |
| // Cross-correlation of y_i=pitch_buf[i:i+convolution_length] and |
| // x=pitch_buf[-convolution_length:] is equivalent to convolution of |
| // y_i and reversed(x). New notation: h=reversed(x), x=y. |
| std::array<float, time_domain_fft_length> h{}; |
| std::array<float, time_domain_fft_length> x{}; |
| |
| const size_t convolution_length = kBufSize12kHz - max_pitch_period; |
| // Check that the FFT-length is big enough to avoid cyclic |
| // convolution errors. |
| RTC_DCHECK_GT(time_domain_fft_length, |
| kNumInvertedLags12kHz + convolution_length); |
| |
| // h[0:convolution_length] is reversed pitch_buf[-convolution_length:]. |
| std::reverse_copy(pitch_buf.end() - convolution_length, pitch_buf.end(), |
| h.begin()); |
| |
| // x is pitch_buf[:kNumInvertedLags12kHz + convolution_length]. |
| std::copy(pitch_buf.begin(), |
| pitch_buf.begin() + kNumInvertedLags12kHz + convolution_length, |
| x.begin()); |
| |
| // Shift to frequency domain. |
| std::array<std::complex<float>, freq_domain_fft_length> X{}; |
| std::array<std::complex<float>, freq_domain_fft_length> H{}; |
| fft->Forward(&x[0], &X[0]); |
| fft->Forward(&h[0], &H[0]); |
| |
| // Convolve in frequency domain. |
| for (size_t i = 0; i < X.size(); ++i) { |
| X[i] *= H[i]; |
| } |
| |
| // Shift back to time domain. |
| std::array<float, time_domain_fft_length> x_conv_h; |
| fft->Inverse(&X[0], &x_conv_h[0]); |
| |
| // Collect the result. |
| std::copy(x_conv_h.begin() + convolution_length - 1, |
| x_conv_h.begin() + convolution_length + kNumInvertedLags12kHz - 1, |
| auto_corr.begin()); |
| } |
| |
| std::array<size_t, 2> FindBestPitchPeriods( |
| rtc::ArrayView<const float> auto_corr, |
| rtc::ArrayView<const float> pitch_buf, |
| size_t max_pitch_period) { |
| // Stores a pitch candidate period and strength information. |
| struct PitchCandidate { |
| // Pitch period encoded as inverted lag. |
| size_t period_inverted_lag = 0; |
| // Pitch strength encoded as a ratio. |
| float strength_numerator = -1.f; |
| float strength_denominator = 0.f; |
| // Compare the strength of two pitch candidates. |
| bool HasStrongerPitchThan(const PitchCandidate& b) const { |
| // Comparing the numerator/denominator ratios without using divisions. |
| return strength_numerator * b.strength_denominator > |
| b.strength_numerator * strength_denominator; |
| } |
| }; |
| |
| RTC_DCHECK_GT(max_pitch_period, auto_corr.size()); |
| RTC_DCHECK_LT(max_pitch_period, pitch_buf.size()); |
| const size_t frame_size = pitch_buf.size() - max_pitch_period; |
| // TODO(bugs.webrtc.org/9076): Maybe optimize using vectorization. |
| float yy = |
| std::inner_product(pitch_buf.begin(), pitch_buf.begin() + frame_size + 1, |
| pitch_buf.begin(), 1.f); |
| // Search best and second best pitches by looking at the scaled |
| // auto-correlation. |
| PitchCandidate candidate; |
| PitchCandidate best; |
| PitchCandidate second_best; |
| second_best.period_inverted_lag = 1; |
| for (size_t inv_lag = 0; inv_lag < auto_corr.size(); ++inv_lag) { |
| // A pitch candidate must have positive correlation. |
| if (auto_corr[inv_lag] > 0) { |
| candidate.period_inverted_lag = inv_lag; |
| candidate.strength_numerator = auto_corr[inv_lag] * auto_corr[inv_lag]; |
| candidate.strength_denominator = yy; |
| if (candidate.HasStrongerPitchThan(second_best)) { |
| if (candidate.HasStrongerPitchThan(best)) { |
| second_best = best; |
| best = candidate; |
| } else { |
| second_best = candidate; |
| } |
| } |
| } |
| // Update |squared_energy_y| for the next inverted lag. |
| const float old_coeff = pitch_buf[inv_lag]; |
| const float new_coeff = pitch_buf[inv_lag + frame_size]; |
| yy -= old_coeff * old_coeff; |
| yy += new_coeff * new_coeff; |
| yy = std::max(0.f, yy); |
| } |
| return {{best.period_inverted_lag, second_best.period_inverted_lag}}; |
| } |
| |
| size_t RefinePitchPeriod48kHz( |
| rtc::ArrayView<const float, kBufSize24kHz> pitch_buf, |
| rtc::ArrayView<const size_t, 2> inv_lags) { |
| // Compute the auto-correlation terms only for neighbors of the given pitch |
| // candidates (similar to what is done in ComputePitchAutoCorrelation(), but |
| // for a few lag values). |
| std::array<float, kNumInvertedLags24kHz> auto_corr; |
| auto_corr.fill(0.f); // Zeros become ignored lags in FindBestPitchPeriods(). |
| auto is_neighbor = [](size_t i, size_t j) { |
| return ((i > j) ? (i - j) : (j - i)) <= 2; |
| }; |
| for (size_t inv_lag = 0; inv_lag < auto_corr.size(); ++inv_lag) { |
| if (is_neighbor(inv_lag, inv_lags[0]) || is_neighbor(inv_lag, inv_lags[1])) |
| auto_corr[inv_lag] = |
| ComputeAutoCorrelationCoeff(pitch_buf, inv_lag, kMaxPitch24kHz); |
| } |
| // Find best pitch at 24 kHz. |
| const auto pitch_candidates_inv_lags = FindBestPitchPeriods( |
| {auto_corr.data(), auto_corr.size()}, |
| {pitch_buf.data(), pitch_buf.size()}, kMaxPitch24kHz); |
| const auto inv_lag = pitch_candidates_inv_lags[0]; // Refine the best. |
| // Pseudo-interpolation. |
| return PitchPseudoInterpolationInvLagAutoCorr(inv_lag, auto_corr); |
| } |
| |
| PitchInfo CheckLowerPitchPeriodsAndComputePitchGain( |
| rtc::ArrayView<const float, kBufSize24kHz> pitch_buf, |
| size_t initial_pitch_period_48kHz, |
| PitchInfo prev_pitch_48kHz) { |
| RTC_DCHECK_LE(kMinPitch48kHz, initial_pitch_period_48kHz); |
| RTC_DCHECK_LE(initial_pitch_period_48kHz, kMaxPitch48kHz); |
| // Stores information for a refined pitch candidate. |
| struct RefinedPitchCandidate { |
| RefinedPitchCandidate() {} |
| RefinedPitchCandidate(size_t period_24kHz, float gain, float xy, float yy) |
| : period_24kHz(period_24kHz), gain(gain), xy(xy), yy(yy) {} |
| size_t period_24kHz; |
| // Pitch strength information. |
| float gain; |
| // Additional pitch strength information used for the final estimation of |
| // pitch gain. |
| float xy; // Cross-correlation. |
| float yy; // Auto-correlation. |
| }; |
| |
| // Initialize. |
| std::array<float, kMaxPitch24kHz + 1> yy_values; |
| ComputeSlidingFrameSquareEnergies(pitch_buf, |
| {yy_values.data(), yy_values.size()}); |
| const float xx = yy_values[0]; |
| // Helper lambdas. |
| const auto pitch_gain = [](float xy, float yy, float xx) { |
| RTC_DCHECK_LE(0.f, xx * yy); |
| return xy / std::sqrt(1.f + xx * yy); |
| }; |
| // Initial pitch candidate gain. |
| RefinedPitchCandidate best_pitch; |
| best_pitch.period_24kHz = |
| std::min(initial_pitch_period_48kHz / 2, kMaxPitch24kHz - 1); |
| best_pitch.xy = ComputeAutoCorrelationCoeff( |
| pitch_buf, GetInvertedLag(best_pitch.period_24kHz), kMaxPitch24kHz); |
| best_pitch.yy = yy_values[best_pitch.period_24kHz]; |
| best_pitch.gain = pitch_gain(best_pitch.xy, best_pitch.yy, xx); |
| |
| // Store the initial pitch period information. |
| const size_t initial_pitch_period = best_pitch.period_24kHz; |
| const float initial_pitch_gain = best_pitch.gain; |
| |
| // Given the initial pitch estimation, check lower periods (i.e., harmonics). |
| const auto alternative_period = [](size_t period, size_t k, |
| size_t n) -> size_t { |
| RTC_DCHECK_LT(0, k); |
| return (2 * n * period + k) / (2 * k); // Same as round(n*period/k). |
| }; |
| for (size_t k = 2; k < kSubHarmonicMultipliers.size() + 2; ++k) { |
| size_t candidate_pitch_period = |
| alternative_period(initial_pitch_period, k, 1); |
| if (candidate_pitch_period < kMinPitch24kHz) |
| break; |
| // When looking at |candidate_pitch_period|, we also look at one of its |
| // sub-harmonics. |kSubHarmonicMultipliers| is used to know where to look. |
| // |k| == 2 is a special case since |candidate_pitch_secondary_period| might |
| // be greater than the maximum pitch period. |
| size_t candidate_pitch_secondary_period = alternative_period( |
| initial_pitch_period, k, kSubHarmonicMultipliers[k - 2]); |
| if (k == 2 && candidate_pitch_secondary_period > kMaxPitch24kHz) |
| candidate_pitch_secondary_period = initial_pitch_period; |
| RTC_DCHECK_NE(candidate_pitch_period, candidate_pitch_secondary_period) |
| << "The lower pitch period and the additional sub-harmonic must not " |
| << "coincide."; |
| // Compute an auto-correlation score for the primary pitch candidate |
| // |candidate_pitch_period| by also looking at its possible sub-harmonic |
| // |candidate_pitch_secondary_period|. |
| float xy_primary_period = ComputeAutoCorrelationCoeff( |
| pitch_buf, GetInvertedLag(candidate_pitch_period), kMaxPitch24kHz); |
| float xy_secondary_period = ComputeAutoCorrelationCoeff( |
| pitch_buf, GetInvertedLag(candidate_pitch_secondary_period), |
| kMaxPitch24kHz); |
| float xy = 0.5f * (xy_primary_period + xy_secondary_period); |
| float yy = 0.5f * (yy_values[candidate_pitch_period] + |
| yy_values[candidate_pitch_secondary_period]); |
| float candidate_pitch_gain = pitch_gain(xy, yy, xx); |
| |
| // Maybe update best period. |
| float threshold = ComputePitchGainThreshold( |
| candidate_pitch_period, k, initial_pitch_period, initial_pitch_gain, |
| prev_pitch_48kHz.period / 2, prev_pitch_48kHz.gain); |
| if (candidate_pitch_gain > threshold) { |
| best_pitch = {candidate_pitch_period, candidate_pitch_gain, xy, yy}; |
| } |
| } |
| |
| // Final pitch gain and period. |
| best_pitch.xy = std::max(0.f, best_pitch.xy); |
| RTC_DCHECK_LE(0.f, best_pitch.yy); |
| float final_pitch_gain = (best_pitch.yy <= best_pitch.xy) |
| ? 1.f |
| : best_pitch.xy / (best_pitch.yy + 1.f); |
| final_pitch_gain = std::min(best_pitch.gain, final_pitch_gain); |
| size_t final_pitch_period_48kHz = std::max( |
| kMinPitch48kHz, |
| PitchPseudoInterpolationLagPitchBuf(best_pitch.period_24kHz, pitch_buf)); |
| |
| return {final_pitch_period_48kHz, final_pitch_gain}; |
| } |
| |
| } // namespace rnn_vad |
| } // namespace webrtc |