Use robust variance computation in RollingAccumulator.
Previous one was failing NumericStabilityForVariance test.
Bug: webrtc:10412
Change-Id: I97ba321743ebb8cc0923ae13a17c9997779fc305
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/133029
Commit-Queue: Yves Gerey <yvesg@google.com>
Reviewed-by: Harald Alvestrand <hta@webrtc.org>
Reviewed-by: Mirko Bonadei <mbonadei@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#27931}
diff --git a/rtc_base/BUILD.gn b/rtc_base/BUILD.gn
index fc3a3f3..98ff5db 100644
--- a/rtc_base/BUILD.gn
+++ b/rtc_base/BUILD.gn
@@ -912,6 +912,8 @@
"ssl_roots.h",
]
+ deps += [ ":rtc_numerics" ]
+
if (is_win) {
sources += [ "win32_socket_init.h" ]
if (current_os != "winuwp") {
diff --git a/rtc_base/numerics/running_statistics.h b/rtc_base/numerics/running_statistics.h
index d71323e..f3aa8e3 100644
--- a/rtc_base/numerics/running_statistics.h
+++ b/rtc_base/numerics/running_statistics.h
@@ -17,6 +17,7 @@
#include "absl/types/optional.h"
+#include "rtc_base/checks.h"
#include "rtc_base/numerics/math_utils.h"
namespace webrtc {
@@ -28,6 +29,10 @@
// min, max, mean, variance and standard deviation.
// If you need to get percentiles, please use webrtc::SamplesStatsCounter.
//
+// Please note RemoveSample() won't affect min and max.
+// If you want a full-fledged moving window over N last samples,
+// please use webrtc::RollingAccumulator.
+//
// The measures return absl::nullopt if no samples were fed (Size() == 0),
// otherwise the returned optional is guaranteed to contain a value.
//
@@ -54,6 +59,24 @@
cumul_ += delta * delta2;
}
+ // Remove a previously added value in O(1) time.
+ // Nb: This doesn't affect min or max.
+ // Calling RemoveSample when Size()==0 is incorrect.
+ void RemoveSample(T sample) {
+ RTC_DCHECK_GT(Size(), 0);
+ // In production, just saturate at 0.
+ if (Size() == 0) {
+ return;
+ }
+ // Since samples order doesn't matter, this is the
+ // exact reciprocal of Welford's incremental update.
+ --size_;
+ const double delta = sample - mean_;
+ mean_ -= delta / size_;
+ const double delta2 = sample - mean_;
+ cumul_ -= delta * delta2;
+ }
+
// Merge other stats, as if samples were added one by one, but in O(1).
void MergeStatistics(const RunningStatistics<T>& other) {
if (other.size_ == 0) {
@@ -78,11 +101,12 @@
// Get Measures ////////////////////////////////////////////
- // Returns number of samples involved,
- // that is number of times AddSample() was called.
+ // Returns number of samples involved via AddSample() or MergeStatistics(),
+ // minus number of times RemoveSample() was called.
int64_t Size() const { return size_; }
- // Returns min in O(1) time.
+ // Returns minimum among all seen samples, in O(1) time.
+ // This isn't affected by RemoveSample().
absl::optional<T> GetMin() const {
if (size_ == 0) {
return absl::nullopt;
@@ -90,7 +114,8 @@
return min_;
}
- // Returns max in O(1) time.
+ // Returns maximum among all seen samples, in O(1) time.
+ // This isn't affected by RemoveSample().
absl::optional<T> GetMax() const {
if (size_ == 0) {
return absl::nullopt;
diff --git a/rtc_base/numerics/running_statistics_unittest.cc b/rtc_base/numerics/running_statistics_unittest.cc
index 806b1e3..d77280a 100644
--- a/rtc_base/numerics/running_statistics_unittest.cc
+++ b/rtc_base/numerics/running_statistics_unittest.cc
@@ -56,7 +56,7 @@
} // namespace
-TEST(RunningStatisticsTest, FullSimpleTest) {
+TEST(RunningStatistics, FullSimpleTest) {
auto stats = CreateStatsFilledWithIntsFrom1ToN(100);
EXPECT_DOUBLE_EQ(*stats.GetMin(), 1.0);
@@ -76,7 +76,49 @@
EXPECT_DOUBLE_EQ(*stats.GetStandardDeviation(), sqrt(4.5));
}
-TEST(RunningStatisticsTest, VarianceFromUniformDistribution) {
+TEST(RunningStatistics, RemoveSample) {
+ // We check that adding then removing sample is no-op,
+ // or so (due to loss of precision).
+ RunningStatistics<int> stats;
+ stats.AddSample(2);
+ stats.AddSample(2);
+ stats.AddSample(-1);
+ stats.AddSample(5);
+
+ constexpr int iterations = 1e5;
+ for (int i = 0; i < iterations; ++i) {
+ stats.AddSample(i);
+ stats.RemoveSample(i);
+
+ EXPECT_NEAR(*stats.GetMean(), 2.0, 1e-8);
+ EXPECT_NEAR(*stats.GetVariance(), 4.5, 1e-3);
+ EXPECT_NEAR(*stats.GetStandardDeviation(), sqrt(4.5), 1e-4);
+ }
+}
+
+TEST(RunningStatistics, RemoveSamplesSequence) {
+ // We check that adding then removing a sequence of samples is no-op,
+ // or so (due to loss of precision).
+ RunningStatistics<int> stats;
+ stats.AddSample(2);
+ stats.AddSample(2);
+ stats.AddSample(-1);
+ stats.AddSample(5);
+
+ constexpr int iterations = 1e4;
+ for (int i = 0; i < iterations; ++i) {
+ stats.AddSample(i);
+ }
+ for (int i = 0; i < iterations; ++i) {
+ stats.RemoveSample(i);
+ }
+
+ EXPECT_NEAR(*stats.GetMean(), 2.0, 1e-7);
+ EXPECT_NEAR(*stats.GetVariance(), 4.5, 1e-3);
+ EXPECT_NEAR(*stats.GetStandardDeviation(), sqrt(4.5), 1e-4);
+}
+
+TEST(RunningStatistics, VarianceFromUniformDistribution) {
// Check variance converge to 1/12 for [0;1) uniform distribution.
// Acts as a sanity check for NumericStabilityForVariance test.
auto stats = CreateStatsFromUniformDistribution(1e6, 0, 1);
@@ -84,7 +126,7 @@
EXPECT_NEAR(*stats.GetVariance(), 1. / 12, 1e-3);
}
-TEST(RunningStatisticsTest, NumericStabilityForVariance) {
+TEST(RunningStatistics, NumericStabilityForVariance) {
// Same test as VarianceFromUniformDistribution,
// except the range is shifted to [1e9;1e9+1).
// Variance should also converge to 1/12.
@@ -96,6 +138,26 @@
EXPECT_NEAR(*stats.GetVariance(), 1. / 12, 1e-3);
}
+TEST(RunningStatistics, MinRemainsUnchangedAfterRemove) {
+ // We don't want to recompute min (that's RollingAccumulator's role),
+ // check we get the overall min.
+ RunningStatistics<int> stats;
+ stats.AddSample(1);
+ stats.AddSample(2);
+ stats.RemoveSample(1);
+ EXPECT_EQ(stats.GetMin(), 1);
+}
+
+TEST(RunningStatistics, MaxRemainsUnchangedAfterRemove) {
+ // We don't want to recompute max (that's RollingAccumulator's role),
+ // check we get the overall max.
+ RunningStatistics<int> stats;
+ stats.AddSample(1);
+ stats.AddSample(2);
+ stats.RemoveSample(2);
+ EXPECT_EQ(stats.GetMax(), 2);
+}
+
TEST_P(RunningStatisticsTest, MergeStatistics) {
int data[SIZE_FOR_MERGE] = {2, 2, -1, 5, 10};
// Split the data in different partitions.
diff --git a/rtc_base/rolling_accumulator.h b/rtc_base/rolling_accumulator.h
index dacceff..b630554 100644
--- a/rtc_base/rolling_accumulator.h
+++ b/rtc_base/rolling_accumulator.h
@@ -17,6 +17,7 @@
#include "rtc_base/checks.h"
#include "rtc_base/constructor_magic.h"
+#include "rtc_base/numerics/running_statistics.h"
namespace rtc {
@@ -28,19 +29,18 @@
class RollingAccumulator {
public:
explicit RollingAccumulator(size_t max_count) : samples_(max_count) {
+ RTC_DCHECK(max_count > 0);
Reset();
}
~RollingAccumulator() {}
size_t max_count() const { return samples_.size(); }
- size_t count() const { return count_; }
+ size_t count() const { return static_cast<size_t>(stats_.Size()); }
void Reset() {
- count_ = 0U;
+ stats_ = webrtc::RunningStatistics<T>();
next_index_ = 0U;
- sum_ = 0.0;
- sum_2_ = 0.0;
max_ = T();
max_stale_ = false;
min_ = T();
@@ -48,52 +48,40 @@
}
void AddSample(T sample) {
- if (count_ == max_count()) {
+ if (count() == max_count()) {
// Remove oldest sample.
T sample_to_remove = samples_[next_index_];
- sum_ -= sample_to_remove;
- sum_2_ -= static_cast<double>(sample_to_remove) * sample_to_remove;
+ stats_.RemoveSample(sample_to_remove);
if (sample_to_remove >= max_) {
max_stale_ = true;
}
if (sample_to_remove <= min_) {
min_stale_ = true;
}
- } else {
- // Increase count of samples.
- ++count_;
}
// Add new sample.
samples_[next_index_] = sample;
- sum_ += sample;
- sum_2_ += static_cast<double>(sample) * sample;
- if (count_ == 1 || sample >= max_) {
+ if (count() == 0 || sample >= max_) {
max_ = sample;
max_stale_ = false;
}
- if (count_ == 1 || sample <= min_) {
+ if (count() == 0 || sample <= min_) {
min_ = sample;
min_stale_ = false;
}
+ stats_.AddSample(sample);
// Update next_index_.
next_index_ = (next_index_ + 1) % max_count();
}
- T ComputeSum() const { return static_cast<T>(sum_); }
-
- double ComputeMean() const {
- if (count_ == 0) {
- return 0.0;
- }
- return sum_ / count_;
- }
+ double ComputeMean() const { return stats_.GetMean().value_or(0); }
T ComputeMax() const {
if (max_stale_) {
- RTC_DCHECK(count_ > 0)
- << "It shouldn't be possible for max_stale_ && count_ == 0";
+ RTC_DCHECK(count() > 0)
+ << "It shouldn't be possible for max_stale_ && count() == 0";
max_ = samples_[next_index_];
- for (size_t i = 1u; i < count_; i++) {
+ for (size_t i = 1u; i < count(); i++) {
max_ = std::max(max_, samples_[(next_index_ + i) % max_count()]);
}
max_stale_ = false;
@@ -103,10 +91,10 @@
T ComputeMin() const {
if (min_stale_) {
- RTC_DCHECK(count_ > 0)
- << "It shouldn't be possible for min_stale_ && count_ == 0";
+ RTC_DCHECK(count() > 0)
+ << "It shouldn't be possible for min_stale_ && count() == 0";
min_ = samples_[next_index_];
- for (size_t i = 1u; i < count_; i++) {
+ for (size_t i = 1u; i < count(); i++) {
min_ = std::min(min_, samples_[(next_index_ + i) % max_count()]);
}
min_stale_ = false;
@@ -118,14 +106,14 @@
// Weights nth sample with weight (learning_rate)^n. Learning_rate should be
// between (0.0, 1.0], otherwise the non-weighted mean is returned.
double ComputeWeightedMean(double learning_rate) const {
- if (count_ < 1 || learning_rate <= 0.0 || learning_rate >= 1.0) {
+ if (count() < 1 || learning_rate <= 0.0 || learning_rate >= 1.0) {
return ComputeMean();
}
double weighted_mean = 0.0;
double current_weight = 1.0;
double weight_sum = 0.0;
const size_t max_size = max_count();
- for (size_t i = 0; i < count_; ++i) {
+ for (size_t i = 0; i < count(); ++i) {
current_weight *= learning_rate;
weight_sum += current_weight;
// Add max_size to prevent underflow.
@@ -137,22 +125,11 @@
// Compute estimated variance. Estimation is more accurate
// as the number of samples grows.
- double ComputeVariance() const {
- if (count_ == 0) {
- return 0.0;
- }
- // Var = E[x^2] - (E[x])^2
- double count_inv = 1.0 / count_;
- double mean_2 = sum_2_ * count_inv;
- double mean = sum_ * count_inv;
- return mean_2 - (mean * mean);
- }
+ double ComputeVariance() const { return stats_.GetVariance().value_or(0); }
private:
- size_t count_;
+ webrtc::RunningStatistics<T> stats_;
size_t next_index_;
- double sum_; // Sum(x) - double to avoid overflow
- double sum_2_; // Sum(x*x) - double to avoid overflow
mutable T max_;
mutable bool max_stale_;
mutable T min_;
diff --git a/rtc_base/rolling_accumulator_unittest.cc b/rtc_base/rolling_accumulator_unittest.cc
index 7d5e70d..f6835aa 100644
--- a/rtc_base/rolling_accumulator_unittest.cc
+++ b/rtc_base/rolling_accumulator_unittest.cc
@@ -8,6 +8,8 @@
* be found in the AUTHORS file in the root of the source tree.
*/
+#include <random>
+
#include "rtc_base/rolling_accumulator.h"
#include "test/gtest.h"
@@ -18,6 +20,18 @@
const double kLearningRate = 0.5;
+// Add |n| samples drawn from uniform distribution in [a;b].
+void FillStatsFromUniformDistribution(RollingAccumulator<double>& stats,
+ int n,
+ double a,
+ double b) {
+ std::mt19937 gen{std::random_device()()};
+ std::uniform_real_distribution<> dis(a, b);
+
+ for (int i = 1; i <= n; i++) {
+ stats.AddSample(dis(gen));
+ }
+}
} // namespace
TEST(RollingAccumulatorTest, ZeroSamples) {
@@ -37,7 +51,6 @@
}
EXPECT_EQ(4U, accum.count());
- EXPECT_EQ(6, accum.ComputeSum());
EXPECT_DOUBLE_EQ(1.5, accum.ComputeMean());
EXPECT_NEAR(2.26666, accum.ComputeWeightedMean(kLearningRate), 0.01);
EXPECT_DOUBLE_EQ(1.25, accum.ComputeVariance());
@@ -52,7 +65,6 @@
}
EXPECT_EQ(10U, accum.count());
- EXPECT_EQ(65, accum.ComputeSum());
EXPECT_DOUBLE_EQ(6.5, accum.ComputeMean());
EXPECT_NEAR(10.0, accum.ComputeWeightedMean(kLearningRate), 0.01);
EXPECT_NEAR(9.0, accum.ComputeVariance(), 1.0);
@@ -79,7 +91,6 @@
}
EXPECT_EQ(5U, accum.count());
- EXPECT_EQ(10, accum.ComputeSum());
EXPECT_DOUBLE_EQ(2.0, accum.ComputeMean());
EXPECT_EQ(0, accum.ComputeMin());
EXPECT_EQ(4, accum.ComputeMax());
@@ -92,7 +103,6 @@
}
EXPECT_EQ(10u, accum.count());
- EXPECT_DOUBLE_EQ(875.0, accum.ComputeSum());
EXPECT_DOUBLE_EQ(87.5, accum.ComputeMean());
EXPECT_NEAR(105.049, accum.ComputeWeightedMean(kLearningRate), 0.1);
EXPECT_NEAR(229.166667, accum.ComputeVariance(), 25);
@@ -116,4 +126,25 @@
EXPECT_NEAR(6.0, accum.ComputeWeightedMean(kLearningRate), 0.1);
}
+TEST(RollingAccumulatorTest, VarianceFromUniformDistribution) {
+ // Check variance converge to 1/12 for [0;1) uniform distribution.
+ // Acts as a sanity check for NumericStabilityForVariance test.
+ RollingAccumulator<double> stats(/*max_count=*/0.5e6);
+ FillStatsFromUniformDistribution(stats, 1e6, 0, 1);
+
+ EXPECT_NEAR(stats.ComputeVariance(), 1. / 12, 1e-3);
+}
+
+TEST(RollingAccumulatorTest, NumericStabilityForVariance) {
+ // Same test as VarianceFromUniformDistribution,
+ // except the range is shifted to [1e9;1e9+1).
+ // Variance should also converge to 1/12.
+ // NB: Although we lose precision for the samples themselves, the fractional
+ // part still enjoys 22 bits of mantissa and errors should even out,
+ // so that couldn't explain a mismatch.
+ RollingAccumulator<double> stats(/*max_count=*/0.5e6);
+ FillStatsFromUniformDistribution(stats, 1e6, 1e9, 1e9 + 1);
+
+ EXPECT_NEAR(stats.ComputeVariance(), 1. / 12, 1e-3);
+}
} // namespace rtc