Use robust variance computation in RollingAccumulator.

Previous one was failing NumericStabilityForVariance test.

Bug: webrtc:10412
Change-Id: I97ba321743ebb8cc0923ae13a17c9997779fc305
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/133029
Commit-Queue: Yves Gerey <yvesg@google.com>
Reviewed-by: Harald Alvestrand <hta@webrtc.org>
Reviewed-by: Mirko Bonadei <mbonadei@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#27931}
diff --git a/rtc_base/BUILD.gn b/rtc_base/BUILD.gn
index fc3a3f3..98ff5db 100644
--- a/rtc_base/BUILD.gn
+++ b/rtc_base/BUILD.gn
@@ -912,6 +912,8 @@
       "ssl_roots.h",
     ]
 
+    deps += [ ":rtc_numerics" ]
+
     if (is_win) {
       sources += [ "win32_socket_init.h" ]
       if (current_os != "winuwp") {
diff --git a/rtc_base/numerics/running_statistics.h b/rtc_base/numerics/running_statistics.h
index d71323e..f3aa8e3 100644
--- a/rtc_base/numerics/running_statistics.h
+++ b/rtc_base/numerics/running_statistics.h
@@ -17,6 +17,7 @@
 
 #include "absl/types/optional.h"
 
+#include "rtc_base/checks.h"
 #include "rtc_base/numerics/math_utils.h"
 
 namespace webrtc {
@@ -28,6 +29,10 @@
 // min, max, mean, variance and standard deviation.
 // If you need to get percentiles, please use webrtc::SamplesStatsCounter.
 //
+// Please note RemoveSample() won't affect min and max.
+// If you want a full-fledged moving window over N last samples,
+// please use webrtc::RollingAccumulator.
+//
 // The measures return absl::nullopt if no samples were fed (Size() == 0),
 // otherwise the returned optional is guaranteed to contain a value.
 //
@@ -54,6 +59,24 @@
     cumul_ += delta * delta2;
   }
 
+  // Remove a previously added value in O(1) time.
+  // Nb: This doesn't affect min or max.
+  // Calling RemoveSample when Size()==0 is incorrect.
+  void RemoveSample(T sample) {
+    RTC_DCHECK_GT(Size(), 0);
+    // In production, just saturate at 0.
+    if (Size() == 0) {
+      return;
+    }
+    // Since samples order doesn't matter, this is the
+    // exact reciprocal of Welford's incremental update.
+    --size_;
+    const double delta = sample - mean_;
+    mean_ -= delta / size_;
+    const double delta2 = sample - mean_;
+    cumul_ -= delta * delta2;
+  }
+
   // Merge other stats, as if samples were added one by one, but in O(1).
   void MergeStatistics(const RunningStatistics<T>& other) {
     if (other.size_ == 0) {
@@ -78,11 +101,12 @@
 
   // Get Measures ////////////////////////////////////////////
 
-  // Returns number of samples involved,
-  // that is number of times AddSample() was called.
+  // Returns number of samples involved via AddSample() or MergeStatistics(),
+  // minus number of times RemoveSample() was called.
   int64_t Size() const { return size_; }
 
-  // Returns min in O(1) time.
+  // Returns minimum among all seen samples, in O(1) time.
+  // This isn't affected by RemoveSample().
   absl::optional<T> GetMin() const {
     if (size_ == 0) {
       return absl::nullopt;
@@ -90,7 +114,8 @@
     return min_;
   }
 
-  // Returns max in O(1) time.
+  // Returns maximum among all seen samples, in O(1) time.
+  // This isn't affected by RemoveSample().
   absl::optional<T> GetMax() const {
     if (size_ == 0) {
       return absl::nullopt;
diff --git a/rtc_base/numerics/running_statistics_unittest.cc b/rtc_base/numerics/running_statistics_unittest.cc
index 806b1e3..d77280a 100644
--- a/rtc_base/numerics/running_statistics_unittest.cc
+++ b/rtc_base/numerics/running_statistics_unittest.cc
@@ -56,7 +56,7 @@
 
 }  // namespace
 
-TEST(RunningStatisticsTest, FullSimpleTest) {
+TEST(RunningStatistics, FullSimpleTest) {
   auto stats = CreateStatsFilledWithIntsFrom1ToN(100);
 
   EXPECT_DOUBLE_EQ(*stats.GetMin(), 1.0);
@@ -76,7 +76,49 @@
   EXPECT_DOUBLE_EQ(*stats.GetStandardDeviation(), sqrt(4.5));
 }
 
-TEST(RunningStatisticsTest, VarianceFromUniformDistribution) {
+TEST(RunningStatistics, RemoveSample) {
+  // We check that adding then removing sample is no-op,
+  // or so (due to loss of precision).
+  RunningStatistics<int> stats;
+  stats.AddSample(2);
+  stats.AddSample(2);
+  stats.AddSample(-1);
+  stats.AddSample(5);
+
+  constexpr int iterations = 1e5;
+  for (int i = 0; i < iterations; ++i) {
+    stats.AddSample(i);
+    stats.RemoveSample(i);
+
+    EXPECT_NEAR(*stats.GetMean(), 2.0, 1e-8);
+    EXPECT_NEAR(*stats.GetVariance(), 4.5, 1e-3);
+    EXPECT_NEAR(*stats.GetStandardDeviation(), sqrt(4.5), 1e-4);
+  }
+}
+
+TEST(RunningStatistics, RemoveSamplesSequence) {
+  // We check that adding then removing a sequence of samples is no-op,
+  // or so (due to loss of precision).
+  RunningStatistics<int> stats;
+  stats.AddSample(2);
+  stats.AddSample(2);
+  stats.AddSample(-1);
+  stats.AddSample(5);
+
+  constexpr int iterations = 1e4;
+  for (int i = 0; i < iterations; ++i) {
+    stats.AddSample(i);
+  }
+  for (int i = 0; i < iterations; ++i) {
+    stats.RemoveSample(i);
+  }
+
+  EXPECT_NEAR(*stats.GetMean(), 2.0, 1e-7);
+  EXPECT_NEAR(*stats.GetVariance(), 4.5, 1e-3);
+  EXPECT_NEAR(*stats.GetStandardDeviation(), sqrt(4.5), 1e-4);
+}
+
+TEST(RunningStatistics, VarianceFromUniformDistribution) {
   // Check variance converge to 1/12 for [0;1) uniform distribution.
   // Acts as a sanity check for NumericStabilityForVariance test.
   auto stats = CreateStatsFromUniformDistribution(1e6, 0, 1);
@@ -84,7 +126,7 @@
   EXPECT_NEAR(*stats.GetVariance(), 1. / 12, 1e-3);
 }
 
-TEST(RunningStatisticsTest, NumericStabilityForVariance) {
+TEST(RunningStatistics, NumericStabilityForVariance) {
   // Same test as VarianceFromUniformDistribution,
   // except the range is shifted to [1e9;1e9+1).
   // Variance should also converge to 1/12.
@@ -96,6 +138,26 @@
   EXPECT_NEAR(*stats.GetVariance(), 1. / 12, 1e-3);
 }
 
+TEST(RunningStatistics, MinRemainsUnchangedAfterRemove) {
+  // We don't want to recompute min (that's RollingAccumulator's role),
+  // check we get the overall min.
+  RunningStatistics<int> stats;
+  stats.AddSample(1);
+  stats.AddSample(2);
+  stats.RemoveSample(1);
+  EXPECT_EQ(stats.GetMin(), 1);
+}
+
+TEST(RunningStatistics, MaxRemainsUnchangedAfterRemove) {
+  // We don't want to recompute max (that's RollingAccumulator's role),
+  // check we get the overall max.
+  RunningStatistics<int> stats;
+  stats.AddSample(1);
+  stats.AddSample(2);
+  stats.RemoveSample(2);
+  EXPECT_EQ(stats.GetMax(), 2);
+}
+
 TEST_P(RunningStatisticsTest, MergeStatistics) {
   int data[SIZE_FOR_MERGE] = {2, 2, -1, 5, 10};
   // Split the data in different partitions.
diff --git a/rtc_base/rolling_accumulator.h b/rtc_base/rolling_accumulator.h
index dacceff..b630554 100644
--- a/rtc_base/rolling_accumulator.h
+++ b/rtc_base/rolling_accumulator.h
@@ -17,6 +17,7 @@
 
 #include "rtc_base/checks.h"
 #include "rtc_base/constructor_magic.h"
+#include "rtc_base/numerics/running_statistics.h"
 
 namespace rtc {
 
@@ -28,19 +29,18 @@
 class RollingAccumulator {
  public:
   explicit RollingAccumulator(size_t max_count) : samples_(max_count) {
+    RTC_DCHECK(max_count > 0);
     Reset();
   }
   ~RollingAccumulator() {}
 
   size_t max_count() const { return samples_.size(); }
 
-  size_t count() const { return count_; }
+  size_t count() const { return static_cast<size_t>(stats_.Size()); }
 
   void Reset() {
-    count_ = 0U;
+    stats_ = webrtc::RunningStatistics<T>();
     next_index_ = 0U;
-    sum_ = 0.0;
-    sum_2_ = 0.0;
     max_ = T();
     max_stale_ = false;
     min_ = T();
@@ -48,52 +48,40 @@
   }
 
   void AddSample(T sample) {
-    if (count_ == max_count()) {
+    if (count() == max_count()) {
       // Remove oldest sample.
       T sample_to_remove = samples_[next_index_];
-      sum_ -= sample_to_remove;
-      sum_2_ -= static_cast<double>(sample_to_remove) * sample_to_remove;
+      stats_.RemoveSample(sample_to_remove);
       if (sample_to_remove >= max_) {
         max_stale_ = true;
       }
       if (sample_to_remove <= min_) {
         min_stale_ = true;
       }
-    } else {
-      // Increase count of samples.
-      ++count_;
     }
     // Add new sample.
     samples_[next_index_] = sample;
-    sum_ += sample;
-    sum_2_ += static_cast<double>(sample) * sample;
-    if (count_ == 1 || sample >= max_) {
+    if (count() == 0 || sample >= max_) {
       max_ = sample;
       max_stale_ = false;
     }
-    if (count_ == 1 || sample <= min_) {
+    if (count() == 0 || sample <= min_) {
       min_ = sample;
       min_stale_ = false;
     }
+    stats_.AddSample(sample);
     // Update next_index_.
     next_index_ = (next_index_ + 1) % max_count();
   }
 
-  T ComputeSum() const { return static_cast<T>(sum_); }
-
-  double ComputeMean() const {
-    if (count_ == 0) {
-      return 0.0;
-    }
-    return sum_ / count_;
-  }
+  double ComputeMean() const { return stats_.GetMean().value_or(0); }
 
   T ComputeMax() const {
     if (max_stale_) {
-      RTC_DCHECK(count_ > 0)
-          << "It shouldn't be possible for max_stale_ && count_ == 0";
+      RTC_DCHECK(count() > 0)
+          << "It shouldn't be possible for max_stale_ && count() == 0";
       max_ = samples_[next_index_];
-      for (size_t i = 1u; i < count_; i++) {
+      for (size_t i = 1u; i < count(); i++) {
         max_ = std::max(max_, samples_[(next_index_ + i) % max_count()]);
       }
       max_stale_ = false;
@@ -103,10 +91,10 @@
 
   T ComputeMin() const {
     if (min_stale_) {
-      RTC_DCHECK(count_ > 0)
-          << "It shouldn't be possible for min_stale_ && count_ == 0";
+      RTC_DCHECK(count() > 0)
+          << "It shouldn't be possible for min_stale_ && count() == 0";
       min_ = samples_[next_index_];
-      for (size_t i = 1u; i < count_; i++) {
+      for (size_t i = 1u; i < count(); i++) {
         min_ = std::min(min_, samples_[(next_index_ + i) % max_count()]);
       }
       min_stale_ = false;
@@ -118,14 +106,14 @@
   // Weights nth sample with weight (learning_rate)^n. Learning_rate should be
   // between (0.0, 1.0], otherwise the non-weighted mean is returned.
   double ComputeWeightedMean(double learning_rate) const {
-    if (count_ < 1 || learning_rate <= 0.0 || learning_rate >= 1.0) {
+    if (count() < 1 || learning_rate <= 0.0 || learning_rate >= 1.0) {
       return ComputeMean();
     }
     double weighted_mean = 0.0;
     double current_weight = 1.0;
     double weight_sum = 0.0;
     const size_t max_size = max_count();
-    for (size_t i = 0; i < count_; ++i) {
+    for (size_t i = 0; i < count(); ++i) {
       current_weight *= learning_rate;
       weight_sum += current_weight;
       // Add max_size to prevent underflow.
@@ -137,22 +125,11 @@
 
   // Compute estimated variance.  Estimation is more accurate
   // as the number of samples grows.
-  double ComputeVariance() const {
-    if (count_ == 0) {
-      return 0.0;
-    }
-    // Var = E[x^2] - (E[x])^2
-    double count_inv = 1.0 / count_;
-    double mean_2 = sum_2_ * count_inv;
-    double mean = sum_ * count_inv;
-    return mean_2 - (mean * mean);
-  }
+  double ComputeVariance() const { return stats_.GetVariance().value_or(0); }
 
  private:
-  size_t count_;
+  webrtc::RunningStatistics<T> stats_;
   size_t next_index_;
-  double sum_;    // Sum(x) - double to avoid overflow
-  double sum_2_;  // Sum(x*x) - double to avoid overflow
   mutable T max_;
   mutable bool max_stale_;
   mutable T min_;
diff --git a/rtc_base/rolling_accumulator_unittest.cc b/rtc_base/rolling_accumulator_unittest.cc
index 7d5e70d..f6835aa 100644
--- a/rtc_base/rolling_accumulator_unittest.cc
+++ b/rtc_base/rolling_accumulator_unittest.cc
@@ -8,6 +8,8 @@
  *  be found in the AUTHORS file in the root of the source tree.
  */
 
+#include <random>
+
 #include "rtc_base/rolling_accumulator.h"
 
 #include "test/gtest.h"
@@ -18,6 +20,18 @@
 
 const double kLearningRate = 0.5;
 
+// Add |n| samples drawn from uniform distribution in [a;b].
+void FillStatsFromUniformDistribution(RollingAccumulator<double>& stats,
+                                      int n,
+                                      double a,
+                                      double b) {
+  std::mt19937 gen{std::random_device()()};
+  std::uniform_real_distribution<> dis(a, b);
+
+  for (int i = 1; i <= n; i++) {
+    stats.AddSample(dis(gen));
+  }
+}
 }  // namespace
 
 TEST(RollingAccumulatorTest, ZeroSamples) {
@@ -37,7 +51,6 @@
   }
 
   EXPECT_EQ(4U, accum.count());
-  EXPECT_EQ(6, accum.ComputeSum());
   EXPECT_DOUBLE_EQ(1.5, accum.ComputeMean());
   EXPECT_NEAR(2.26666, accum.ComputeWeightedMean(kLearningRate), 0.01);
   EXPECT_DOUBLE_EQ(1.25, accum.ComputeVariance());
@@ -52,7 +65,6 @@
   }
 
   EXPECT_EQ(10U, accum.count());
-  EXPECT_EQ(65, accum.ComputeSum());
   EXPECT_DOUBLE_EQ(6.5, accum.ComputeMean());
   EXPECT_NEAR(10.0, accum.ComputeWeightedMean(kLearningRate), 0.01);
   EXPECT_NEAR(9.0, accum.ComputeVariance(), 1.0);
@@ -79,7 +91,6 @@
   }
 
   EXPECT_EQ(5U, accum.count());
-  EXPECT_EQ(10, accum.ComputeSum());
   EXPECT_DOUBLE_EQ(2.0, accum.ComputeMean());
   EXPECT_EQ(0, accum.ComputeMin());
   EXPECT_EQ(4, accum.ComputeMax());
@@ -92,7 +103,6 @@
   }
 
   EXPECT_EQ(10u, accum.count());
-  EXPECT_DOUBLE_EQ(875.0, accum.ComputeSum());
   EXPECT_DOUBLE_EQ(87.5, accum.ComputeMean());
   EXPECT_NEAR(105.049, accum.ComputeWeightedMean(kLearningRate), 0.1);
   EXPECT_NEAR(229.166667, accum.ComputeVariance(), 25);
@@ -116,4 +126,25 @@
   EXPECT_NEAR(6.0, accum.ComputeWeightedMean(kLearningRate), 0.1);
 }
 
+TEST(RollingAccumulatorTest, VarianceFromUniformDistribution) {
+  // Check variance converge to 1/12 for [0;1) uniform distribution.
+  // Acts as a sanity check for NumericStabilityForVariance test.
+  RollingAccumulator<double> stats(/*max_count=*/0.5e6);
+  FillStatsFromUniformDistribution(stats, 1e6, 0, 1);
+
+  EXPECT_NEAR(stats.ComputeVariance(), 1. / 12, 1e-3);
+}
+
+TEST(RollingAccumulatorTest, NumericStabilityForVariance) {
+  // Same test as VarianceFromUniformDistribution,
+  // except the range is shifted to [1e9;1e9+1).
+  // Variance should also converge to 1/12.
+  // NB: Although we lose precision for the samples themselves, the fractional
+  //     part still enjoys 22 bits of mantissa and errors should even out,
+  //     so that couldn't explain a mismatch.
+  RollingAccumulator<double> stats(/*max_count=*/0.5e6);
+  FillStatsFromUniformDistribution(stats, 1e6, 1e9, 1e9 + 1);
+
+  EXPECT_NEAR(stats.ComputeVariance(), 1. / 12, 1e-3);
+}
 }  // namespace rtc