Add SetKeyFrameRequestCallback to MediaTransportInterface
And implemented in LoopbackMediaTransport.
Bug: webrtc:9719
Change-Id: I68b16c2b6ed5583ffe9a5266e3d4cb1d94afbb97
Reviewed-on: https://webrtc-review.googlesource.com/c/113523
Reviewed-by: Karl Wiberg <kwiberg@webrtc.org>
Commit-Queue: Niels Moller <nisse@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#25948}
diff --git a/api/media_transport_interface.cc b/api/media_transport_interface.cc
index 6b00f2c..b201eae 100644
--- a/api/media_transport_interface.cc
+++ b/api/media_transport_interface.cc
@@ -158,6 +158,9 @@
return std::unique_ptr<MediaTransportInterface>(nullptr);
}
+void MediaTransportInterface::SetKeyFrameRequestCallback(
+ MediaTransportKeyFrameRequestCallback* callback) {}
+
absl::optional<TargetTransferRate>
MediaTransportInterface::GetLatestTargetTransferRate() {
return absl::nullopt;
diff --git a/api/media_transport_interface.h b/api/media_transport_interface.h
index 6b461cf..6f2ec60 100644
--- a/api/media_transport_interface.h
+++ b/api/media_transport_interface.h
@@ -239,6 +239,15 @@
RTC_DEPRECATED virtual void OnKeyFrameRequested(uint64_t channel_id) {}
};
+// Interface for video sender to be notified of received key frame request.
+class MediaTransportKeyFrameRequestCallback {
+ public:
+ virtual ~MediaTransportKeyFrameRequestCallback() = default;
+
+ // Called when a key frame request is received on the transport.
+ virtual void OnKeyFrameRequested(uint64_t channel_id) = 0;
+};
+
// State of the media transport. Media transport begins in the pending state.
// It transitions to writable when it is ready to send media. It may transition
// back to pending if the connection is blocked. It may transition to closed at
@@ -339,6 +348,10 @@
uint64_t channel_id,
const MediaTransportEncodedVideoFrame& frame) = 0;
+ // Used by video sender to be notified on key frame requests.
+ virtual void SetKeyFrameRequestCallback(
+ MediaTransportKeyFrameRequestCallback* callback);
+
// Requests a keyframe for the particular channel (stream). The caller should
// check that the keyframe is not present in a jitter buffer already (i.e.
// don't request a keyframe if there is one that you will get from the jitter
diff --git a/api/test/loopback_media_transport.cc b/api/test/loopback_media_transport.cc
index 6094813..8fbaea5 100644
--- a/api/test/loopback_media_transport.cc
+++ b/api/test/loopback_media_transport.cc
@@ -34,6 +34,11 @@
return wrapped_->SendVideoFrame(channel_id, frame);
}
+ void SetKeyFrameRequestCallback(
+ MediaTransportKeyFrameRequestCallback* callback) override {
+ wrapped_->SetKeyFrameRequestCallback(callback);
+ }
+
RTCError RequestKeyFrame(uint64_t channel_id) override {
return wrapped_->RequestKeyFrame(channel_id);
}
@@ -125,8 +130,20 @@
return RTCError::OK();
}
+void MediaTransportPair::LoopbackMediaTransport::SetKeyFrameRequestCallback(
+ MediaTransportKeyFrameRequestCallback* callback) {
+ rtc::CritScope lock(&sink_lock_);
+ if (callback) {
+ RTC_CHECK(key_frame_callback_ == nullptr);
+ }
+ key_frame_callback_ = callback;
+}
+
RTCError MediaTransportPair::LoopbackMediaTransport::RequestKeyFrame(
uint64_t channel_id) {
+ invoker_.AsyncInvoke<void>(RTC_FROM_HERE, thread_, [this, channel_id] {
+ other_->OnKeyFrameRequested(channel_id);
+ });
return RTCError::OK();
}
@@ -245,6 +262,14 @@
}
}
+void MediaTransportPair::LoopbackMediaTransport::OnKeyFrameRequested(
+ int channel_id) {
+ rtc::CritScope lock(&sink_lock_);
+ if (key_frame_callback_) {
+ key_frame_callback_->OnKeyFrameRequested(channel_id);
+ }
+}
+
void MediaTransportPair::LoopbackMediaTransport::OnRemoteCloseChannel(
int channel_id) {
rtc::CritScope lock(&sink_lock_);
diff --git a/api/test/loopback_media_transport.h b/api/test/loopback_media_transport.h
index d520c20..e0f784e 100644
--- a/api/test/loopback_media_transport.h
+++ b/api/test/loopback_media_transport.h
@@ -90,6 +90,9 @@
uint64_t channel_id,
const MediaTransportEncodedVideoFrame& frame) override;
+ void SetKeyFrameRequestCallback(
+ MediaTransportKeyFrameRequestCallback* callback) override;
+
RTCError RequestKeyFrame(uint64_t channel_id) override;
void SetReceiveAudioSink(MediaTransportAudioSinkInterface* sink) override;
@@ -122,6 +125,8 @@
DataMessageType type,
const rtc::CopyOnWriteBuffer& buffer);
+ void OnKeyFrameRequested(int channel_id);
+
void OnRemoteCloseChannel(int channel_id);
void OnStateChanged() RTC_RUN_ON(thread_);
@@ -135,6 +140,10 @@
MediaTransportVideoSinkInterface* video_sink_ RTC_GUARDED_BY(sink_lock_) =
nullptr;
DataChannelSink* data_sink_ RTC_GUARDED_BY(sink_lock_) = nullptr;
+
+ MediaTransportKeyFrameRequestCallback* key_frame_callback_
+ RTC_GUARDED_BY(sink_lock_) = nullptr;
+
MediaTransportStateCallback* state_callback_ RTC_GUARDED_BY(sink_lock_) =
nullptr;
diff --git a/api/test/loopback_media_transport_unittest.cc b/api/test/loopback_media_transport_unittest.cc
index a1b13ec..c67a9d5 100644
--- a/api/test/loopback_media_transport_unittest.cc
+++ b/api/test/loopback_media_transport_unittest.cc
@@ -29,6 +29,11 @@
: public MediaTransportVideoSinkInterface {
public:
MOCK_METHOD2(OnData, void(uint64_t, MediaTransportEncodedVideoFrame));
+};
+
+class MockMediaTransportKeyFrameRequestCallback
+ : public MediaTransportKeyFrameRequestCallback {
+ public:
MOCK_METHOD1(OnKeyFrameRequested, void(uint64_t));
};
@@ -125,6 +130,26 @@
transport_pair.second()->SetReceiveVideoSink(nullptr);
}
+TEST(LoopbackMediaTransport, VideoKeyFrameRequestDeliveredToCallback) {
+ std::unique_ptr<rtc::Thread> thread = rtc::Thread::Create();
+ thread->Start();
+ MediaTransportPair transport_pair(thread.get());
+ testing::StrictMock<MockMediaTransportKeyFrameRequestCallback> callback1;
+ testing::StrictMock<MockMediaTransportKeyFrameRequestCallback> callback2;
+ const uint64_t kFirstChannelId = 1111;
+ const uint64_t kSecondChannelId = 2222;
+
+ EXPECT_CALL(callback1, OnKeyFrameRequested(kSecondChannelId));
+ EXPECT_CALL(callback2, OnKeyFrameRequested(kFirstChannelId));
+ transport_pair.first()->SetKeyFrameRequestCallback(&callback1);
+ transport_pair.second()->SetKeyFrameRequestCallback(&callback2);
+
+ transport_pair.first()->RequestKeyFrame(kFirstChannelId);
+ transport_pair.second()->RequestKeyFrame(kSecondChannelId);
+
+ transport_pair.FlushAsyncInvokes();
+}
+
TEST(LoopbackMediaTransport, DataDeliveredToSink) {
std::unique_ptr<rtc::Thread> thread = rtc::Thread::Create();
thread->Start();