Add SetKeyFrameRequestCallback to MediaTransportInterface

And implemented in LoopbackMediaTransport.

Bug: webrtc:9719
Change-Id: I68b16c2b6ed5583ffe9a5266e3d4cb1d94afbb97
Reviewed-on: https://webrtc-review.googlesource.com/c/113523
Reviewed-by: Karl Wiberg <kwiberg@webrtc.org>
Commit-Queue: Niels Moller <nisse@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#25948}
diff --git a/api/media_transport_interface.cc b/api/media_transport_interface.cc
index 6b00f2c..b201eae 100644
--- a/api/media_transport_interface.cc
+++ b/api/media_transport_interface.cc
@@ -158,6 +158,9 @@
   return std::unique_ptr<MediaTransportInterface>(nullptr);
 }
 
+void MediaTransportInterface::SetKeyFrameRequestCallback(
+    MediaTransportKeyFrameRequestCallback* callback) {}
+
 absl::optional<TargetTransferRate>
 MediaTransportInterface::GetLatestTargetTransferRate() {
   return absl::nullopt;
diff --git a/api/media_transport_interface.h b/api/media_transport_interface.h
index 6b461cf..6f2ec60 100644
--- a/api/media_transport_interface.h
+++ b/api/media_transport_interface.h
@@ -239,6 +239,15 @@
   RTC_DEPRECATED virtual void OnKeyFrameRequested(uint64_t channel_id) {}
 };
 
+// Interface for video sender to be notified of received key frame request.
+class MediaTransportKeyFrameRequestCallback {
+ public:
+  virtual ~MediaTransportKeyFrameRequestCallback() = default;
+
+  // Called when a key frame request is received on the transport.
+  virtual void OnKeyFrameRequested(uint64_t channel_id) = 0;
+};
+
 // State of the media transport.  Media transport begins in the pending state.
 // It transitions to writable when it is ready to send media.  It may transition
 // back to pending if the connection is blocked.  It may transition to closed at
@@ -339,6 +348,10 @@
       uint64_t channel_id,
       const MediaTransportEncodedVideoFrame& frame) = 0;
 
+  // Used by video sender to be notified on key frame requests.
+  virtual void SetKeyFrameRequestCallback(
+      MediaTransportKeyFrameRequestCallback* callback);
+
   // Requests a keyframe for the particular channel (stream). The caller should
   // check that the keyframe is not present in a jitter buffer already (i.e.
   // don't request a keyframe if there is one that you will get from the jitter
diff --git a/api/test/loopback_media_transport.cc b/api/test/loopback_media_transport.cc
index 6094813..8fbaea5 100644
--- a/api/test/loopback_media_transport.cc
+++ b/api/test/loopback_media_transport.cc
@@ -34,6 +34,11 @@
     return wrapped_->SendVideoFrame(channel_id, frame);
   }
 
+  void SetKeyFrameRequestCallback(
+      MediaTransportKeyFrameRequestCallback* callback) override {
+    wrapped_->SetKeyFrameRequestCallback(callback);
+  }
+
   RTCError RequestKeyFrame(uint64_t channel_id) override {
     return wrapped_->RequestKeyFrame(channel_id);
   }
@@ -125,8 +130,20 @@
   return RTCError::OK();
 }
 
+void MediaTransportPair::LoopbackMediaTransport::SetKeyFrameRequestCallback(
+    MediaTransportKeyFrameRequestCallback* callback) {
+  rtc::CritScope lock(&sink_lock_);
+  if (callback) {
+    RTC_CHECK(key_frame_callback_ == nullptr);
+  }
+  key_frame_callback_ = callback;
+}
+
 RTCError MediaTransportPair::LoopbackMediaTransport::RequestKeyFrame(
     uint64_t channel_id) {
+  invoker_.AsyncInvoke<void>(RTC_FROM_HERE, thread_, [this, channel_id] {
+    other_->OnKeyFrameRequested(channel_id);
+  });
   return RTCError::OK();
 }
 
@@ -245,6 +262,14 @@
   }
 }
 
+void MediaTransportPair::LoopbackMediaTransport::OnKeyFrameRequested(
+    int channel_id) {
+  rtc::CritScope lock(&sink_lock_);
+  if (key_frame_callback_) {
+    key_frame_callback_->OnKeyFrameRequested(channel_id);
+  }
+}
+
 void MediaTransportPair::LoopbackMediaTransport::OnRemoteCloseChannel(
     int channel_id) {
   rtc::CritScope lock(&sink_lock_);
diff --git a/api/test/loopback_media_transport.h b/api/test/loopback_media_transport.h
index d520c20..e0f784e 100644
--- a/api/test/loopback_media_transport.h
+++ b/api/test/loopback_media_transport.h
@@ -90,6 +90,9 @@
         uint64_t channel_id,
         const MediaTransportEncodedVideoFrame& frame) override;
 
+    void SetKeyFrameRequestCallback(
+        MediaTransportKeyFrameRequestCallback* callback) override;
+
     RTCError RequestKeyFrame(uint64_t channel_id) override;
 
     void SetReceiveAudioSink(MediaTransportAudioSinkInterface* sink) override;
@@ -122,6 +125,8 @@
                 DataMessageType type,
                 const rtc::CopyOnWriteBuffer& buffer);
 
+    void OnKeyFrameRequested(int channel_id);
+
     void OnRemoteCloseChannel(int channel_id);
 
     void OnStateChanged() RTC_RUN_ON(thread_);
@@ -135,6 +140,10 @@
     MediaTransportVideoSinkInterface* video_sink_ RTC_GUARDED_BY(sink_lock_) =
         nullptr;
     DataChannelSink* data_sink_ RTC_GUARDED_BY(sink_lock_) = nullptr;
+
+    MediaTransportKeyFrameRequestCallback* key_frame_callback_
+        RTC_GUARDED_BY(sink_lock_) = nullptr;
+
     MediaTransportStateCallback* state_callback_ RTC_GUARDED_BY(sink_lock_) =
         nullptr;
 
diff --git a/api/test/loopback_media_transport_unittest.cc b/api/test/loopback_media_transport_unittest.cc
index a1b13ec..c67a9d5 100644
--- a/api/test/loopback_media_transport_unittest.cc
+++ b/api/test/loopback_media_transport_unittest.cc
@@ -29,6 +29,11 @@
     : public MediaTransportVideoSinkInterface {
  public:
   MOCK_METHOD2(OnData, void(uint64_t, MediaTransportEncodedVideoFrame));
+};
+
+class MockMediaTransportKeyFrameRequestCallback
+    : public MediaTransportKeyFrameRequestCallback {
+ public:
   MOCK_METHOD1(OnKeyFrameRequested, void(uint64_t));
 };
 
@@ -125,6 +130,26 @@
   transport_pair.second()->SetReceiveVideoSink(nullptr);
 }
 
+TEST(LoopbackMediaTransport, VideoKeyFrameRequestDeliveredToCallback) {
+  std::unique_ptr<rtc::Thread> thread = rtc::Thread::Create();
+  thread->Start();
+  MediaTransportPair transport_pair(thread.get());
+  testing::StrictMock<MockMediaTransportKeyFrameRequestCallback> callback1;
+  testing::StrictMock<MockMediaTransportKeyFrameRequestCallback> callback2;
+  const uint64_t kFirstChannelId = 1111;
+  const uint64_t kSecondChannelId = 2222;
+
+  EXPECT_CALL(callback1, OnKeyFrameRequested(kSecondChannelId));
+  EXPECT_CALL(callback2, OnKeyFrameRequested(kFirstChannelId));
+  transport_pair.first()->SetKeyFrameRequestCallback(&callback1);
+  transport_pair.second()->SetKeyFrameRequestCallback(&callback2);
+
+  transport_pair.first()->RequestKeyFrame(kFirstChannelId);
+  transport_pair.second()->RequestKeyFrame(kSecondChannelId);
+
+  transport_pair.FlushAsyncInvokes();
+}
+
 TEST(LoopbackMediaTransport, DataDeliveredToSink) {
   std::unique_ptr<rtc::Thread> thread = rtc::Thread::Create();
   thread->Start();