Move frame instrumentation evaluation to using callback.
Using a callback instead of a return value is in preparation for making
the corruption score calculation truly asynchronous in some cases.
Bug: webrtc:358039777
Change-Id: I8fc166a3236c5ccfb1e2cbacc0631091a2b3ad06
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/398820
Reviewed-by: Fanny Linderborg <linderborg@webrtc.org>
Commit-Queue: Erik Språng <sprang@webrtc.org>
Cr-Commit-Position: refs/heads/main@{#45117}
diff --git a/common_video/BUILD.gn b/common_video/BUILD.gn
index d8bc46c..c6e8a07 100644
--- a/common_video/BUILD.gn
+++ b/common_video/BUILD.gn
@@ -29,6 +29,7 @@
deps = [
":frame_instrumentation_data",
"../api/video:video_frame",
+ "../api/video:video_rtp_headers",
]
}
diff --git a/common_video/include/corruption_score_calculator.h b/common_video/include/corruption_score_calculator.h
index 77ea533..cb87dbb5 100644
--- a/common_video/include/corruption_score_calculator.h
+++ b/common_video/include/corruption_score_calculator.h
@@ -11,8 +11,7 @@
#ifndef COMMON_VIDEO_INCLUDE_CORRUPTION_SCORE_CALCULATOR_H_
#define COMMON_VIDEO_INCLUDE_CORRUPTION_SCORE_CALCULATOR_H_
-#include <optional>
-
+#include "api/video/video_content_type.h"
#include "api/video/video_frame.h"
#include "common_video/frame_instrumentation_data.h"
@@ -24,9 +23,10 @@
public:
virtual ~CorruptionScoreCalculator() = default;
- virtual std::optional<double> CalculateCorruptionScore(
+ virtual void CalculateCorruptionScore(
const VideoFrame& frame,
- const FrameInstrumentationData& frame_instrumentation_data) = 0;
+ const FrameInstrumentationData& frame_instrumentation_data,
+ VideoContentType content_type) = 0;
};
} // namespace webrtc
diff --git a/modules/video_coding/generic_decoder.cc b/modules/video_coding/generic_decoder.cc
index c605cf7..14b5a92 100644
--- a/modules/video_coding/generic_decoder.cc
+++ b/modules/video_coding/generic_decoder.cc
@@ -147,17 +147,6 @@
return;
}
- std::optional<double> corruption_score;
- if (corruption_score_calculator_ &&
- frame_info->frame_instrumentation_data.has_value()) {
- if (const FrameInstrumentationData* data =
- std::get_if<FrameInstrumentationData>(
- &*frame_info->frame_instrumentation_data)) {
- corruption_score = corruption_score_calculator_->CalculateCorruptionScore(
- decodedImage, *data);
- }
- }
-
decodedImage.set_ntp_time_ms(frame_info->ntp_time_ms);
decodedImage.set_packet_infos(frame_info->packet_infos);
decodedImage.set_rotation(frame_info->rotation);
@@ -255,8 +244,17 @@
.qp = qp,
.decode_time = decode_time,
.content_type = frame_info->content_type,
- .frame_type = frame_info->frame_type,
- .corruption_score = corruption_score});
+ .frame_type = frame_info->frame_type});
+
+ if (corruption_score_calculator_ &&
+ frame_info->frame_instrumentation_data.has_value()) {
+ if (const FrameInstrumentationData* data =
+ std::get_if<FrameInstrumentationData>(
+ &*frame_info->frame_instrumentation_data)) {
+ corruption_score_calculator_->CalculateCorruptionScore(
+ decodedImage, *data, frame_info->content_type);
+ }
+ }
}
void VCMDecodedFrameCallback::OnDecoderInfoChanged(
diff --git a/modules/video_coding/generic_decoder_unittest.cc b/modules/video_coding/generic_decoder_unittest.cc
index e75c040..944480c 100644
--- a/modules/video_coding/generic_decoder_unittest.cc
+++ b/modules/video_coding/generic_decoder_unittest.cc
@@ -41,6 +41,9 @@
#include "test/gtest.h"
#include "test/time_controller/simulated_time_controller.h"
+using ::testing::Eq;
+using ::testing::Field;
+using ::testing::Property;
using ::testing::Return;
namespace webrtc {
@@ -48,10 +51,11 @@
class MockCorruptionScoreCalculator : public CorruptionScoreCalculator {
public:
- MOCK_METHOD(std::optional<double>,
+ MOCK_METHOD(void,
CalculateCorruptionScore,
(const VideoFrame& frame,
- const FrameInstrumentationData& frame_instrumentation_data),
+ const FrameInstrumentationData& frame_instrumentation_data,
+ VideoContentType content_type),
(override));
};
@@ -59,7 +63,6 @@
public:
int32_t OnFrameToRender(const FrameToRender& arguments) override {
frames_.push_back(arguments.video_frame);
- last_corruption_score_ = arguments.corruption_score;
return 0;
}
@@ -79,14 +82,9 @@
uint32_t frames_dropped() const { return frames_dropped_; }
- std::optional<double> last_corruption_score() const {
- return last_corruption_score_;
- }
-
private:
std::vector<VideoFrame> frames_;
uint32_t frames_dropped_ = 0;
- std::optional<double> last_corruption_score_;
};
class GenericDecoderTest : public ::testing::Test {
@@ -220,17 +218,13 @@
}
TEST_F(GenericDecoderTest, CallCalculateCorruptionScoreInDecoded) {
- constexpr double kCorruptionScore = 0.76;
-
- EXPECT_CALL(corruption_score_calculator_, CalculateCorruptionScore)
- .WillOnce(Return(kCorruptionScore));
-
constexpr uint32_t kRtpTimestamp = 1;
FrameInfo frame_info;
- frame_info.frame_instrumentation_data = FrameInstrumentationData{};
+ frame_info.frame_instrumentation_data =
+ FrameInstrumentationData{.sequence_index = 1};
frame_info.rtp_timestamp = kRtpTimestamp;
frame_info.decode_start = Timestamp::Zero();
- frame_info.content_type = VideoContentType::UNSPECIFIED;
+ frame_info.content_type = VideoContentType::SCREENSHARE;
frame_info.frame_type = VideoFrameType::kVideoFrameDelta;
VideoFrame video_frame = VideoFrame::Builder()
.set_video_frame_buffer(I420Buffer::Create(5, 5))
@@ -238,9 +232,12 @@
.build();
vcm_callback_.Map(std::move(frame_info));
+ EXPECT_CALL(corruption_score_calculator_,
+ CalculateCorruptionScore(
+ Property(&VideoFrame::rtp_timestamp, Eq(kRtpTimestamp)),
+ Field(&FrameInstrumentationData::sequence_index, Eq(1)),
+ VideoContentType::SCREENSHARE));
vcm_callback_.Decoded(video_frame);
-
- EXPECT_EQ(user_callback_.last_corruption_score(), kCorruptionScore);
}
} // namespace video_coding
diff --git a/modules/video_coding/include/video_coding_defines.h b/modules/video_coding/include/video_coding_defines.h
index 1a81d29..88c16ca 100644
--- a/modules/video_coding/include/video_coding_defines.h
+++ b/modules/video_coding/include/video_coding_defines.h
@@ -58,7 +58,6 @@
TimeDelta decode_time;
VideoContentType content_type;
VideoFrameType frame_type;
- std::optional<double> corruption_score;
};
virtual int32_t OnFrameToRender(const FrameToRender& arguments) = 0;
diff --git a/video/corruption_detection/BUILD.gn b/video/corruption_detection/BUILD.gn
index 6a07547..490f3ca 100644
--- a/video/corruption_detection/BUILD.gn
+++ b/video/corruption_detection/BUILD.gn
@@ -32,6 +32,7 @@
"../../api:array_view",
"../../api:scoped_refptr",
"../../api/video:video_frame",
+ "../../api/video:video_rtp_headers",
"../../common_video:frame_instrumentation_data",
"../../rtc_base:checks",
"../../rtc_base:logging",
@@ -162,6 +163,7 @@
":frame_instrumentation_evaluation",
"../../api:scoped_refptr",
"../../api/video:video_frame",
+ "../../api/video:video_rtp_headers",
"../../common_video:frame_instrumentation_data",
"../../test:test_support",
]
diff --git a/video/corruption_detection/frame_instrumentation_evaluation.cc b/video/corruption_detection/frame_instrumentation_evaluation.cc
index 8029a52..61a14fe 100644
--- a/video/corruption_detection/frame_instrumentation_evaluation.cc
+++ b/video/corruption_detection/frame_instrumentation_evaluation.cc
@@ -11,15 +11,14 @@
#include "video/corruption_detection/frame_instrumentation_evaluation.h"
#include <cstddef>
-#include <optional>
#include <vector>
#include "api/array_view.h"
+#include "api/video/video_content_type.h"
#include "api/video/video_frame.h"
#include "common_video/frame_instrumentation_data.h"
#include "rtc_base/checks.h"
#include "rtc_base/logging.h"
-#include "video/corruption_detection/corruption_classifier.h"
#include "video/corruption_detection/halton_frame_sampler.h"
namespace webrtc {
@@ -41,42 +40,49 @@
} // namespace
-std::optional<double> GetCorruptionScore(const FrameInstrumentationData& data,
- const VideoFrame& frame) {
+FrameInstrumentationEvaluation::FrameInstrumentationEvaluation(
+ CorruptionScoreObserver* observer)
+ : observer_(observer), classifier_(/*scale_factor=*/3) {
+ RTC_CHECK(observer);
+}
+
+void FrameInstrumentationEvaluation::OnInstrumentedFrame(
+ const FrameInstrumentationData& data,
+ const VideoFrame& frame,
+ VideoContentType content_type) {
if (data.sample_values.empty()) {
RTC_LOG(LS_WARNING)
<< "Samples are needed to calculate a corruption score.";
- return std::nullopt;
+ return;
}
- HaltonFrameSampler frame_sampler;
- frame_sampler.SetCurrentIndex(data.sequence_index);
+ frame_sampler_.SetCurrentIndex(data.sequence_index);
std::vector<HaltonFrameSampler::Coordinates> sample_coordinates =
- frame_sampler.GetSampleCoordinatesForFrame(data.sample_values.size());
+ frame_sampler_.GetSampleCoordinatesForFrame(data.sample_values.size());
if (sample_coordinates.empty()) {
RTC_LOG(LS_ERROR) << "Failed to get sample coordinates for frame.";
- return std::nullopt;
+ return;
}
std::vector<FilteredSample> samples = GetSampleValuesForFrame(
frame, sample_coordinates, frame.width(), frame.height(), data.std_dev);
if (samples.empty()) {
RTC_LOG(LS_ERROR) << "Failed to get sample values for frame";
- return std::nullopt;
+ return;
}
std::vector<FilteredSample> data_samples =
ConvertSampleValuesToFilteredSamples(data.sample_values, samples);
if (data_samples.empty()) {
RTC_LOG(LS_ERROR) << "Failed to convert sample values to filtered samples";
- return std::nullopt;
+ return;
}
- CorruptionClassifier classifier(3);
+ double score = classifier_.CalculateCorruptionProbability(
+ data_samples, samples, data.luma_error_threshold,
+ data.chroma_error_threshold);
- return classifier.CalculateCorruptionProbability(data_samples, samples,
- data.luma_error_threshold,
- data.chroma_error_threshold);
+ observer_->OnCorruptionScore(score, content_type);
}
} // namespace webrtc
diff --git a/video/corruption_detection/frame_instrumentation_evaluation.h b/video/corruption_detection/frame_instrumentation_evaluation.h
index 8bd3e1c..767eeff 100644
--- a/video/corruption_detection/frame_instrumentation_evaluation.h
+++ b/video/corruption_detection/frame_instrumentation_evaluation.h
@@ -11,15 +11,39 @@
#ifndef VIDEO_CORRUPTION_DETECTION_FRAME_INSTRUMENTATION_EVALUATION_H_
#define VIDEO_CORRUPTION_DETECTION_FRAME_INSTRUMENTATION_EVALUATION_H_
-#include <optional>
-
+#include "api/video/video_content_type.h"
#include "api/video/video_frame.h"
#include "common_video/frame_instrumentation_data.h"
+#include "video/corruption_detection/corruption_classifier.h"
+#include "video/corruption_detection/halton_frame_sampler.h"
namespace webrtc {
-std::optional<double> GetCorruptionScore(const FrameInstrumentationData& data,
- const VideoFrame& frame);
+class CorruptionScoreObserver {
+ public:
+ CorruptionScoreObserver() = default;
+ virtual ~CorruptionScoreObserver() = default;
+
+ // Results of corruption detection for a single frame, with a likelihood score
+ // in the range [0.0, 1.0].
+ virtual void OnCorruptionScore(double corruption_score,
+ VideoContentType content_type) = 0;
+};
+
+class FrameInstrumentationEvaluation {
+ public:
+ explicit FrameInstrumentationEvaluation(CorruptionScoreObserver* observer);
+
+ void OnInstrumentedFrame(const FrameInstrumentationData& data,
+ const VideoFrame& frame,
+ VideoContentType frame_type);
+
+ private:
+ CorruptionScoreObserver* const observer_;
+
+ HaltonFrameSampler frame_sampler_;
+ CorruptionClassifier classifier_;
+};
} // namespace webrtc
diff --git a/video/corruption_detection/frame_instrumentation_evaluation_unittest.cc b/video/corruption_detection/frame_instrumentation_evaluation_unittest.cc
index d82068d..983a757 100644
--- a/video/corruption_detection/frame_instrumentation_evaluation_unittest.cc
+++ b/video/corruption_detection/frame_instrumentation_evaluation_unittest.cc
@@ -11,18 +11,29 @@
#include "video/corruption_detection/frame_instrumentation_evaluation.h"
#include <cstdint>
-#include <optional>
#include <vector>
#include "api/scoped_refptr.h"
#include "api/video/i420_buffer.h"
+#include "api/video/video_content_type.h"
#include "api/video/video_frame.h"
#include "common_video/frame_instrumentation_data.h"
+#include "test/gmock.h"
#include "test/gtest.h"
namespace webrtc {
namespace {
+using ::testing::_;
+using ::testing::AllOf;
+using ::testing::Ge;
+using ::testing::Le;
+
+class MockCorruptionScoreObserver : public CorruptionScoreObserver {
+ public:
+ MOCK_METHOD(void, OnCorruptionScore, (double, VideoContentType), (override));
+};
+
scoped_refptr<I420Buffer> MakeI420FrameBufferWithDifferentPixelValues() {
// Create an I420 frame of size 4x4.
const int kDefaultLumaWidth = 4;
@@ -52,9 +63,10 @@
.set_video_frame_buffer(MakeI420FrameBufferWithDifferentPixelValues())
.build();
- std::optional<double> corruption_score = GetCorruptionScore(data, frame);
-
- EXPECT_FALSE(corruption_score.has_value());
+ MockCorruptionScoreObserver observer;
+ FrameInstrumentationEvaluation evaluator(&observer);
+ EXPECT_CALL(observer, OnCorruptionScore).Times(0);
+ evaluator.OnInstrumentedFrame(data, frame, VideoContentType::UNSPECIFIED);
}
TEST(FrameInstrumentationEvaluationTest,
@@ -71,10 +83,10 @@
.set_video_frame_buffer(MakeI420FrameBufferWithDifferentPixelValues())
.build();
- std::optional<double> corruption_score = GetCorruptionScore(data, frame);
-
- ASSERT_TRUE(corruption_score.has_value());
- EXPECT_DOUBLE_EQ(*corruption_score, 1.0);
+ MockCorruptionScoreObserver observer;
+ FrameInstrumentationEvaluation evaluator(&observer);
+ EXPECT_CALL(observer, OnCorruptionScore(1.0, VideoContentType::SCREENSHARE));
+ evaluator.OnInstrumentedFrame(data, frame, VideoContentType::SCREENSHARE);
}
TEST(FrameInstrumentationEvaluationTest,
@@ -91,11 +103,10 @@
.set_video_frame_buffer(MakeI420FrameBufferWithDifferentPixelValues())
.build();
- std::optional<double> corruption_score = GetCorruptionScore(data, frame);
-
- ASSERT_TRUE(corruption_score.has_value());
- EXPECT_LE(*corruption_score, 1);
- EXPECT_GE(*corruption_score, 0);
+ MockCorruptionScoreObserver observer;
+ FrameInstrumentationEvaluation evaluator(&observer);
+ EXPECT_CALL(observer, OnCorruptionScore(AllOf(Ge(0.0), Le(1.0)), _));
+ evaluator.OnInstrumentedFrame(data, frame, VideoContentType::UNSPECIFIED);
}
TEST(FrameInstrumentationEvaluationTest,
@@ -114,11 +125,10 @@
.set_video_frame_buffer(MakeI420FrameBufferWithDifferentPixelValues())
.build();
- std::optional<double> corruption_score = GetCorruptionScore(data, frame);
-
- ASSERT_TRUE(corruption_score.has_value());
- EXPECT_LE(*corruption_score, 1);
- EXPECT_GE(*corruption_score, 0);
+ MockCorruptionScoreObserver observer;
+ FrameInstrumentationEvaluation evaluator(&observer);
+ EXPECT_CALL(observer, OnCorruptionScore(AllOf(Ge(0.0), Le(1.0)), _));
+ evaluator.OnInstrumentedFrame(data, frame, VideoContentType::UNSPECIFIED);
}
TEST(FrameInstrumentationEvaluationTest, ApplySequenceIndexWhenProvided) {
@@ -136,11 +146,10 @@
.set_video_frame_buffer(MakeI420FrameBufferWithDifferentPixelValues())
.build();
- std::optional<double> corruption_score = GetCorruptionScore(data, frame);
-
- ASSERT_TRUE(corruption_score.has_value());
- EXPECT_LE(*corruption_score, 1);
- EXPECT_GE(*corruption_score, 0);
+ MockCorruptionScoreObserver observer;
+ FrameInstrumentationEvaluation evaluator(&observer);
+ EXPECT_CALL(observer, OnCorruptionScore(AllOf(Ge(0.0), Le(1.0)), _));
+ evaluator.OnInstrumentedFrame(data, frame, VideoContentType::UNSPECIFIED);
}
} // namespace
diff --git a/video/receive_statistics_proxy.h b/video/receive_statistics_proxy.h
index 023f4fa..2fb819a 100644
--- a/video/receive_statistics_proxy.h
+++ b/video/receive_statistics_proxy.h
@@ -40,6 +40,7 @@
#include "rtc_base/rate_tracker.h"
#include "rtc_base/system/no_unique_address.h"
#include "rtc_base/thread_annotations.h"
+#include "video/corruption_detection/frame_instrumentation_evaluation.h"
#include "video/stats_counter.h"
#include "video/video_quality_observer2.h"
#include "video/video_stream_buffer_controller.h"
@@ -55,7 +56,8 @@
class ReceiveStatisticsProxy : public VideoStreamBufferControllerStatsObserver,
public RtcpCnameCallback,
- public RtcpPacketTypeCounterObserver {
+ public RtcpPacketTypeCounterObserver,
+ public CorruptionScoreObserver {
public:
ReceiveStatisticsProxy(uint32_t remote_ssrc,
Clock* clock,
@@ -115,7 +117,7 @@
void OnCname(uint32_t ssrc, absl::string_view cname) override;
void OnCorruptionScore(double corruption_score,
- VideoContentType content_type);
+ VideoContentType content_type) override;
// Implements RtcpPacketTypeCounterObserver.
void RtcpPacketTypesCounterUpdated(
diff --git a/video/video_receive_stream2.cc b/video/video_receive_stream2.cc
index 2e8bc0b..4660da8 100644
--- a/video/video_receive_stream2.cc
+++ b/video/video_receive_stream2.cc
@@ -48,6 +48,7 @@
#include "api/video/recordable_encoded_frame.h"
#include "api/video/render_resolution.h"
#include "api/video/video_codec_type.h"
+#include "api/video/video_content_type.h"
#include "api/video/video_frame.h"
#include "api/video/video_frame_type.h"
#include "api/video/video_rotation.h"
@@ -266,6 +267,7 @@
max_wait_for_frame_(DetermineMaxWaitForFrame(
TimeDelta::Millis(config_.rtp.nack.rtp_history_ms),
false)),
+ frame_evaluator_(&stats_proxy_),
decode_queue_(env_.task_queue_factory().CreateTaskQueue(
"DecodingQueue",
TaskQueueFactory::Priority::HIGH)) {
@@ -651,10 +653,13 @@
stats_proxy_.UpdateHistograms(fraction_lost, rtp_stats, nullptr);
}
-std::optional<double> VideoReceiveStream2::CalculateCorruptionScore(
+void VideoReceiveStream2::CalculateCorruptionScore(
const VideoFrame& frame,
- const FrameInstrumentationData& frame_instrumentation_data) {
- return GetCorruptionScore(frame_instrumentation_data, frame);
+ const FrameInstrumentationData& frame_instrumentation_data,
+ VideoContentType content_type) {
+ RTC_DCHECK_RUN_ON(&decode_sequence_checker_);
+ frame_evaluator_.OnInstrumentedFrame(frame_instrumentation_data, frame,
+ content_type);
}
bool VideoReceiveStream2::SetBaseMinimumPlayoutDelayMs(int delay_ms) {
diff --git a/video/video_receive_stream2.h b/video/video_receive_stream2.h
index d6d27a5..135e4c8 100644
--- a/video/video_receive_stream2.h
+++ b/video/video_receive_stream2.h
@@ -34,6 +34,7 @@
#include "api/units/timestamp.h"
#include "api/video/encoded_frame.h"
#include "api/video/recordable_encoded_frame.h"
+#include "api/video/video_content_type.h"
#include "api/video/video_frame.h"
#include "api/video/video_sink_interface.h"
#include "call/call.h"
@@ -50,6 +51,7 @@
#include "rtc_base/system/no_unique_address.h"
#include "rtc_base/thread_annotations.h"
#include "rtc_base/time_utils.h"
+#include "video/corruption_detection/frame_instrumentation_evaluation.h"
#include "video/decode_synchronizer.h"
#include "video/receive_statistics_proxy.h"
#include "video/rtp_streams_synchronizer2.h"
@@ -258,9 +260,10 @@
RTC_RUN_ON(decode_sequence_checker_);
void UpdateHistograms();
- std::optional<double> CalculateCorruptionScore(
+ void CalculateCorruptionScore(
const VideoFrame& frame,
- const FrameInstrumentationData& frame_instrumentation_data) override;
+ const FrameInstrumentationData& frame_instrumentation_data,
+ VideoContentType content_type) override;
const Environment env_;
@@ -364,6 +367,9 @@
std::vector<std::unique_ptr<EncodedFrame>> buffered_encoded_frames_
RTC_GUARDED_BY(decode_sequence_checker_);
+ FrameInstrumentationEvaluation frame_evaluator_
+ RTC_GUARDED_BY(decode_sequence_checker_);
+
// Used to signal destruction to potentially pending tasks.
ScopedTaskSafety task_safety_;
diff --git a/video/video_stream_decoder2.cc b/video/video_stream_decoder2.cc
index 02b3db8..45379a0 100644
--- a/video/video_stream_decoder2.cc
+++ b/video/video_stream_decoder2.cc
@@ -53,10 +53,6 @@
receive_stats_callback_->OnDecodedFrame(
arguments.video_frame, arguments.qp, arguments.decode_time,
arguments.content_type, arguments.frame_type);
- if (arguments.corruption_score.has_value()) {
- receive_stats_callback_->OnCorruptionScore(*arguments.corruption_score,
- arguments.content_type);
- }
incoming_video_stream_->OnFrame(arguments.video_frame);
return 0;
}