| /* |
| * Copyright (c) 2019 The WebRTC project authors. All Rights Reserved. |
| * |
| * Use of this source code is governed by a BSD-style license |
| * that can be found in the LICENSE file in the root of the source |
| * tree. An additional intellectual property rights grant can be found |
| * in the file PATENTS. All contributing project authors may |
| * be found in the AUTHORS file in the root of the source tree. |
| */ |
| |
| #include "modules/audio_processing/agc2/rnn_vad/auto_correlation.h" |
| |
| #include <algorithm> |
| |
| #include "rtc_base/checks.h" |
| |
| namespace webrtc { |
| namespace rnn_vad { |
| namespace { |
| |
| constexpr int kAutoCorrelationFftOrder = 9; // Length-512 FFT. |
| static_assert(1 << kAutoCorrelationFftOrder > |
| kNumLags12kHz + kBufSize12kHz - kMaxPitch12kHz, |
| ""); |
| |
| } // namespace |
| |
| AutoCorrelationCalculator::AutoCorrelationCalculator() |
| : fft_(1 << kAutoCorrelationFftOrder, Pffft::FftType::kReal), |
| tmp_(fft_.CreateBuffer()), |
| X_(fft_.CreateBuffer()), |
| H_(fft_.CreateBuffer()) {} |
| |
| AutoCorrelationCalculator::~AutoCorrelationCalculator() = default; |
| |
| // The auto-correlations coefficients are computed as follows: |
| // |.........|...........| <- pitch buffer |
| // [ x (fixed) ] |
| // [ y_0 ] |
| // [ y_{m-1} ] |
| // x and y are sub-array of equal length; x is never moved, whereas y slides. |
| // The cross-correlation between y_0 and x corresponds to the auto-correlation |
| // for the maximum pitch period. Hence, the first value in |auto_corr| has an |
| // inverted lag equal to 0 that corresponds to a lag equal to the maximum |
| // pitch period. |
| void AutoCorrelationCalculator::ComputeOnPitchBuffer( |
| rtc::ArrayView<const float, kBufSize12kHz> pitch_buf, |
| rtc::ArrayView<float, kNumLags12kHz> auto_corr) { |
| RTC_DCHECK_LT(auto_corr.size(), kMaxPitch12kHz); |
| RTC_DCHECK_GT(pitch_buf.size(), kMaxPitch12kHz); |
| constexpr int kFftFrameSize = 1 << kAutoCorrelationFftOrder; |
| constexpr int kConvolutionLength = kBufSize12kHz - kMaxPitch12kHz; |
| static_assert(kConvolutionLength == kFrameSize20ms12kHz, |
| "Mismatch between pitch buffer size, frame size and maximum " |
| "pitch period."); |
| static_assert(kFftFrameSize > kNumLags12kHz + kConvolutionLength, |
| "The FFT length is not sufficiently big to avoid cyclic " |
| "convolution errors."); |
| auto tmp = tmp_->GetView(); |
| |
| // Compute the FFT for the reversed reference frame - i.e., |
| // pitch_buf[-kConvolutionLength:]. |
| std::reverse_copy(pitch_buf.end() - kConvolutionLength, pitch_buf.end(), |
| tmp.begin()); |
| std::fill(tmp.begin() + kConvolutionLength, tmp.end(), 0.f); |
| fft_.ForwardTransform(*tmp_, H_.get(), /*ordered=*/false); |
| |
| // Compute the FFT for the sliding frames chunk. The sliding frames are |
| // defined as pitch_buf[i:i+kConvolutionLength] where i in |
| // [0, kNumLags12kHz). The chunk includes all of them, hence it is |
| // defined as pitch_buf[:kNumLags12kHz+kConvolutionLength]. |
| std::copy(pitch_buf.begin(), |
| pitch_buf.begin() + kConvolutionLength + kNumLags12kHz, |
| tmp.begin()); |
| std::fill(tmp.begin() + kNumLags12kHz + kConvolutionLength, tmp.end(), 0.f); |
| fft_.ForwardTransform(*tmp_, X_.get(), /*ordered=*/false); |
| |
| // Convolve in the frequency domain. |
| constexpr float kScalingFactor = 1.f / static_cast<float>(kFftFrameSize); |
| std::fill(tmp.begin(), tmp.end(), 0.f); |
| fft_.FrequencyDomainConvolve(*X_, *H_, tmp_.get(), kScalingFactor); |
| fft_.BackwardTransform(*tmp_, tmp_.get(), /*ordered=*/false); |
| |
| // Extract the auto-correlation coefficients. |
| std::copy(tmp.begin() + kConvolutionLength - 1, |
| tmp.begin() + kConvolutionLength + kNumLags12kHz - 1, |
| auto_corr.begin()); |
| } |
| |
| } // namespace rnn_vad |
| } // namespace webrtc |