Provide robust and efficient variance computation for online statistics.
This CL implements Welford's algorithm for a
numerically stable computation of the variance.
This implementation is plugged in SamplesStatsCounter class (adapter pattern).
A 'NumericalStability' unit test has been added,
whose previous implementation of SamplesStatsCounter failed to pass.
Follow-up CLs will factorize more occurences of duplicated and misbehaved
computations.
Bug: webrtc:10412
Change-Id: Id807c3d34e9c780fb1cbd769d30b655c575c88ac
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/131394
Commit-Queue: Yves Gerey <yvesg@google.com>
Reviewed-by: Artem Titov <titovartem@webrtc.org>
Reviewed-by: Karl Wiberg <kwiberg@webrtc.org>
Reviewed-by: Mirko Bonadei <mbonadei@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#27547}
diff --git a/rtc_base/BUILD.gn b/rtc_base/BUILD.gn
index 2e7017d..6a2c5aa 100644
--- a/rtc_base/BUILD.gn
+++ b/rtc_base/BUILD.gn
@@ -588,10 +588,12 @@
sources = [
"numerics/exp_filter.cc",
"numerics/exp_filter.h",
+ "numerics/math_utils.h",
"numerics/moving_average.cc",
"numerics/moving_average.h",
"numerics/moving_median_filter.h",
"numerics/percentile_filter.h",
+ "numerics/running_statistics.h",
"numerics/samples_stats_counter.cc",
"numerics/samples_stats_counter.h",
"numerics/sequence_number_util.h",
@@ -1297,6 +1299,7 @@
"numerics/moving_average_unittest.cc",
"numerics/moving_median_filter_unittest.cc",
"numerics/percentile_filter_unittest.cc",
+ "numerics/running_statistics_unittest.cc",
"numerics/samples_stats_counter_unittest.cc",
"numerics/sequence_number_util_unittest.cc",
]
diff --git a/rtc_base/numerics/math_utils.h b/rtc_base/numerics/math_utils.h
index 8a91958..d5f3ee4 100644
--- a/rtc_base/numerics/math_utils.h
+++ b/rtc_base/numerics/math_utils.h
@@ -36,4 +36,39 @@
return static_cast<unsigned_type>(x) - static_cast<unsigned_type>(y);
}
+// Provide neutral element with respect to min().
+// Typically used as an initial value for running minimum.
+template <typename T,
+ typename std::enable_if<std::numeric_limits<T>::has_infinity>::type* =
+ nullptr>
+constexpr T infinity_or_max() {
+ return std::numeric_limits<T>::infinity();
+}
+
+template <typename T,
+ typename std::enable_if<
+ !std::numeric_limits<T>::has_infinity>::type* = nullptr>
+constexpr T infinity_or_max() {
+ // Fallback to max().
+ return std::numeric_limits<T>::max();
+}
+
+// Provide neutral element with respect to max().
+// Typically used as an initial value for running maximum.
+template <typename T,
+ typename std::enable_if<std::numeric_limits<T>::has_infinity>::type* =
+ nullptr>
+constexpr T minus_infinity_or_min() {
+ static_assert(std::is_signed<T>::value, "Unsupported. Please open a bug.");
+ return -std::numeric_limits<T>::infinity();
+}
+
+template <typename T,
+ typename std::enable_if<
+ !std::numeric_limits<T>::has_infinity>::type* = nullptr>
+constexpr T minus_infinity_or_min() {
+ // Fallback to min().
+ return std::numeric_limits<T>::min();
+}
+
#endif // RTC_BASE_NUMERICS_MATH_UTILS_H_
diff --git a/rtc_base/numerics/running_statistics.h b/rtc_base/numerics/running_statistics.h
new file mode 100644
index 0000000..d71323e
--- /dev/null
+++ b/rtc_base/numerics/running_statistics.h
@@ -0,0 +1,135 @@
+/*
+ * Copyright (c) 2019 The WebRTC project authors. All Rights Reserved.
+ *
+ * Use of this source code is governed by a BSD-style license
+ * that can be found in the LICENSE file in the root of the source
+ * tree. An additional intellectual property rights grant can be found
+ * in the file PATENTS. All contributing project authors may
+ * be found in the AUTHORS file in the root of the source tree.
+ */
+
+#ifndef RTC_BASE_NUMERICS_RUNNING_STATISTICS_H_
+#define RTC_BASE_NUMERICS_RUNNING_STATISTICS_H_
+
+#include <algorithm>
+#include <cmath>
+#include <limits>
+
+#include "absl/types/optional.h"
+
+#include "rtc_base/numerics/math_utils.h"
+
+namespace webrtc {
+
+// tl;dr: Robust and efficient online computation of statistics,
+// using Welford's method for variance. [1]
+//
+// This should be your go-to class if you ever need to compute
+// min, max, mean, variance and standard deviation.
+// If you need to get percentiles, please use webrtc::SamplesStatsCounter.
+//
+// The measures return absl::nullopt if no samples were fed (Size() == 0),
+// otherwise the returned optional is guaranteed to contain a value.
+//
+// [1]
+// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
+
+// The type T is a scalar which must be convertible to double.
+// Rationale: we often need greater precision for measures
+// than for the samples themselves.
+template <typename T>
+class RunningStatistics {
+ public:
+ // Update stats ////////////////////////////////////////////
+
+ // Add a value participating in the statistics in O(1) time.
+ void AddSample(T sample) {
+ max_ = std::max(max_, sample);
+ min_ = std::min(min_, sample);
+ ++size_;
+ // Welford's incremental update.
+ const double delta = sample - mean_;
+ mean_ += delta / size_;
+ const double delta2 = sample - mean_;
+ cumul_ += delta * delta2;
+ }
+
+ // Merge other stats, as if samples were added one by one, but in O(1).
+ void MergeStatistics(const RunningStatistics<T>& other) {
+ if (other.size_ == 0) {
+ return;
+ }
+ max_ = std::max(max_, other.max_);
+ min_ = std::min(min_, other.min_);
+ const int64_t new_size = size_ + other.size_;
+ const double new_mean =
+ (mean_ * size_ + other.mean_ * other.size_) / new_size;
+ // Each cumulant must be corrected.
+ // * from: sum((x_i - mean_)²)
+ // * to: sum((x_i - new_mean)²)
+ auto delta = [new_mean](const RunningStatistics<T>& stats) {
+ return stats.size_ * (new_mean * (new_mean - 2 * stats.mean_) +
+ stats.mean_ * stats.mean_);
+ };
+ cumul_ = cumul_ + delta(*this) + other.cumul_ + delta(other);
+ mean_ = new_mean;
+ size_ = new_size;
+ }
+
+ // Get Measures ////////////////////////////////////////////
+
+ // Returns number of samples involved,
+ // that is number of times AddSample() was called.
+ int64_t Size() const { return size_; }
+
+ // Returns min in O(1) time.
+ absl::optional<T> GetMin() const {
+ if (size_ == 0) {
+ return absl::nullopt;
+ }
+ return min_;
+ }
+
+ // Returns max in O(1) time.
+ absl::optional<T> GetMax() const {
+ if (size_ == 0) {
+ return absl::nullopt;
+ }
+ return max_;
+ }
+
+ // Returns mean in O(1) time.
+ absl::optional<double> GetMean() const {
+ if (size_ == 0) {
+ return absl::nullopt;
+ }
+ return mean_;
+ }
+
+ // Returns unbiased sample variance in O(1) time.
+ absl::optional<double> GetVariance() const {
+ if (size_ == 0) {
+ return absl::nullopt;
+ }
+ return cumul_ / size_;
+ }
+
+ // Returns unbiased standard deviation in O(1) time.
+ absl::optional<double> GetStandardDeviation() const {
+ if (size_ == 0) {
+ return absl::nullopt;
+ }
+ return std::sqrt(*GetVariance());
+ }
+
+ private:
+ int64_t size_ = 0; // Samples seen.
+ T min_ = infinity_or_max<T>();
+ T max_ = minus_infinity_or_min<T>();
+ double mean_ = 0;
+ double cumul_ = 0; // Variance * size_, sometimes noted m2.
+};
+
+} // namespace webrtc
+
+#endif // RTC_BASE_NUMERICS_RUNNING_STATISTICS_H_
diff --git a/rtc_base/numerics/running_statistics_unittest.cc b/rtc_base/numerics/running_statistics_unittest.cc
new file mode 100644
index 0000000..806b1e3
--- /dev/null
+++ b/rtc_base/numerics/running_statistics_unittest.cc
@@ -0,0 +1,131 @@
+/*
+ * Copyright (c) 2016 The WebRTC project authors. All Rights Reserved.
+ *
+ * Use of this source code is governed by a BSD-style license
+ * that can be found in the LICENSE file in the root of the source
+ * tree. An additional intellectual property rights grant can be found
+ * in the file PATENTS. All contributing project authors may
+ * be found in the AUTHORS file in the root of the source tree.
+ */
+
+#include "rtc_base/numerics/running_statistics.h"
+
+#include <math.h>
+#include <random>
+#include <vector>
+
+#include "absl/algorithm/container.h"
+#include "test/gtest.h"
+
+// Tests were copied from samples_stats_counter_unittest.cc.
+
+namespace webrtc {
+namespace {
+
+RunningStatistics<double> CreateStatsFilledWithIntsFrom1ToN(int n) {
+ std::vector<double> data;
+ for (int i = 1; i <= n; i++) {
+ data.push_back(i);
+ }
+ absl::c_shuffle(data, std::mt19937(std::random_device()()));
+
+ RunningStatistics<double> stats;
+ for (double v : data) {
+ stats.AddSample(v);
+ }
+ return stats;
+}
+
+// Add n samples drawn from uniform distribution in [a;b].
+RunningStatistics<double> CreateStatsFromUniformDistribution(int n,
+ double a,
+ double b) {
+ std::mt19937 gen{std::random_device()()};
+ std::uniform_real_distribution<> dis(a, b);
+
+ RunningStatistics<double> stats;
+ for (int i = 1; i <= n; i++) {
+ stats.AddSample(dis(gen));
+ }
+ return stats;
+}
+
+class RunningStatisticsTest : public ::testing::TestWithParam<int> {};
+
+constexpr int SIZE_FOR_MERGE = 5;
+
+} // namespace
+
+TEST(RunningStatisticsTest, FullSimpleTest) {
+ auto stats = CreateStatsFilledWithIntsFrom1ToN(100);
+
+ EXPECT_DOUBLE_EQ(*stats.GetMin(), 1.0);
+ EXPECT_DOUBLE_EQ(*stats.GetMax(), 100.0);
+ EXPECT_DOUBLE_EQ(*stats.GetMean(), 50.5);
+}
+
+TEST(RunningStatistics, VarianceAndDeviation) {
+ RunningStatistics<int> stats;
+ stats.AddSample(2);
+ stats.AddSample(2);
+ stats.AddSample(-1);
+ stats.AddSample(5);
+
+ EXPECT_DOUBLE_EQ(*stats.GetMean(), 2.0);
+ EXPECT_DOUBLE_EQ(*stats.GetVariance(), 4.5);
+ EXPECT_DOUBLE_EQ(*stats.GetStandardDeviation(), sqrt(4.5));
+}
+
+TEST(RunningStatisticsTest, VarianceFromUniformDistribution) {
+ // Check variance converge to 1/12 for [0;1) uniform distribution.
+ // Acts as a sanity check for NumericStabilityForVariance test.
+ auto stats = CreateStatsFromUniformDistribution(1e6, 0, 1);
+
+ EXPECT_NEAR(*stats.GetVariance(), 1. / 12, 1e-3);
+}
+
+TEST(RunningStatisticsTest, NumericStabilityForVariance) {
+ // Same test as VarianceFromUniformDistribution,
+ // except the range is shifted to [1e9;1e9+1).
+ // Variance should also converge to 1/12.
+ // NB: Although we lose precision for the samples themselves, the fractional
+ // part still enjoys 22 bits of mantissa and errors should even out,
+ // so that couldn't explain a mismatch.
+ auto stats = CreateStatsFromUniformDistribution(1e6, 1e9, 1e9 + 1);
+
+ EXPECT_NEAR(*stats.GetVariance(), 1. / 12, 1e-3);
+}
+
+TEST_P(RunningStatisticsTest, MergeStatistics) {
+ int data[SIZE_FOR_MERGE] = {2, 2, -1, 5, 10};
+ // Split the data in different partitions.
+ // We have 6 distinct tests:
+ // * Empty merged with full sequence.
+ // * 1 sample merged with 4 last.
+ // * 2 samples merged with 3 last.
+ // [...]
+ // * Full merged with empty sequence.
+ // All must lead to the same result.
+ // I miss QuickCheck so much.
+ RunningStatistics<int> stats0, stats1;
+ for (int i = 0; i < GetParam(); ++i) {
+ stats0.AddSample(data[i]);
+ }
+ for (int i = GetParam(); i < SIZE_FOR_MERGE; ++i) {
+ stats1.AddSample(data[i]);
+ }
+ stats0.MergeStatistics(stats1);
+
+ EXPECT_EQ(stats0.Size(), SIZE_FOR_MERGE);
+ EXPECT_DOUBLE_EQ(*stats0.GetMin(), -1);
+ EXPECT_DOUBLE_EQ(*stats0.GetMax(), 10);
+ EXPECT_DOUBLE_EQ(*stats0.GetMean(), 3.6);
+ EXPECT_DOUBLE_EQ(*stats0.GetVariance(), 13.84);
+ EXPECT_DOUBLE_EQ(*stats0.GetStandardDeviation(), sqrt(13.84));
+}
+
+INSTANTIATE_TEST_SUITE_P(RunningStatisticsTests,
+ RunningStatisticsTest,
+ ::testing::Range(0, SIZE_FOR_MERGE + 1));
+
+} // namespace webrtc
diff --git a/rtc_base/numerics/samples_stats_counter.cc b/rtc_base/numerics/samples_stats_counter.cc
index 134a65d..655f4c1 100644
--- a/rtc_base/numerics/samples_stats_counter.cc
+++ b/rtc_base/numerics/samples_stats_counter.cc
@@ -26,26 +26,15 @@
default;
void SamplesStatsCounter::AddSample(double value) {
+ stats_.AddSample(value);
samples_.push_back(value);
sorted_ = false;
- if (value > max_) {
- max_ = value;
- }
- if (value < min_) {
- min_ = value;
- }
- sum_ += value;
- sum_squared_ += value * value;
}
void SamplesStatsCounter::AddSamples(const SamplesStatsCounter& other) {
- for (double sample : other.samples_)
- samples_.push_back(sample);
+ stats_.MergeStatistics(other.stats_);
+ samples_.insert(samples_.end(), other.samples_.begin(), other.samples_.end());
sorted_ = false;
- max_ = std::max(max_, other.max_);
- min_ = std::min(min_, other.min_);
- sum_ += other.sum_;
- sum_squared_ += other.sum_squared_;
}
double SamplesStatsCounter::GetPercentile(double percentile) {
diff --git a/rtc_base/numerics/samples_stats_counter.h b/rtc_base/numerics/samples_stats_counter.h
index 05a8c14..ac5f12c 100644
--- a/rtc_base/numerics/samples_stats_counter.h
+++ b/rtc_base/numerics/samples_stats_counter.h
@@ -11,14 +11,15 @@
#ifndef RTC_BASE_NUMERICS_SAMPLES_STATS_COUNTER_H_
#define RTC_BASE_NUMERICS_SAMPLES_STATS_COUNTER_H_
-#include <math.h>
-#include <limits>
#include <vector>
#include "rtc_base/checks.h"
+#include "rtc_base/numerics/running_statistics.h"
namespace webrtc {
+// This class extends RunningStatistics by providing GetPercentile() method,
+// while slightly adapting the interface.
class SamplesStatsCounter {
public:
SamplesStatsCounter();
@@ -41,31 +42,31 @@
// samples.
double GetMin() const {
RTC_DCHECK(!IsEmpty());
- return min_;
+ return *stats_.GetMin();
}
// Returns max in O(1) time. This function may not be called if there are no
// samples.
double GetMax() const {
RTC_DCHECK(!IsEmpty());
- return max_;
+ return *stats_.GetMax();
}
// Returns average in O(1) time. This function may not be called if there are
// no samples.
double GetAverage() const {
RTC_DCHECK(!IsEmpty());
- return sum_ / samples_.size();
+ return *stats_.GetMean();
}
// Returns variance in O(1) time. This function may not be called if there are
// no samples.
double GetVariance() const {
RTC_DCHECK(!IsEmpty());
- return sum_squared_ / samples_.size() - GetAverage() * GetAverage();
+ return *stats_.GetVariance();
}
// Returns standard deviation in O(1) time. This function may not be called if
// there are no samples.
double GetStandardDeviation() const {
RTC_DCHECK(!IsEmpty());
- return sqrt(GetVariance());
+ return *stats_.GetStandardDeviation();
}
// Returns percentile in O(nlogn) on first call and in O(1) after, if no
// additions were done. This function may not be called if there are no
@@ -76,11 +77,8 @@
double GetPercentile(double percentile);
private:
+ RunningStatistics<double> stats_;
std::vector<double> samples_;
- double min_ = std::numeric_limits<double>::max();
- double max_ = std::numeric_limits<double>::min();
- double sum_ = 0;
- double sum_squared_ = 0;
bool sorted_ = false;
};
diff --git a/rtc_base/numerics/samples_stats_counter_unittest.cc b/rtc_base/numerics/samples_stats_counter_unittest.cc
index 8634295..590bf8c 100644
--- a/rtc_base/numerics/samples_stats_counter_unittest.cc
+++ b/rtc_base/numerics/samples_stats_counter_unittest.cc
@@ -34,6 +34,24 @@
return stats;
}
+// Add n samples drawn from uniform distribution in [a;b].
+SamplesStatsCounter CreateStatsFromUniformDistribution(int n,
+ double a,
+ double b) {
+ std::mt19937 gen{std::random_device()()};
+ std::uniform_real_distribution<> dis(a, b);
+
+ SamplesStatsCounter stats;
+ for (int i = 1; i <= n; i++) {
+ stats.AddSample(dis(gen));
+ }
+ return stats;
+}
+
+class SamplesStatsCounterTest : public ::testing::TestWithParam<int> {};
+
+constexpr int SIZE_FOR_MERGE = 10;
+
} // namespace
TEST(SamplesStatsCounter, FullSimpleTest) {
@@ -76,4 +94,58 @@
EXPECT_DOUBLE_EQ(stats.GetPercentile(1.0), 5);
}
+TEST(SamplesStatsCounter, VarianceFromUniformDistribution) {
+ // Check variance converge to 1/12 for [0;1) uniform distribution.
+ // Acts as a sanity check for NumericStabilityForVariance test.
+ SamplesStatsCounter stats = CreateStatsFromUniformDistribution(1e6, 0, 1);
+
+ EXPECT_NEAR(stats.GetVariance(), 1. / 12, 1e-3);
+}
+
+TEST(SamplesStatsCounter, NumericStabilityForVariance) {
+ // Same test as VarianceFromUniformDistribution,
+ // except the range is shifted to [1e9;1e9+1).
+ // Variance should also converge to 1/12.
+ // NB: Although we lose precision for the samples themselves, the fractional
+ // part still enjoys 22 bits of mantissa and errors should even out,
+ // so that couldn't explain a mismatch.
+ SamplesStatsCounter stats =
+ CreateStatsFromUniformDistribution(1e6, 1e9, 1e9 + 1);
+
+ EXPECT_NEAR(stats.GetVariance(), 1. / 12, 1e-3);
+}
+
+TEST_P(SamplesStatsCounterTest, AddSamples) {
+ int data[SIZE_FOR_MERGE] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
+ // Split the data in different partitions.
+ // We have 11 distinct tests:
+ // * Empty merged with full sequence.
+ // * 1 sample merged with 9 last.
+ // * 2 samples merged with 8 last.
+ // [...]
+ // * Full merged with empty sequence.
+ // All must lead to the same result.
+ SamplesStatsCounter stats0, stats1;
+ for (int i = 0; i < GetParam(); ++i) {
+ stats0.AddSample(data[i]);
+ }
+ for (int i = GetParam(); i < SIZE_FOR_MERGE; ++i) {
+ stats1.AddSample(data[i]);
+ }
+ stats0.AddSamples(stats1);
+
+ EXPECT_EQ(stats0.GetMin(), 0);
+ EXPECT_EQ(stats0.GetMax(), 9);
+ EXPECT_DOUBLE_EQ(stats0.GetAverage(), 4.5);
+ EXPECT_DOUBLE_EQ(stats0.GetVariance(), 8.25);
+ EXPECT_DOUBLE_EQ(stats0.GetStandardDeviation(), sqrt(8.25));
+ EXPECT_DOUBLE_EQ(stats0.GetPercentile(0.1), 0.9);
+ EXPECT_DOUBLE_EQ(stats0.GetPercentile(0.5), 4.5);
+ EXPECT_DOUBLE_EQ(stats0.GetPercentile(0.9), 8.1);
+}
+
+INSTANTIATE_TEST_SUITE_P(SamplesStatsCounterTests,
+ SamplesStatsCounterTest,
+ ::testing::Range(0, SIZE_FOR_MERGE + 1));
+
} // namespace webrtc