BiQuadFilter: API improvements

Bug: webrtc:7494
Change-Id: If0270cddeb46fa53c0fbb385c85e48f28f9e1a5c
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/236342
Reviewed-by: Hanna Silen <silen@webrtc.org>
Commit-Queue: Alessio Bazzica <alessiob@webrtc.org>
Cr-Commit-Position: refs/heads/main@{#35274}
diff --git a/modules/audio_processing/agc2/biquad_filter.cc b/modules/audio_processing/agc2/biquad_filter.cc
index ccb7807..453125f 100644
--- a/modules/audio_processing/agc2/biquad_filter.cc
+++ b/modules/audio_processing/agc2/biquad_filter.cc
@@ -10,26 +10,37 @@
 
 #include "modules/audio_processing/agc2/biquad_filter.h"
 
-#include <stddef.h>
+#include "rtc_base/arraysize.h"
 
 namespace webrtc {
 
-// Transposed direct form I implementation of a bi-quad filter applied to an
-// input signal `x` to produce an output signal `y`.
+BiQuadFilter::BiQuadFilter(const Config& config)
+    : config_(config), state_({}) {}
+
+BiQuadFilter::~BiQuadFilter() = default;
+
+void BiQuadFilter::SetConfig(const Config& config) {
+  config_ = config;
+  state_ = {};
+}
+
+void BiQuadFilter::Reset() {
+  state_ = {};
+}
+
 void BiQuadFilter::Process(rtc::ArrayView<const float> x,
                            rtc::ArrayView<float> y) {
+  RTC_DCHECK_EQ(x.size(), y.size());
   for (size_t k = 0; k < x.size(); ++k) {
-    // Use temporary variable for x[k] to allow in-place function call
-    // (that x and y refer to the same array).
+    // Use a temporary variable for `x[k]` to allow in-place processing.
     const float tmp = x[k];
-    y[k] = coefficients_.b[0] * tmp + coefficients_.b[1] * biquad_state_.b[0] +
-           coefficients_.b[2] * biquad_state_.b[1] -
-           coefficients_.a[0] * biquad_state_.a[0] -
-           coefficients_.a[1] * biquad_state_.a[1];
-    biquad_state_.b[1] = biquad_state_.b[0];
-    biquad_state_.b[0] = tmp;
-    biquad_state_.a[1] = biquad_state_.a[0];
-    biquad_state_.a[0] = y[k];
+    y[k] = config_.b[0] * tmp + config_.b[1] * state_.b[0] +
+           config_.b[2] * state_.b[1] - config_.a[0] * state_.a[0] -
+           config_.a[1] * state_.a[1];
+    state_.b[1] = state_.b[0];
+    state_.b[0] = tmp;
+    state_.a[1] = state_.a[0];
+    state_.a[0] = y[k];
   }
 }
 
diff --git a/modules/audio_processing/agc2/biquad_filter.h b/modules/audio_processing/agc2/biquad_filter.h
index 7bf3301..5273ff9 100644
--- a/modules/audio_processing/agc2/biquad_filter.h
+++ b/modules/audio_processing/agc2/biquad_filter.h
@@ -11,54 +11,44 @@
 #ifndef MODULES_AUDIO_PROCESSING_AGC2_BIQUAD_FILTER_H_
 #define MODULES_AUDIO_PROCESSING_AGC2_BIQUAD_FILTER_H_
 
-#include <algorithm>
-
 #include "api/array_view.h"
-#include "rtc_base/arraysize.h"
-#include "rtc_base/constructor_magic.h"
 
 namespace webrtc {
 
+// Transposed direct form I implementation of a bi-quad filter.
+//        b[0] + b[1] • z^(-1) + b[2] • z^(-2)
+// H(z) = ------------------------------------
+//          1 + a[1] • z^(-1) + a[2] • z^(-2)
 class BiQuadFilter {
  public:
   // Normalized filter coefficients.
-  //        b_0 + b_1 • z^(-1) + b_2 • z^(-2)
-  // H(z) = ---------------------------------
-  //         1 + a_1 • z^(-1) + a_2 • z^(-2)
-  struct BiQuadCoefficients {
-    float b[3];
-    float a[2];
+  // Computed as `[b, a] = scipy.signal.butter(N=2, Wn, btype)`.
+  struct Config {
+    float b[3];  // b[0], b[1], b[2].
+    float a[2];  // a[1], a[2].
   };
 
-  BiQuadFilter() = default;
+  explicit BiQuadFilter(const Config& config);
+  BiQuadFilter(const BiQuadFilter&) = delete;
+  BiQuadFilter& operator=(const BiQuadFilter&) = delete;
+  ~BiQuadFilter();
 
-  void Initialize(const BiQuadCoefficients& coefficients) {
-    coefficients_ = coefficients;
-  }
+  // Sets the filter configuration and resets the internal state.
+  void SetConfig(const Config& config);
 
-  void Reset() { biquad_state_.Reset(); }
+  // Zeroes the filter state.
+  void Reset();
 
-  // Produces a filtered output y of the input x. Both x and y need to
-  // have the same length. In-place modification is allowed.
+  // Filters `x` and writes the output in `y`, which must have the same length
+  // of `x`. In-place processing is supported.
   void Process(rtc::ArrayView<const float> x, rtc::ArrayView<float> y);
 
  private:
-  struct BiQuadState {
-    BiQuadState() { Reset(); }
-
-    void Reset() {
-      std::fill(b, b + arraysize(b), 0.f);
-      std::fill(a, a + arraysize(a), 0.f);
-    }
-
+  Config config_;
+  struct State {
     float b[2];
     float a[2];
-  };
-
-  BiQuadState biquad_state_;
-  BiQuadCoefficients coefficients_;
-
-  RTC_DISALLOW_COPY_AND_ASSIGN(BiQuadFilter);
+  } state_;
 };
 
 }  // namespace webrtc
diff --git a/modules/audio_processing/agc2/biquad_filter_unittest.cc b/modules/audio_processing/agc2/biquad_filter_unittest.cc
index 55ca1a5..a53036b 100644
--- a/modules/audio_processing/agc2/biquad_filter_unittest.cc
+++ b/modules/audio_processing/agc2/biquad_filter_unittest.cc
@@ -19,11 +19,10 @@
 #include "rtc_base/gunit.h"
 
 namespace webrtc {
-namespace test {
 namespace {
 
-constexpr size_t kFrameSize = 8;
-constexpr size_t kNumFrames = 4;
+constexpr int kFrameSize = 8;
+constexpr int kNumFrames = 4;
 using FloatArraySequence =
     std::array<std::array<float, kFrameSize>, kNumFrames>;
 
@@ -37,8 +36,8 @@
      {{22.645832f, -64.597153f, 55.462521f, -109.393188f, 10.117825f,
        -40.019642f, -98.612228f, -8.330326f}}}};
 
-// Generated via "B, A = scipy.signal.butter(2, 30/12000, btype='highpass')"
-const BiQuadFilter::BiQuadCoefficients kBiQuadConfig = {
+// Computed as `scipy.signal.butter(N=2, Wn=60/24000, btype='highpass')`.
+constexpr BiQuadFilter::Config kBiQuadConfig{
     {0.99446179f, -1.98892358f, 0.99446179f},
     {-1.98889291f, 0.98895425f}};
 
@@ -57,22 +56,23 @@
      {{24.84286614f, -62.18094158f, 57.91488056f, -106.65685933f, 13.38760103f,
        -36.60367134f, -94.44880104f, -3.59920354f}}}};
 
-// Fail for every pair from two equally sized rtc::ArrayView<float> views such
+// Fails for every pair from two equally sized rtc::ArrayView<float> views such
 // that their relative error is above a given threshold. If the expected value
-// of a pair is 0, the tolerance is used to check the absolute error.
+// of a pair is 0, `tolerance` is used to check the absolute error.
 void ExpectNearRelative(rtc::ArrayView<const float> expected,
                         rtc::ArrayView<const float> computed,
                         const float tolerance) {
   // The relative error is undefined when the expected value is 0.
   // When that happens, check the absolute error instead. `safe_den` is used
   // below to implement such logic.
-  auto safe_den = [](float x) { return (x == 0.f) ? 1.f : std::fabs(x); };
+  auto safe_den = [](float x) { return (x == 0.0f) ? 1.0f : std::fabs(x); };
   ASSERT_EQ(expected.size(), computed.size());
   for (size_t i = 0; i < expected.size(); ++i) {
     const float abs_diff = std::fabs(expected[i] - computed[i]);
     // No failure when the values are equal.
-    if (abs_diff == 0.f)
+    if (abs_diff == 0.0f) {
       continue;
+    }
     SCOPED_TRACE(i);
     SCOPED_TRACE(expected[i]);
     SCOPED_TRACE(computed[i]);
@@ -80,32 +80,32 @@
   }
 }
 
-}  // namespace
-
+// Checks that filtering works when different containers are used both as input
+// and as output.
 TEST(BiQuadFilterTest, FilterNotInPlace) {
-  BiQuadFilter filter;
-  filter.Initialize(kBiQuadConfig);
+  BiQuadFilter filter(kBiQuadConfig);
   std::array<float, kFrameSize> samples;
 
   // TODO(https://bugs.webrtc.org/8948): Add when the issue is fixed.
   // FloatingPointExceptionObserver fpe_observer;
 
-  for (size_t i = 0; i < kNumFrames; ++i) {
+  for (int i = 0; i < kNumFrames; ++i) {
     SCOPED_TRACE(i);
     filter.Process(kBiQuadInputSeq[i], samples);
     ExpectNearRelative(kBiQuadOutputSeq[i], samples, 2e-4f);
   }
 }
 
+// Checks that filtering works when the same container is used both as input and
+// as output.
 TEST(BiQuadFilterTest, FilterInPlace) {
-  BiQuadFilter filter;
-  filter.Initialize(kBiQuadConfig);
+  BiQuadFilter filter(kBiQuadConfig);
   std::array<float, kFrameSize> samples;
 
   // TODO(https://bugs.webrtc.org/8948): Add when the issue is fixed.
   // FloatingPointExceptionObserver fpe_observer;
 
-  for (size_t i = 0; i < kNumFrames; ++i) {
+  for (int i = 0; i < kNumFrames; ++i) {
     SCOPED_TRACE(i);
     std::copy(kBiQuadInputSeq[i].begin(), kBiQuadInputSeq[i].end(),
               samples.begin());
@@ -114,23 +114,62 @@
   }
 }
 
-TEST(BiQuadFilterTest, Reset) {
-  BiQuadFilter filter;
-  filter.Initialize(kBiQuadConfig);
+// Checks that different configurations produce different outputs.
+TEST(BiQuadFilterTest, SetConfigDifferentOutput) {
+  BiQuadFilter filter(/*config=*/{{0.97803048f, -1.95606096f, 0.97803048f},
+                                  {-1.95557824f, 0.95654368f}});
 
   std::array<float, kFrameSize> samples1;
-  for (size_t i = 0; i < kNumFrames; ++i) {
+  for (int i = 0; i < kNumFrames; ++i) {
     filter.Process(kBiQuadInputSeq[i], samples1);
   }
 
-  filter.Reset();
+  filter.SetConfig(
+      {{0.09763107f, 0.19526215f, 0.09763107f}, {-0.94280904f, 0.33333333f}});
   std::array<float, kFrameSize> samples2;
-  for (size_t i = 0; i < kNumFrames; ++i) {
+  for (int i = 0; i < kNumFrames; ++i) {
+    filter.Process(kBiQuadInputSeq[i], samples2);
+  }
+
+  EXPECT_NE(samples1, samples2);
+}
+
+// Checks that when `SetConfig()` is called but the filter coefficients are the
+// same the filter state is reset.
+TEST(BiQuadFilterTest, SetConfigResetsState) {
+  BiQuadFilter filter(kBiQuadConfig);
+
+  std::array<float, kFrameSize> samples1;
+  for (int i = 0; i < kNumFrames; ++i) {
+    filter.Process(kBiQuadInputSeq[i], samples1);
+  }
+
+  filter.SetConfig(kBiQuadConfig);
+  std::array<float, kFrameSize> samples2;
+  for (int i = 0; i < kNumFrames; ++i) {
     filter.Process(kBiQuadInputSeq[i], samples2);
   }
 
   EXPECT_EQ(samples1, samples2);
 }
 
-}  // namespace test
+// Checks that when `Reset()` is called the filter state is reset.
+TEST(BiQuadFilterTest, Reset) {
+  BiQuadFilter filter(kBiQuadConfig);
+
+  std::array<float, kFrameSize> samples1;
+  for (int i = 0; i < kNumFrames; ++i) {
+    filter.Process(kBiQuadInputSeq[i], samples1);
+  }
+
+  filter.Reset();
+  std::array<float, kFrameSize> samples2;
+  for (int i = 0; i < kNumFrames; ++i) {
+    filter.Process(kBiQuadInputSeq[i], samples2);
+  }
+
+  EXPECT_EQ(samples1, samples2);
+}
+
+}  // namespace
 }  // namespace webrtc
diff --git a/modules/audio_processing/agc2/rnn_vad/features_extraction.cc b/modules/audio_processing/agc2/rnn_vad/features_extraction.cc
index 5c276c8..5020234 100644
--- a/modules/audio_processing/agc2/rnn_vad/features_extraction.cc
+++ b/modules/audio_processing/agc2/rnn_vad/features_extraction.cc
@@ -19,8 +19,8 @@
 namespace rnn_vad {
 namespace {
 
-// Generated via "B, A = scipy.signal.butter(2, 30/12000, btype='highpass')"
-const BiQuadFilter::BiQuadCoefficients kHpfConfig24k = {
+// Computed as `scipy.signal.butter(N=2, Wn=60/24000, btype='highpass')`.
+constexpr BiQuadFilter::Config kHpfConfig24k{
     {0.99446179f, -1.98892358f, 0.99446179f},
     {-1.98889291f, 0.98895425f}};
 
@@ -28,6 +28,7 @@
 
 FeaturesExtractor::FeaturesExtractor(const AvailableCpuFeatures& cpu_features)
     : use_high_pass_filter_(false),
+      hpf_(kHpfConfig24k),
       pitch_buf_24kHz_(),
       pitch_buf_24kHz_view_(pitch_buf_24kHz_.GetBufferView()),
       lp_residual_(kBufSize24kHz),
@@ -35,7 +36,6 @@
       pitch_estimator_(cpu_features),
       reference_frame_view_(pitch_buf_24kHz_.GetMostRecentValuesView()) {
   RTC_DCHECK_EQ(kBufSize24kHz, lp_residual_.size());
-  hpf_.Initialize(kHpfConfig24k);
   Reset();
 }
 
@@ -44,8 +44,9 @@
 void FeaturesExtractor::Reset() {
   pitch_buf_24kHz_.Reset();
   spectral_features_extractor_.Reset();
-  if (use_high_pass_filter_)
+  if (use_high_pass_filter_) {
     hpf_.Reset();
+  }
 }
 
 bool FeaturesExtractor::CheckSilenceComputeFeatures(