Add support for corruption classification.
This class calculates the corruption score based on the given samples from two frames.
Bug: webrtc:358039777
Change-Id: Ib036f91ec16609e827137cc35d342a2c49764737
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/362801
Reviewed-by: Erik Språng <sprang@webrtc.org>
Reviewed-by: Fanny Linderborg <linderborg@webrtc.org>
Commit-Queue: Emil Vardar (xWF) <vardar@google.com>
Cr-Commit-Position: refs/heads/main@{#43043}
diff --git a/video/corruption_detection/BUILD.gn b/video/corruption_detection/BUILD.gn
index bbcc122..d37e359 100644
--- a/video/corruption_detection/BUILD.gn
+++ b/video/corruption_detection/BUILD.gn
@@ -8,6 +8,19 @@
import("../../webrtc.gni")
+rtc_library("corruption_classifier") {
+ sources = [
+ "corruption_classifier.cc",
+ "corruption_classifier.h",
+ ]
+ deps = [
+ ":halton_frame_sampler",
+ "../../api:array_view",
+ "../../rtc_base:checks",
+ "../../rtc_base:logging",
+ ]
+}
+
rtc_library("frame_instrumentation_generator") {
sources = [
"frame_instrumentation_generator.cc",
@@ -66,6 +79,16 @@
}
if (rtc_include_tests) {
+ rtc_library("corruption_classifier_unittest") {
+ testonly = true
+ sources = [ "corruption_classifier_unittest.cc" ]
+ deps = [
+ ":corruption_classifier",
+ ":halton_frame_sampler",
+ "../../test:test_support",
+ ]
+ }
+
rtc_library("frame_instrumentation_generator_unittest") {
testonly = true
sources = [ "frame_instrumentation_generator_unittest.cc" ]
@@ -115,6 +138,7 @@
testonly = true
sources = []
deps = [
+ ":corruption_classifier_unittest",
":frame_instrumentation_generator_unittest",
":generic_mapping_functions_unittest",
":halton_frame_sampler_unittest",
diff --git a/video/corruption_detection/corruption_classifier.cc b/video/corruption_detection/corruption_classifier.cc
new file mode 100644
index 0000000..a4fc167
--- /dev/null
+++ b/video/corruption_detection/corruption_classifier.cc
@@ -0,0 +1,107 @@
+/*
+ * Copyright 2024 The WebRTC project authors. All rights reserved.
+ *
+ * Use of this source code is governed by a BSD-style license
+ * that can be found in the LICENSE file in the root of the source
+ * tree. An additional intellectual property rights grant can be found
+ * in the file PATENTS. All contributing project authors may
+ * be found in the AUTHORS file in the root of the source tree.
+ */
+
+#include "video/corruption_detection/corruption_classifier.h"
+
+#include <algorithm>
+#include <cmath>
+#include <variant>
+
+#include "api/array_view.h"
+#include "rtc_base/checks.h"
+#include "rtc_base/logging.h"
+#include "video/corruption_detection/halton_frame_sampler.h"
+
+namespace webrtc {
+
+CorruptionClassifier::CorruptionClassifier(float scale_factor)
+ : config_(ScalarConfig{.scale_factor = scale_factor}) {
+ RTC_CHECK_GT(scale_factor, 0) << "The scale factor must be positive.";
+ RTC_LOG(LS_INFO) << "Calculating corruption probability using scale factor.";
+}
+
+CorruptionClassifier::CorruptionClassifier(float growth_rate, float midpoint)
+ : config_(LogisticFunctionConfig{.growth_rate = growth_rate,
+ .midpoint = midpoint}) {
+ RTC_CHECK_GT(growth_rate, 0)
+ << "As the `score` is defined now (low score means probably not "
+ "corrupted and vice versa), the growth rate must be positive to have "
+ "a logistic function that is monotonically increasing.";
+ RTC_LOG(LS_INFO)
+ << "Calculating corruption probability using logistic function.";
+}
+
+double CorruptionClassifier::CalculateCorruptionProbablility(
+ rtc::ArrayView<const FilteredSample> filtered_original_samples,
+ rtc::ArrayView<const FilteredSample> filtered_compressed_samples,
+ int luma_threshold,
+ int chroma_threshold) const {
+ RTC_DCHECK_GT(luma_threshold, 0) << "Luma threshold must be positive.";
+ RTC_DCHECK_GT(chroma_threshold, 0) << "Chroma threshold must be positive.";
+ RTC_DCHECK_EQ(filtered_original_samples.size(),
+ filtered_compressed_samples.size())
+ << "The original and compressed frame have a different amount of "
+ "filtered samples.";
+
+ double loss = GetScore(filtered_original_samples, filtered_compressed_samples,
+ luma_threshold, chroma_threshold);
+
+ if (const auto* scalar_config = std::get_if<ScalarConfig>(&config_)) {
+ // Fitting the unbounded loss to the interval of [0, 1] using a simple scale
+ // factor and capping the loss to 1.
+ return std::min(loss / scalar_config->scale_factor, 1.0);
+ }
+
+ const auto config = std::get_if<LogisticFunctionConfig>(&config_);
+ RTC_DCHECK(config);
+ // Fitting the unbounded loss to the interval of [0, 1] using the logistic
+ // function.
+ return 1 / (1 + std::exp(-config->growth_rate * (loss - config->midpoint)));
+}
+
+// The score is calculated according to the following formula :
+//
+// score = (sum_i max{(|original_i - compressed_i| - threshold, 0)^2}) / N
+//
+// where N is the number of samples, i in [0, N), and the threshold is
+// either `luma_threshold` or `chroma_threshold` depending on whether the
+// sample is luma or chroma.
+double CorruptionClassifier::GetScore(
+ rtc::ArrayView<const FilteredSample> filtered_original_samples,
+ rtc::ArrayView<const FilteredSample> filtered_compressed_samples,
+ int luma_threshold,
+ int chroma_threshold) const {
+ RTC_CHECK_EQ(filtered_original_samples.size(),
+ filtered_compressed_samples.size());
+ const int num_samples = filtered_original_samples.size();
+ double sum = 0.0;
+ for (int i = 0; i < num_samples; ++i) {
+ RTC_CHECK_EQ(filtered_original_samples[i].plane,
+ filtered_compressed_samples[i].plane);
+ double abs_diff = std::abs(filtered_original_samples[i].value -
+ filtered_compressed_samples[i].value);
+ switch (filtered_original_samples[i].plane) {
+ case ImagePlane::kLuma:
+ if (abs_diff > luma_threshold) {
+ sum += std::pow(abs_diff - luma_threshold, 2);
+ }
+ break;
+ case ImagePlane::kChroma:
+ if (abs_diff > chroma_threshold) {
+ sum += std::pow(abs_diff - chroma_threshold, 2);
+ }
+ break;
+ }
+ }
+
+ return sum / num_samples;
+}
+
+} // namespace webrtc
diff --git a/video/corruption_detection/corruption_classifier.h b/video/corruption_detection/corruption_classifier.h
new file mode 100644
index 0000000..8e0c061
--- /dev/null
+++ b/video/corruption_detection/corruption_classifier.h
@@ -0,0 +1,75 @@
+/*
+ * Copyright 2024 The WebRTC project authors. All rights reserved.
+ *
+ * Use of this source code is governed by a BSD-style license
+ * that can be found in the LICENSE file in the root of the source
+ * tree. An additional intellectual property rights grant can be found
+ * in the file PATENTS. All contributing project authors may
+ * be found in the AUTHORS file in the root of the source tree.
+ */
+
+#ifndef VIDEO_CORRUPTION_DETECTION_CORRUPTION_CLASSIFIER_H_
+#define VIDEO_CORRUPTION_DETECTION_CORRUPTION_CLASSIFIER_H_
+
+#include <variant>
+
+#include "api/array_view.h"
+#include "video/corruption_detection/halton_frame_sampler.h"
+
+namespace webrtc {
+
+// Based on the given filtered samples to `CalculateCorruptionProbablility` this
+// class calculates a probability to indicate whether the frame is corrupted.
+// The classification is done either by scaling the loss to the interval of [0,
+// 1] using a simple `scale_factor` or by applying a logistic function to the
+// loss. The logistic function is constructed based on `growth_rate` and
+// `midpoint`, to the score between the original and the compressed frames'
+// samples. This score is calculated using `GetScore`.
+//
+// TODO: bugs.webrtc.org/358039777 - Remove one of the constructors based on
+// which mapping function works best in practice.
+class CorruptionClassifier {
+ public:
+ // Calculates the corruption probability using a simple scale factor.
+ explicit CorruptionClassifier(float scale_factor);
+ // Calculates the corruption probability using a logistic function.
+ CorruptionClassifier(float growth_rate, float midpoint);
+ ~CorruptionClassifier() = default;
+
+ // This function calculates and returns the probability (in the interval [0,
+ // 1] that a frame is corrupted. The probability is determined either by
+ // scaling the loss to the interval of [0, 1] using a simple `scale_factor`
+ // or by applying a logistic function to the loss. The method is chosen
+ // depending on the used constructor.
+ double CalculateCorruptionProbablility(
+ rtc::ArrayView<const FilteredSample> filtered_original_samples,
+ rtc::ArrayView<const FilteredSample> filtered_compressed_samples,
+ int luma_threshold,
+ int chroma_threshold) const;
+
+ private:
+ struct ScalarConfig {
+ float scale_factor;
+ };
+
+ // Logistic function parameters. See
+ // https://en.wikipedia.org/wiki/Logistic_function.
+ struct LogisticFunctionConfig {
+ float growth_rate;
+ float midpoint;
+ };
+
+ // Returns the non-normalized score between the original and the compressed
+ // frames' samples.
+ double GetScore(
+ rtc::ArrayView<const FilteredSample> filtered_original_samples,
+ rtc::ArrayView<const FilteredSample> filtered_compressed_samples,
+ int luma_threshold,
+ int chroma_threshold) const;
+
+ const std::variant<ScalarConfig, LogisticFunctionConfig> config_;
+};
+
+} // namespace webrtc
+
+#endif // VIDEO_CORRUPTION_DETECTION_CORRUPTION_CLASSIFIER_H_
diff --git a/video/corruption_detection/corruption_classifier_unittest.cc b/video/corruption_detection/corruption_classifier_unittest.cc
new file mode 100644
index 0000000..1fbdb29
--- /dev/null
+++ b/video/corruption_detection/corruption_classifier_unittest.cc
@@ -0,0 +1,269 @@
+/*
+ * Copyright 2024 The WebRTC project authors. All rights reserved.
+ *
+ * Use of this source code is governed by a BSD-style license
+ * that can be found in the LICENSE file in the root of the source
+ * tree. An additional intellectual property rights grant can be found
+ * in the file PATENTS. All contributing project authors may
+ * be found in the AUTHORS file in the root of the source tree.
+ */
+
+#include "video/corruption_detection/corruption_classifier.h"
+
+#include <vector>
+
+#include "test/gmock.h"
+#include "test/gtest.h"
+#include "video/corruption_detection/halton_frame_sampler.h"
+
+namespace webrtc {
+namespace {
+
+using ::testing::DoubleNear;
+
+constexpr int kLumaThreshold = 3;
+constexpr int kChromaThreshold = 2;
+
+constexpr double kMaxAbsoluteError = 1e-4;
+
+// Arbitrary values for testing.
+constexpr double kBaseOriginalLumaSampleValue1 = 1.0;
+constexpr double kBaseOriginalLumaSampleValue2 = 2.5;
+constexpr double kBaseOriginalChromaSampleValue1 = 0.5;
+
+constexpr FilteredSample kFilteredOriginalSampleValues[] = {
+ {.value = kBaseOriginalLumaSampleValue1, .plane = ImagePlane::kLuma},
+ {.value = kBaseOriginalLumaSampleValue2, .plane = ImagePlane::kLuma},
+ {.value = kBaseOriginalChromaSampleValue1, .plane = ImagePlane::kChroma}};
+
+// The value 14.0 corresponds to the corruption probability being on the same
+// side of 0.5 in the `ScalarConfig` and `LogisticFunctionConfig`.
+constexpr float kScaleFactor = 14.0;
+
+constexpr float kGrowthRate = 1.0;
+constexpr float kMidpoint = 7.0;
+
+// Helper function to create fake compressed sample values.
+std::vector<FilteredSample> GetCompressedSampleValues(
+ double increase_value_luma,
+ double increase_value_chroma) {
+ return std::vector<FilteredSample>{
+ {.value = kBaseOriginalLumaSampleValue1 + increase_value_luma,
+ .plane = ImagePlane::kLuma},
+ {.value = kBaseOriginalLumaSampleValue2 + increase_value_luma,
+ .plane = ImagePlane::kLuma},
+ {.value = kBaseOriginalChromaSampleValue1 + increase_value_chroma,
+ .plane = ImagePlane::kChroma}};
+}
+
+TEST(CorruptionClassifierTest,
+ SameSampleValuesShouldResultInNoCorruptionScalarConfig) {
+ float kIncreaseValue = 0.0;
+ const std::vector<FilteredSample> kFilteredCompressedSampleValues =
+ GetCompressedSampleValues(kIncreaseValue, kIncreaseValue);
+
+ CorruptionClassifier corruption_classifier(kScaleFactor);
+
+ // Expected: score = 0.
+ // Note that the `score` above corresponds to the value returned by the
+ // `GetScore` function. Then this value should be passed through the Scalar or
+ // Logistic function giving the expected result inside DoubleNear. This
+ // applies for all the following tests.
+ EXPECT_THAT(
+ corruption_classifier.CalculateCorruptionProbablility(
+ kFilteredOriginalSampleValues, kFilteredCompressedSampleValues,
+ kLumaThreshold, kChromaThreshold),
+ DoubleNear(0.0, kMaxAbsoluteError));
+}
+
+TEST(CorruptionClassifierTest,
+ SameSampleValuesShouldResultInNoCorruptionLogisticFunctionConfig) {
+ float kIncreaseValue = 0.0;
+ const std::vector<FilteredSample> kFilteredCompressedSampleValues =
+ GetCompressedSampleValues(kIncreaseValue, kIncreaseValue);
+
+ CorruptionClassifier corruption_classifier(kGrowthRate, kMidpoint);
+
+ // Expected: score = 0. See above for explanation why we have `0.0009` below.
+ EXPECT_THAT(
+ corruption_classifier.CalculateCorruptionProbablility(
+ kFilteredOriginalSampleValues, kFilteredCompressedSampleValues,
+ kLumaThreshold, kChromaThreshold),
+ DoubleNear(0.0009, kMaxAbsoluteError));
+}
+
+TEST(CorruptionClassifierTest,
+ NoCorruptionWhenAllSampleDifferencesBelowThresholdScalarConfig) {
+ // Following value should be < `kLumaThreshold` and `kChromaThreshold`.
+ const double kIncreaseValue = 1;
+ const std::vector<FilteredSample> kFilteredCompressedSampleValues =
+ GetCompressedSampleValues(kIncreaseValue, kIncreaseValue);
+
+ CorruptionClassifier corruption_classifier(kScaleFactor);
+
+ // Expected: score = 0.
+ EXPECT_THAT(
+ corruption_classifier.CalculateCorruptionProbablility(
+ kFilteredOriginalSampleValues, kFilteredCompressedSampleValues,
+ kLumaThreshold, kChromaThreshold),
+ DoubleNear(0.0, kMaxAbsoluteError));
+}
+
+TEST(CorruptionClassifierTest,
+ NoCorruptionWhenAllSampleDifferencesBelowThresholdLogisticFunctionConfig) {
+ // Following value should be < `kLumaThreshold` and `kChromaThreshold`.
+ const double kIncreaseValue = 1;
+ const std::vector<FilteredSample> kFilteredCompressedSampleValues =
+ GetCompressedSampleValues(kIncreaseValue, kIncreaseValue);
+
+ CorruptionClassifier corruption_classifier(kGrowthRate, kMidpoint);
+
+ // Expected: score = 0.
+ EXPECT_THAT(
+ corruption_classifier.CalculateCorruptionProbablility(
+ kFilteredOriginalSampleValues, kFilteredCompressedSampleValues,
+ kLumaThreshold, kChromaThreshold),
+ DoubleNear(0.0009, kMaxAbsoluteError));
+}
+
+TEST(CorruptionClassifierTest,
+ NoCorruptionWhenSmallPartOfSamplesAboveThresholdScalarConfig) {
+ const double kIncreaseValueLuma = 1;
+ const double kIncreaseValueChroma = 2.5; // Above `kChromaThreshold`.
+ const std::vector<FilteredSample> kFilteredCompressedSampleValues =
+ GetCompressedSampleValues(kIncreaseValueLuma, kIncreaseValueChroma);
+
+ CorruptionClassifier corruption_classifier(kScaleFactor);
+
+ // Expected: score = (0.5)^2 / 3.
+ EXPECT_THAT(
+ corruption_classifier.CalculateCorruptionProbablility(
+ kFilteredOriginalSampleValues, kFilteredCompressedSampleValues,
+ kLumaThreshold, kChromaThreshold),
+ DoubleNear(0.0060, kMaxAbsoluteError));
+}
+
+TEST(CorruptionClassifierTest,
+ NoCorruptionWhenSmallPartOfSamplesAboveThresholdLogisticFunctionConfig) {
+ const double kIncreaseValueLuma = 1;
+ const double kIncreaseValueChroma = 2.5; // Above `kChromaThreshold`.
+ const std::vector<FilteredSample> kFilteredCompressedSampleValues =
+ GetCompressedSampleValues(kIncreaseValueLuma, kIncreaseValueChroma);
+
+ CorruptionClassifier corruption_classifier(kGrowthRate, kMidpoint);
+
+ // Expected: score = (0.5)^2 / 3.
+ EXPECT_THAT(
+ corruption_classifier.CalculateCorruptionProbablility(
+ kFilteredOriginalSampleValues, kFilteredCompressedSampleValues,
+ kLumaThreshold, kChromaThreshold),
+ DoubleNear(0.001, kMaxAbsoluteError));
+}
+
+TEST(CorruptionClassifierTest,
+ NoCorruptionWhenAllSamplesSlightlyAboveThresholdScalarConfig) {
+ const double kIncreaseValueLuma = 4.2; // Above `kLumaThreshold`.
+ const double kIncreaseValueChroma = 2.5; // Above `kChromaThreshold`.
+ const std::vector<FilteredSample> kFilteredCompressedSampleValues =
+ GetCompressedSampleValues(kIncreaseValueLuma, kIncreaseValueChroma);
+
+ CorruptionClassifier corruption_classifier(kScaleFactor);
+
+ // Expected: score = ((0.5)^2 + 2*(1.2)^2) / 3.
+ EXPECT_THAT(
+ corruption_classifier.CalculateCorruptionProbablility(
+ kFilteredOriginalSampleValues, kFilteredCompressedSampleValues,
+ kLumaThreshold, kChromaThreshold),
+ DoubleNear(0.07452, kMaxAbsoluteError));
+}
+
+TEST(CorruptionClassifierTest,
+ NoCorruptionWhenAllSamplesSlightlyAboveThresholdLogisticFunctionConfig) {
+ const double kIncreaseValueLuma = 4.2; // Above `kLumaThreshold`.
+ const double kIncreaseValueChroma = 2.5; // Above `kChromaThreshold`.
+ const std::vector<FilteredSample> kFilteredCompressedSampleValues =
+ GetCompressedSampleValues(kIncreaseValueLuma, kIncreaseValueChroma);
+
+ CorruptionClassifier corruption_classifier(kGrowthRate, kMidpoint);
+
+ // Expected: score = ((0.5)^2 + 2*(1.2)^2) / 3.
+ EXPECT_THAT(
+ corruption_classifier.CalculateCorruptionProbablility(
+ kFilteredOriginalSampleValues, kFilteredCompressedSampleValues,
+ kLumaThreshold, kChromaThreshold),
+ DoubleNear(0.0026, kMaxAbsoluteError));
+}
+
+// Observe that the following 2 tests in practice could be classified as
+// corrupted, if so wanted. However, with the `kGrowthRate`, `kMidpoint` and
+// `kScaleFactor` values chosen in these tests, the score is not high enough to
+// be classified as corrupted.
+TEST(CorruptionClassifierTest,
+ NoCorruptionWhenAllSamplesSomewhatAboveThresholdScalarConfig) {
+ const double kIncreaseValue = 5.0;
+ const std::vector<FilteredSample> kFilteredCompressedSampleValues =
+ GetCompressedSampleValues(kIncreaseValue, kIncreaseValue);
+
+ CorruptionClassifier corruption_classifier(kScaleFactor);
+
+ // Expected: score = ((3)^2 + 2*(2)^2) / 3.
+ EXPECT_THAT(
+ corruption_classifier.CalculateCorruptionProbablility(
+ kFilteredOriginalSampleValues, kFilteredCompressedSampleValues,
+ kLumaThreshold, kChromaThreshold),
+ DoubleNear(0.4048, kMaxAbsoluteError));
+}
+
+TEST(CorruptionClassifierTest,
+ NoCorruptionWhenAllSamplesSomewhatAboveThresholdLogisticFunctionConfig) {
+ // Somewhat above `kLumaThreshold` and `kChromaThreshold`.
+ const double kIncreaseValue = 5.0;
+ const std::vector<FilteredSample> kFilteredCompressedSampleValues =
+ GetCompressedSampleValues(kIncreaseValue, kIncreaseValue);
+
+ CorruptionClassifier corruption_classifier(kGrowthRate, kMidpoint);
+
+ // Expected: score = ((3)^2 + 2*(2)^2) / 3.
+ EXPECT_THAT(
+ corruption_classifier.CalculateCorruptionProbablility(
+ kFilteredOriginalSampleValues, kFilteredCompressedSampleValues,
+ kLumaThreshold, kChromaThreshold),
+ DoubleNear(0.2086, kMaxAbsoluteError));
+}
+
+TEST(CorruptionClassifierTest,
+ CorruptionWhenAllSamplesWellAboveThresholdScalarConfig) {
+ // Well above `kLumaThreshold` and `kChromaThreshold`.
+ const double kIncreaseValue = 7.0;
+ const std::vector<FilteredSample> kFilteredCompressedSampleValues =
+ GetCompressedSampleValues(kIncreaseValue, kIncreaseValue);
+
+ CorruptionClassifier corruption_classifier(kScaleFactor);
+
+ // Expected: score = ((5)^2 + 2*(4)^2) / 3. Expected 1 because of capping.
+ EXPECT_THAT(
+ corruption_classifier.CalculateCorruptionProbablility(
+ kFilteredOriginalSampleValues, kFilteredCompressedSampleValues,
+ kLumaThreshold, kChromaThreshold),
+ DoubleNear(1, kMaxAbsoluteError));
+}
+
+TEST(CorruptionClassifierTest,
+ CorruptionWhenAllSamplesWellAboveThresholdLogisticFunctionConfig) {
+ // Well above `kLumaThreshold` and `kChromaThreshold`.
+ const double kIncreaseValue = 7.0;
+ const std::vector<FilteredSample> kFilteredCompressedSampleValues =
+ GetCompressedSampleValues(kIncreaseValue, kIncreaseValue);
+
+ CorruptionClassifier corruption_classifier(kGrowthRate, kMidpoint);
+
+ // Expected: score = ((5)^2 + 2*(4)^2) / 3.
+ EXPECT_THAT(
+ corruption_classifier.CalculateCorruptionProbablility(
+ kFilteredOriginalSampleValues, kFilteredCompressedSampleValues,
+ kLumaThreshold, kChromaThreshold),
+ DoubleNear(1, kMaxAbsoluteError));
+}
+
+} // namespace
+} // namespace webrtc