blob: 686eb753a4d706ff33cdf481bc2463de63fc8560 [file] [log] [blame]
/*
* Copyright 2024 The WebRTC project authors. All rights reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "video/corruption_detection/corruption_classifier.h"
#include <algorithm>
#include <cmath>
#include <variant>
#include "api/array_view.h"
#include "rtc_base/checks.h"
#include "rtc_base/logging.h"
#include "video/corruption_detection/halton_frame_sampler.h"
namespace webrtc {
CorruptionClassifier::CorruptionClassifier(float scale_factor)
: config_(ScalarConfig{.scale_factor = scale_factor}) {
RTC_CHECK_GT(scale_factor, 0) << "The scale factor must be positive.";
RTC_LOG(LS_INFO) << "Calculating corruption probability using scale factor.";
}
CorruptionClassifier::CorruptionClassifier(float growth_rate, float midpoint)
: config_(LogisticFunctionConfig{.growth_rate = growth_rate,
.midpoint = midpoint}) {
RTC_CHECK_GT(growth_rate, 0)
<< "As the `score` is defined now (low score means probably not "
"corrupted and vice versa), the growth rate must be positive to have "
"a logistic function that is monotonically increasing.";
RTC_LOG(LS_INFO)
<< "Calculating corruption probability using logistic function.";
}
double CorruptionClassifier::CalculateCorruptionProbability(
rtc::ArrayView<const FilteredSample> filtered_original_samples,
rtc::ArrayView<const FilteredSample> filtered_compressed_samples,
int luma_threshold,
int chroma_threshold) const {
double loss = GetScore(filtered_original_samples, filtered_compressed_samples,
luma_threshold, chroma_threshold);
if (const auto* scalar_config = std::get_if<ScalarConfig>(&config_)) {
// Fitting the unbounded loss to the interval of [0, 1] using a simple scale
// factor and capping the loss to 1.
return std::min(loss / scalar_config->scale_factor, 1.0);
}
const auto config = std::get_if<LogisticFunctionConfig>(&config_);
RTC_DCHECK(config);
// Fitting the unbounded loss to the interval of [0, 1] using the logistic
// function.
return 1 / (1 + std::exp(-config->growth_rate * (loss - config->midpoint)));
}
// The score is calculated according to the following formula :
//
// score = (sum_i max{(|original_i - compressed_i| - threshold, 0)^2}) / N
//
// where N is the number of samples, i in [0, N), and the threshold is
// either `luma_threshold` or `chroma_threshold` depending on whether the
// sample is luma or chroma.
double CorruptionClassifier::GetScore(
rtc::ArrayView<const FilteredSample> filtered_original_samples,
rtc::ArrayView<const FilteredSample> filtered_compressed_samples,
int luma_threshold,
int chroma_threshold) const {
RTC_DCHECK_GE(luma_threshold, 0);
RTC_DCHECK_GE(chroma_threshold, 0);
RTC_CHECK_EQ(filtered_original_samples.size(),
filtered_compressed_samples.size())
<< "The original and compressed frame have different amounts of "
"filtered samples.";
RTC_CHECK_GT(filtered_original_samples.size(), 0);
const int num_samples = filtered_original_samples.size();
double sum = 0.0;
for (int i = 0; i < num_samples; ++i) {
RTC_CHECK_EQ(filtered_original_samples[i].plane,
filtered_compressed_samples[i].plane);
double abs_diff = std::abs(filtered_original_samples[i].value -
filtered_compressed_samples[i].value);
switch (filtered_original_samples[i].plane) {
case ImagePlane::kLuma:
if (abs_diff > luma_threshold) {
sum += std::pow(abs_diff - luma_threshold, 2);
}
break;
case ImagePlane::kChroma:
if (abs_diff > chroma_threshold) {
sum += std::pow(abs_diff - chroma_threshold, 2);
}
break;
}
}
return sum / num_samples;
}
} // namespace webrtc