Move frame instrumentation evaluation to using callback.

Using a callback instead of a return value is in preparation for making
the corruption score calculation truly asynchronous in some cases.

Bug: webrtc:358039777
Change-Id: I8fc166a3236c5ccfb1e2cbacc0631091a2b3ad06
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/398820
Reviewed-by: Fanny Linderborg <linderborg@webrtc.org>
Commit-Queue: Erik Språng <sprang@webrtc.org>
Cr-Commit-Position: refs/heads/main@{#45117}
diff --git a/common_video/BUILD.gn b/common_video/BUILD.gn
index d8bc46c..c6e8a07 100644
--- a/common_video/BUILD.gn
+++ b/common_video/BUILD.gn
@@ -29,6 +29,7 @@
   deps = [
     ":frame_instrumentation_data",
     "../api/video:video_frame",
+    "../api/video:video_rtp_headers",
   ]
 }
 
diff --git a/common_video/include/corruption_score_calculator.h b/common_video/include/corruption_score_calculator.h
index 77ea533..cb87dbb5 100644
--- a/common_video/include/corruption_score_calculator.h
+++ b/common_video/include/corruption_score_calculator.h
@@ -11,8 +11,7 @@
 #ifndef COMMON_VIDEO_INCLUDE_CORRUPTION_SCORE_CALCULATOR_H_
 #define COMMON_VIDEO_INCLUDE_CORRUPTION_SCORE_CALCULATOR_H_
 
-#include <optional>
-
+#include "api/video/video_content_type.h"
 #include "api/video/video_frame.h"
 #include "common_video/frame_instrumentation_data.h"
 
@@ -24,9 +23,10 @@
  public:
   virtual ~CorruptionScoreCalculator() = default;
 
-  virtual std::optional<double> CalculateCorruptionScore(
+  virtual void CalculateCorruptionScore(
       const VideoFrame& frame,
-      const FrameInstrumentationData& frame_instrumentation_data) = 0;
+      const FrameInstrumentationData& frame_instrumentation_data,
+      VideoContentType content_type) = 0;
 };
 
 }  // namespace webrtc
diff --git a/modules/video_coding/generic_decoder.cc b/modules/video_coding/generic_decoder.cc
index c605cf7..14b5a92 100644
--- a/modules/video_coding/generic_decoder.cc
+++ b/modules/video_coding/generic_decoder.cc
@@ -147,17 +147,6 @@
     return;
   }
 
-  std::optional<double> corruption_score;
-  if (corruption_score_calculator_ &&
-      frame_info->frame_instrumentation_data.has_value()) {
-    if (const FrameInstrumentationData* data =
-            std::get_if<FrameInstrumentationData>(
-                &*frame_info->frame_instrumentation_data)) {
-      corruption_score = corruption_score_calculator_->CalculateCorruptionScore(
-          decodedImage, *data);
-    }
-  }
-
   decodedImage.set_ntp_time_ms(frame_info->ntp_time_ms);
   decodedImage.set_packet_infos(frame_info->packet_infos);
   decodedImage.set_rotation(frame_info->rotation);
@@ -255,8 +244,17 @@
                                      .qp = qp,
                                      .decode_time = decode_time,
                                      .content_type = frame_info->content_type,
-                                     .frame_type = frame_info->frame_type,
-                                     .corruption_score = corruption_score});
+                                     .frame_type = frame_info->frame_type});
+
+  if (corruption_score_calculator_ &&
+      frame_info->frame_instrumentation_data.has_value()) {
+    if (const FrameInstrumentationData* data =
+            std::get_if<FrameInstrumentationData>(
+                &*frame_info->frame_instrumentation_data)) {
+      corruption_score_calculator_->CalculateCorruptionScore(
+          decodedImage, *data, frame_info->content_type);
+    }
+  }
 }
 
 void VCMDecodedFrameCallback::OnDecoderInfoChanged(
diff --git a/modules/video_coding/generic_decoder_unittest.cc b/modules/video_coding/generic_decoder_unittest.cc
index e75c040..944480c 100644
--- a/modules/video_coding/generic_decoder_unittest.cc
+++ b/modules/video_coding/generic_decoder_unittest.cc
@@ -41,6 +41,9 @@
 #include "test/gtest.h"
 #include "test/time_controller/simulated_time_controller.h"
 
+using ::testing::Eq;
+using ::testing::Field;
+using ::testing::Property;
 using ::testing::Return;
 
 namespace webrtc {
@@ -48,10 +51,11 @@
 
 class MockCorruptionScoreCalculator : public CorruptionScoreCalculator {
  public:
-  MOCK_METHOD(std::optional<double>,
+  MOCK_METHOD(void,
               CalculateCorruptionScore,
               (const VideoFrame& frame,
-               const FrameInstrumentationData& frame_instrumentation_data),
+               const FrameInstrumentationData& frame_instrumentation_data,
+               VideoContentType content_type),
               (override));
 };
 
@@ -59,7 +63,6 @@
  public:
   int32_t OnFrameToRender(const FrameToRender& arguments) override {
     frames_.push_back(arguments.video_frame);
-    last_corruption_score_ = arguments.corruption_score;
     return 0;
   }
 
@@ -79,14 +82,9 @@
 
   uint32_t frames_dropped() const { return frames_dropped_; }
 
-  std::optional<double> last_corruption_score() const {
-    return last_corruption_score_;
-  }
-
  private:
   std::vector<VideoFrame> frames_;
   uint32_t frames_dropped_ = 0;
-  std::optional<double> last_corruption_score_;
 };
 
 class GenericDecoderTest : public ::testing::Test {
@@ -220,17 +218,13 @@
 }
 
 TEST_F(GenericDecoderTest, CallCalculateCorruptionScoreInDecoded) {
-  constexpr double kCorruptionScore = 0.76;
-
-  EXPECT_CALL(corruption_score_calculator_, CalculateCorruptionScore)
-      .WillOnce(Return(kCorruptionScore));
-
   constexpr uint32_t kRtpTimestamp = 1;
   FrameInfo frame_info;
-  frame_info.frame_instrumentation_data = FrameInstrumentationData{};
+  frame_info.frame_instrumentation_data =
+      FrameInstrumentationData{.sequence_index = 1};
   frame_info.rtp_timestamp = kRtpTimestamp;
   frame_info.decode_start = Timestamp::Zero();
-  frame_info.content_type = VideoContentType::UNSPECIFIED;
+  frame_info.content_type = VideoContentType::SCREENSHARE;
   frame_info.frame_type = VideoFrameType::kVideoFrameDelta;
   VideoFrame video_frame = VideoFrame::Builder()
                                .set_video_frame_buffer(I420Buffer::Create(5, 5))
@@ -238,9 +232,12 @@
                                .build();
   vcm_callback_.Map(std::move(frame_info));
 
+  EXPECT_CALL(corruption_score_calculator_,
+              CalculateCorruptionScore(
+                  Property(&VideoFrame::rtp_timestamp, Eq(kRtpTimestamp)),
+                  Field(&FrameInstrumentationData::sequence_index, Eq(1)),
+                  VideoContentType::SCREENSHARE));
   vcm_callback_.Decoded(video_frame);
-
-  EXPECT_EQ(user_callback_.last_corruption_score(), kCorruptionScore);
 }
 
 }  // namespace video_coding
diff --git a/modules/video_coding/include/video_coding_defines.h b/modules/video_coding/include/video_coding_defines.h
index 1a81d29..88c16ca 100644
--- a/modules/video_coding/include/video_coding_defines.h
+++ b/modules/video_coding/include/video_coding_defines.h
@@ -58,7 +58,6 @@
     TimeDelta decode_time;
     VideoContentType content_type;
     VideoFrameType frame_type;
-    std::optional<double> corruption_score;
   };
 
   virtual int32_t OnFrameToRender(const FrameToRender& arguments) = 0;
diff --git a/video/corruption_detection/BUILD.gn b/video/corruption_detection/BUILD.gn
index 6a07547..490f3ca 100644
--- a/video/corruption_detection/BUILD.gn
+++ b/video/corruption_detection/BUILD.gn
@@ -32,6 +32,7 @@
     "../../api:array_view",
     "../../api:scoped_refptr",
     "../../api/video:video_frame",
+    "../../api/video:video_rtp_headers",
     "../../common_video:frame_instrumentation_data",
     "../../rtc_base:checks",
     "../../rtc_base:logging",
@@ -162,6 +163,7 @@
       ":frame_instrumentation_evaluation",
       "../../api:scoped_refptr",
       "../../api/video:video_frame",
+      "../../api/video:video_rtp_headers",
       "../../common_video:frame_instrumentation_data",
       "../../test:test_support",
     ]
diff --git a/video/corruption_detection/frame_instrumentation_evaluation.cc b/video/corruption_detection/frame_instrumentation_evaluation.cc
index 8029a52..61a14fe 100644
--- a/video/corruption_detection/frame_instrumentation_evaluation.cc
+++ b/video/corruption_detection/frame_instrumentation_evaluation.cc
@@ -11,15 +11,14 @@
 #include "video/corruption_detection/frame_instrumentation_evaluation.h"
 
 #include <cstddef>
-#include <optional>
 #include <vector>
 
 #include "api/array_view.h"
+#include "api/video/video_content_type.h"
 #include "api/video/video_frame.h"
 #include "common_video/frame_instrumentation_data.h"
 #include "rtc_base/checks.h"
 #include "rtc_base/logging.h"
-#include "video/corruption_detection/corruption_classifier.h"
 #include "video/corruption_detection/halton_frame_sampler.h"
 
 namespace webrtc {
@@ -41,42 +40,49 @@
 
 }  // namespace
 
-std::optional<double> GetCorruptionScore(const FrameInstrumentationData& data,
-                                         const VideoFrame& frame) {
+FrameInstrumentationEvaluation::FrameInstrumentationEvaluation(
+    CorruptionScoreObserver* observer)
+    : observer_(observer), classifier_(/*scale_factor=*/3) {
+  RTC_CHECK(observer);
+}
+
+void FrameInstrumentationEvaluation::OnInstrumentedFrame(
+    const FrameInstrumentationData& data,
+    const VideoFrame& frame,
+    VideoContentType content_type) {
   if (data.sample_values.empty()) {
     RTC_LOG(LS_WARNING)
         << "Samples are needed to calculate a corruption score.";
-    return std::nullopt;
+    return;
   }
 
-  HaltonFrameSampler frame_sampler;
-  frame_sampler.SetCurrentIndex(data.sequence_index);
+  frame_sampler_.SetCurrentIndex(data.sequence_index);
   std::vector<HaltonFrameSampler::Coordinates> sample_coordinates =
-      frame_sampler.GetSampleCoordinatesForFrame(data.sample_values.size());
+      frame_sampler_.GetSampleCoordinatesForFrame(data.sample_values.size());
   if (sample_coordinates.empty()) {
     RTC_LOG(LS_ERROR) << "Failed to get sample coordinates for frame.";
-    return std::nullopt;
+    return;
   }
 
   std::vector<FilteredSample> samples = GetSampleValuesForFrame(
       frame, sample_coordinates, frame.width(), frame.height(), data.std_dev);
   if (samples.empty()) {
     RTC_LOG(LS_ERROR) << "Failed to get sample values for frame";
-    return std::nullopt;
+    return;
   }
 
   std::vector<FilteredSample> data_samples =
       ConvertSampleValuesToFilteredSamples(data.sample_values, samples);
   if (data_samples.empty()) {
     RTC_LOG(LS_ERROR) << "Failed to convert sample values to filtered samples";
-    return std::nullopt;
+    return;
   }
 
-  CorruptionClassifier classifier(3);
+  double score = classifier_.CalculateCorruptionProbability(
+      data_samples, samples, data.luma_error_threshold,
+      data.chroma_error_threshold);
 
-  return classifier.CalculateCorruptionProbability(data_samples, samples,
-                                                   data.luma_error_threshold,
-                                                   data.chroma_error_threshold);
+  observer_->OnCorruptionScore(score, content_type);
 }
 
 }  // namespace webrtc
diff --git a/video/corruption_detection/frame_instrumentation_evaluation.h b/video/corruption_detection/frame_instrumentation_evaluation.h
index 8bd3e1c..767eeff 100644
--- a/video/corruption_detection/frame_instrumentation_evaluation.h
+++ b/video/corruption_detection/frame_instrumentation_evaluation.h
@@ -11,15 +11,39 @@
 #ifndef VIDEO_CORRUPTION_DETECTION_FRAME_INSTRUMENTATION_EVALUATION_H_
 #define VIDEO_CORRUPTION_DETECTION_FRAME_INSTRUMENTATION_EVALUATION_H_
 
-#include <optional>
-
+#include "api/video/video_content_type.h"
 #include "api/video/video_frame.h"
 #include "common_video/frame_instrumentation_data.h"
+#include "video/corruption_detection/corruption_classifier.h"
+#include "video/corruption_detection/halton_frame_sampler.h"
 
 namespace webrtc {
 
-std::optional<double> GetCorruptionScore(const FrameInstrumentationData& data,
-                                         const VideoFrame& frame);
+class CorruptionScoreObserver {
+ public:
+  CorruptionScoreObserver() = default;
+  virtual ~CorruptionScoreObserver() = default;
+
+  // Results of corruption detection for a single frame, with a likelihood score
+  // in the range [0.0, 1.0].
+  virtual void OnCorruptionScore(double corruption_score,
+                                 VideoContentType content_type) = 0;
+};
+
+class FrameInstrumentationEvaluation {
+ public:
+  explicit FrameInstrumentationEvaluation(CorruptionScoreObserver* observer);
+
+  void OnInstrumentedFrame(const FrameInstrumentationData& data,
+                           const VideoFrame& frame,
+                           VideoContentType frame_type);
+
+ private:
+  CorruptionScoreObserver* const observer_;
+
+  HaltonFrameSampler frame_sampler_;
+  CorruptionClassifier classifier_;
+};
 
 }  // namespace webrtc
 
diff --git a/video/corruption_detection/frame_instrumentation_evaluation_unittest.cc b/video/corruption_detection/frame_instrumentation_evaluation_unittest.cc
index d82068d..983a757 100644
--- a/video/corruption_detection/frame_instrumentation_evaluation_unittest.cc
+++ b/video/corruption_detection/frame_instrumentation_evaluation_unittest.cc
@@ -11,18 +11,29 @@
 #include "video/corruption_detection/frame_instrumentation_evaluation.h"
 
 #include <cstdint>
-#include <optional>
 #include <vector>
 
 #include "api/scoped_refptr.h"
 #include "api/video/i420_buffer.h"
+#include "api/video/video_content_type.h"
 #include "api/video/video_frame.h"
 #include "common_video/frame_instrumentation_data.h"
+#include "test/gmock.h"
 #include "test/gtest.h"
 
 namespace webrtc {
 namespace {
 
+using ::testing::_;
+using ::testing::AllOf;
+using ::testing::Ge;
+using ::testing::Le;
+
+class MockCorruptionScoreObserver : public CorruptionScoreObserver {
+ public:
+  MOCK_METHOD(void, OnCorruptionScore, (double, VideoContentType), (override));
+};
+
 scoped_refptr<I420Buffer> MakeI420FrameBufferWithDifferentPixelValues() {
   // Create an I420 frame of size 4x4.
   const int kDefaultLumaWidth = 4;
@@ -52,9 +63,10 @@
           .set_video_frame_buffer(MakeI420FrameBufferWithDifferentPixelValues())
           .build();
 
-  std::optional<double> corruption_score = GetCorruptionScore(data, frame);
-
-  EXPECT_FALSE(corruption_score.has_value());
+  MockCorruptionScoreObserver observer;
+  FrameInstrumentationEvaluation evaluator(&observer);
+  EXPECT_CALL(observer, OnCorruptionScore).Times(0);
+  evaluator.OnInstrumentedFrame(data, frame, VideoContentType::UNSPECIFIED);
 }
 
 TEST(FrameInstrumentationEvaluationTest,
@@ -71,10 +83,10 @@
           .set_video_frame_buffer(MakeI420FrameBufferWithDifferentPixelValues())
           .build();
 
-  std::optional<double> corruption_score = GetCorruptionScore(data, frame);
-
-  ASSERT_TRUE(corruption_score.has_value());
-  EXPECT_DOUBLE_EQ(*corruption_score, 1.0);
+  MockCorruptionScoreObserver observer;
+  FrameInstrumentationEvaluation evaluator(&observer);
+  EXPECT_CALL(observer, OnCorruptionScore(1.0, VideoContentType::SCREENSHARE));
+  evaluator.OnInstrumentedFrame(data, frame, VideoContentType::SCREENSHARE);
 }
 
 TEST(FrameInstrumentationEvaluationTest,
@@ -91,11 +103,10 @@
           .set_video_frame_buffer(MakeI420FrameBufferWithDifferentPixelValues())
           .build();
 
-  std::optional<double> corruption_score = GetCorruptionScore(data, frame);
-
-  ASSERT_TRUE(corruption_score.has_value());
-  EXPECT_LE(*corruption_score, 1);
-  EXPECT_GE(*corruption_score, 0);
+  MockCorruptionScoreObserver observer;
+  FrameInstrumentationEvaluation evaluator(&observer);
+  EXPECT_CALL(observer, OnCorruptionScore(AllOf(Ge(0.0), Le(1.0)), _));
+  evaluator.OnInstrumentedFrame(data, frame, VideoContentType::UNSPECIFIED);
 }
 
 TEST(FrameInstrumentationEvaluationTest,
@@ -114,11 +125,10 @@
           .set_video_frame_buffer(MakeI420FrameBufferWithDifferentPixelValues())
           .build();
 
-  std::optional<double> corruption_score = GetCorruptionScore(data, frame);
-
-  ASSERT_TRUE(corruption_score.has_value());
-  EXPECT_LE(*corruption_score, 1);
-  EXPECT_GE(*corruption_score, 0);
+  MockCorruptionScoreObserver observer;
+  FrameInstrumentationEvaluation evaluator(&observer);
+  EXPECT_CALL(observer, OnCorruptionScore(AllOf(Ge(0.0), Le(1.0)), _));
+  evaluator.OnInstrumentedFrame(data, frame, VideoContentType::UNSPECIFIED);
 }
 
 TEST(FrameInstrumentationEvaluationTest, ApplySequenceIndexWhenProvided) {
@@ -136,11 +146,10 @@
           .set_video_frame_buffer(MakeI420FrameBufferWithDifferentPixelValues())
           .build();
 
-  std::optional<double> corruption_score = GetCorruptionScore(data, frame);
-
-  ASSERT_TRUE(corruption_score.has_value());
-  EXPECT_LE(*corruption_score, 1);
-  EXPECT_GE(*corruption_score, 0);
+  MockCorruptionScoreObserver observer;
+  FrameInstrumentationEvaluation evaluator(&observer);
+  EXPECT_CALL(observer, OnCorruptionScore(AllOf(Ge(0.0), Le(1.0)), _));
+  evaluator.OnInstrumentedFrame(data, frame, VideoContentType::UNSPECIFIED);
 }
 
 }  // namespace
diff --git a/video/receive_statistics_proxy.h b/video/receive_statistics_proxy.h
index 023f4fa..2fb819a 100644
--- a/video/receive_statistics_proxy.h
+++ b/video/receive_statistics_proxy.h
@@ -40,6 +40,7 @@
 #include "rtc_base/rate_tracker.h"
 #include "rtc_base/system/no_unique_address.h"
 #include "rtc_base/thread_annotations.h"
+#include "video/corruption_detection/frame_instrumentation_evaluation.h"
 #include "video/stats_counter.h"
 #include "video/video_quality_observer2.h"
 #include "video/video_stream_buffer_controller.h"
@@ -55,7 +56,8 @@
 
 class ReceiveStatisticsProxy : public VideoStreamBufferControllerStatsObserver,
                                public RtcpCnameCallback,
-                               public RtcpPacketTypeCounterObserver {
+                               public RtcpPacketTypeCounterObserver,
+                               public CorruptionScoreObserver {
  public:
   ReceiveStatisticsProxy(uint32_t remote_ssrc,
                          Clock* clock,
@@ -115,7 +117,7 @@
   void OnCname(uint32_t ssrc, absl::string_view cname) override;
 
   void OnCorruptionScore(double corruption_score,
-                         VideoContentType content_type);
+                         VideoContentType content_type) override;
 
   // Implements RtcpPacketTypeCounterObserver.
   void RtcpPacketTypesCounterUpdated(
diff --git a/video/video_receive_stream2.cc b/video/video_receive_stream2.cc
index 2e8bc0b..4660da8 100644
--- a/video/video_receive_stream2.cc
+++ b/video/video_receive_stream2.cc
@@ -48,6 +48,7 @@
 #include "api/video/recordable_encoded_frame.h"
 #include "api/video/render_resolution.h"
 #include "api/video/video_codec_type.h"
+#include "api/video/video_content_type.h"
 #include "api/video/video_frame.h"
 #include "api/video/video_frame_type.h"
 #include "api/video/video_rotation.h"
@@ -266,6 +267,7 @@
       max_wait_for_frame_(DetermineMaxWaitForFrame(
           TimeDelta::Millis(config_.rtp.nack.rtp_history_ms),
           false)),
+      frame_evaluator_(&stats_proxy_),
       decode_queue_(env_.task_queue_factory().CreateTaskQueue(
           "DecodingQueue",
           TaskQueueFactory::Priority::HIGH)) {
@@ -651,10 +653,13 @@
   stats_proxy_.UpdateHistograms(fraction_lost, rtp_stats, nullptr);
 }
 
-std::optional<double> VideoReceiveStream2::CalculateCorruptionScore(
+void VideoReceiveStream2::CalculateCorruptionScore(
     const VideoFrame& frame,
-    const FrameInstrumentationData& frame_instrumentation_data) {
-  return GetCorruptionScore(frame_instrumentation_data, frame);
+    const FrameInstrumentationData& frame_instrumentation_data,
+    VideoContentType content_type) {
+  RTC_DCHECK_RUN_ON(&decode_sequence_checker_);
+  frame_evaluator_.OnInstrumentedFrame(frame_instrumentation_data, frame,
+                                       content_type);
 }
 
 bool VideoReceiveStream2::SetBaseMinimumPlayoutDelayMs(int delay_ms) {
diff --git a/video/video_receive_stream2.h b/video/video_receive_stream2.h
index d6d27a5..135e4c8 100644
--- a/video/video_receive_stream2.h
+++ b/video/video_receive_stream2.h
@@ -34,6 +34,7 @@
 #include "api/units/timestamp.h"
 #include "api/video/encoded_frame.h"
 #include "api/video/recordable_encoded_frame.h"
+#include "api/video/video_content_type.h"
 #include "api/video/video_frame.h"
 #include "api/video/video_sink_interface.h"
 #include "call/call.h"
@@ -50,6 +51,7 @@
 #include "rtc_base/system/no_unique_address.h"
 #include "rtc_base/thread_annotations.h"
 #include "rtc_base/time_utils.h"
+#include "video/corruption_detection/frame_instrumentation_evaluation.h"
 #include "video/decode_synchronizer.h"
 #include "video/receive_statistics_proxy.h"
 #include "video/rtp_streams_synchronizer2.h"
@@ -258,9 +260,10 @@
       RTC_RUN_ON(decode_sequence_checker_);
 
   void UpdateHistograms();
-  std::optional<double> CalculateCorruptionScore(
+  void CalculateCorruptionScore(
       const VideoFrame& frame,
-      const FrameInstrumentationData& frame_instrumentation_data) override;
+      const FrameInstrumentationData& frame_instrumentation_data,
+      VideoContentType content_type) override;
 
   const Environment env_;
 
@@ -364,6 +367,9 @@
   std::vector<std::unique_ptr<EncodedFrame>> buffered_encoded_frames_
       RTC_GUARDED_BY(decode_sequence_checker_);
 
+  FrameInstrumentationEvaluation frame_evaluator_
+      RTC_GUARDED_BY(decode_sequence_checker_);
+
   // Used to signal destruction to potentially pending tasks.
   ScopedTaskSafety task_safety_;
 
diff --git a/video/video_stream_decoder2.cc b/video/video_stream_decoder2.cc
index 02b3db8..45379a0 100644
--- a/video/video_stream_decoder2.cc
+++ b/video/video_stream_decoder2.cc
@@ -53,10 +53,6 @@
   receive_stats_callback_->OnDecodedFrame(
       arguments.video_frame, arguments.qp, arguments.decode_time,
       arguments.content_type, arguments.frame_type);
-  if (arguments.corruption_score.has_value()) {
-    receive_stats_callback_->OnCorruptionScore(*arguments.corruption_score,
-                                               arguments.content_type);
-  }
   incoming_video_stream_->OnFrame(arguments.video_frame);
   return 0;
 }