Add video support to LoopbackMediaTransport

Bug: webrtc:9719
Change-Id: I568da8720377342cf44ee8caa316e14b4cd8beba
Reviewed-on: https://webrtc-review.googlesource.com/c/111960
Commit-Queue: Niels Moller <nisse@webrtc.org>
Reviewed-by: Karl Wiberg <kwiberg@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#25826}
diff --git a/api/media_transport_interface.cc b/api/media_transport_interface.cc
index ef223aa..8999e08 100644
--- a/api/media_transport_interface.cc
+++ b/api/media_transport_interface.cc
@@ -59,7 +59,9 @@
 MediaTransportEncodedAudioFrame::MediaTransportEncodedAudioFrame(
     MediaTransportEncodedAudioFrame&&) = default;
 
-MediaTransportEncodedVideoFrame::~MediaTransportEncodedVideoFrame() {}
+MediaTransportEncodedVideoFrame::MediaTransportEncodedVideoFrame() = default;
+
+MediaTransportEncodedVideoFrame::~MediaTransportEncodedVideoFrame() = default;
 
 MediaTransportEncodedVideoFrame::MediaTransportEncodedVideoFrame(
     int64_t frame_id,
@@ -72,16 +74,54 @@
       referenced_frame_ids_(std::move(referenced_frame_ids)) {}
 
 MediaTransportEncodedVideoFrame& MediaTransportEncodedVideoFrame::operator=(
-    const MediaTransportEncodedVideoFrame&) = default;
+    const MediaTransportEncodedVideoFrame& o) {
+  codec_type_ = o.codec_type_;
+  encoded_image_ = o.encoded_image_;
+  encoded_data_ = o.encoded_data_;
+  frame_id_ = o.frame_id_;
+  referenced_frame_ids_ = o.referenced_frame_ids_;
+  if (!encoded_data_.empty()) {
+    // We own the underlying data.
+    encoded_image_._buffer = encoded_data_.data();
+  }
+  return *this;
+}
 
 MediaTransportEncodedVideoFrame& MediaTransportEncodedVideoFrame::operator=(
-    MediaTransportEncodedVideoFrame&&) = default;
+    MediaTransportEncodedVideoFrame&& o) {
+  codec_type_ = o.codec_type_;
+  encoded_image_ = o.encoded_image_;
+  encoded_data_ = std::move(o.encoded_data_);
+  frame_id_ = o.frame_id_;
+  referenced_frame_ids_ = std::move(o.referenced_frame_ids_);
+  if (!encoded_data_.empty()) {
+    // We take over ownership of the underlying data.
+    encoded_image_._buffer = encoded_data_.data();
+    o.encoded_image_._buffer = nullptr;
+  }
+  return *this;
+}
 
 MediaTransportEncodedVideoFrame::MediaTransportEncodedVideoFrame(
-    const MediaTransportEncodedVideoFrame&) = default;
+    const MediaTransportEncodedVideoFrame& o)
+    : MediaTransportEncodedVideoFrame() {
+  *this = o;
+}
 
 MediaTransportEncodedVideoFrame::MediaTransportEncodedVideoFrame(
-    MediaTransportEncodedVideoFrame&&) = default;
+    MediaTransportEncodedVideoFrame&& o)
+    : MediaTransportEncodedVideoFrame() {
+  *this = std::move(o);
+}
+
+void MediaTransportEncodedVideoFrame::Retain() {
+  if (encoded_image_._buffer && encoded_data_.empty()) {
+    encoded_data_ =
+        std::vector<uint8_t>(encoded_image_._buffer,
+                             encoded_image_._buffer + encoded_image_._length);
+    encoded_image_._buffer = encoded_data_.data();
+  }
+}
 
 SendDataParams::SendDataParams() = default;
 
diff --git a/api/media_transport_interface.h b/api/media_transport_interface.h
index b10dd63..9dee4f9 100644
--- a/api/media_transport_interface.h
+++ b/api/media_transport_interface.h
@@ -183,14 +183,23 @@
     return referenced_frame_ids_;
   }
 
+  // Hack to workaround lack of ownership of the encoded_image_._buffer. If we
+  // don't already own the underlying data, make a copy.
+  void Retain();
+
  private:
+  MediaTransportEncodedVideoFrame();
+
   VideoCodecType codec_type_;
 
-  // The buffer is not owned by the encoded image by default. On the sender it
-  // means that it will need to make a copy of it if it wants to deliver it
-  // asynchronously.
+  // The buffer is not owned by the encoded image. On the sender it means that
+  // it will need to make a copy using the Retain() method, if it wants to
+  // deliver it asynchronously.
   webrtc::EncodedImage encoded_image_;
 
+  // If non-empty, this is the data for the encoded image.
+  std::vector<uint8_t> encoded_data_;
+
   // Frame id uniquely identifies a frame in a stream. It needs to be unique in
   // a given time window (i.e. technically unique identifier for the lifetime of
   // the connection is not needed, but you need to guarantee that remote side
diff --git a/api/test/loopback_media_transport.h b/api/test/loopback_media_transport.h
index 48255b1..2620789 100644
--- a/api/test/loopback_media_transport.h
+++ b/api/test/loopback_media_transport.h
@@ -98,6 +98,8 @@
   struct Stats {
     int sent_audio_frames = 0;
     int received_audio_frames = 0;
+    int sent_video_frames = 0;
+    int received_video_frames = 0;
   };
 
   explicit MediaTransportPair(rtc::Thread* thread)
@@ -136,7 +138,8 @@
 
     ~LoopbackMediaTransport() {
       rtc::CritScope lock(&sink_lock_);
-      RTC_CHECK(sink_ == nullptr);
+      RTC_CHECK(audio_sink_ == nullptr);
+      RTC_CHECK(video_sink_ == nullptr);
       RTC_CHECK(data_sink_ == nullptr);
     }
 
@@ -156,6 +159,17 @@
     RTCError SendVideoFrame(
         uint64_t channel_id,
         const MediaTransportEncodedVideoFrame& frame) override {
+      {
+        rtc::CritScope lock(&stats_lock_);
+        ++stats_.sent_video_frames;
+      }
+      // Ensure that we own the referenced data.
+      MediaTransportEncodedVideoFrame frame_copy = frame;
+      frame_copy.Retain();
+      invoker_.AsyncInvoke<void>(
+          RTC_FROM_HERE, thread_, [this, channel_id, frame_copy] {
+            other_->OnData(channel_id, std::move(frame_copy));
+          });
       return RTCError::OK();
     }
 
@@ -166,12 +180,18 @@
     void SetReceiveAudioSink(MediaTransportAudioSinkInterface* sink) override {
       rtc::CritScope lock(&sink_lock_);
       if (sink) {
-        RTC_CHECK(sink_ == nullptr);
+        RTC_CHECK(audio_sink_ == nullptr);
       }
-      sink_ = sink;
+      audio_sink_ = sink;
     }
 
-    void SetReceiveVideoSink(MediaTransportVideoSinkInterface* sink) override {}
+    void SetReceiveVideoSink(MediaTransportVideoSinkInterface* sink) override {
+      rtc::CritScope lock(&sink_lock_);
+      if (sink) {
+        RTC_CHECK(video_sink_ == nullptr);
+      }
+      video_sink_ = sink;
+    }
 
     void SetMediaTransportStateCallback(
         MediaTransportStateCallback* callback) override {
@@ -228,8 +248,8 @@
     void OnData(uint64_t channel_id, MediaTransportEncodedAudioFrame frame) {
       {
         rtc::CritScope lock(&sink_lock_);
-        if (sink_) {
-          sink_->OnData(channel_id, frame);
+        if (audio_sink_) {
+          audio_sink_->OnData(channel_id, frame);
         }
       }
       {
@@ -238,6 +258,19 @@
       }
     }
 
+    void OnData(uint64_t channel_id, MediaTransportEncodedVideoFrame frame) {
+      {
+        rtc::CritScope lock(&sink_lock_);
+        if (video_sink_) {
+          video_sink_->OnData(channel_id, frame);
+        }
+      }
+      {
+        rtc::CritScope lock(&stats_lock_);
+        ++stats_.received_video_frames;
+      }
+    }
+
     void OnData(int channel_id,
                 DataMessageType type,
                 const rtc::CopyOnWriteBuffer& buffer) {
@@ -266,7 +299,9 @@
     rtc::CriticalSection sink_lock_;
     rtc::CriticalSection stats_lock_;
 
-    MediaTransportAudioSinkInterface* sink_ RTC_GUARDED_BY(sink_lock_) =
+    MediaTransportAudioSinkInterface* audio_sink_ RTC_GUARDED_BY(sink_lock_) =
+        nullptr;
+    MediaTransportVideoSinkInterface* video_sink_ RTC_GUARDED_BY(sink_lock_) =
         nullptr;
     DataChannelSink* data_sink_ RTC_GUARDED_BY(sink_lock_) = nullptr;
     MediaTransportStateCallback* state_callback_ RTC_GUARDED_BY(sink_lock_) =
diff --git a/api/test/loopback_media_transport_unittest.cc b/api/test/loopback_media_transport_unittest.cc
index ba741a0..ef4bbf4 100644
--- a/api/test/loopback_media_transport_unittest.cc
+++ b/api/test/loopback_media_transport_unittest.cc
@@ -8,6 +8,7 @@
  *  be found in the AUTHORS file in the root of the source tree.
  */
 
+#include <algorithm>
 #include <memory>
 #include <vector>
 
@@ -24,6 +25,13 @@
   MOCK_METHOD2(OnData, void(uint64_t, MediaTransportEncodedAudioFrame));
 };
 
+class MockMediaTransportVideoSinkInterface
+    : public MediaTransportVideoSinkInterface {
+ public:
+  MOCK_METHOD2(OnData, void(uint64_t, MediaTransportEncodedVideoFrame));
+  MOCK_METHOD1(OnKeyFrameRequested, void(uint64_t));
+};
+
 class MockDataChannelSink : public DataChannelSink {
  public:
   MOCK_METHOD3(OnDataReceived,
@@ -50,6 +58,13 @@
       kPayloadType, std::vector<uint8_t>(kSamplesPerChannel));
 }
 
+MediaTransportEncodedVideoFrame CreateVideoFrame(
+    int frame_id,
+    const webrtc::EncodedImage& encoded_image) {
+  return MediaTransportEncodedVideoFrame(frame_id, /*referenced_frame_ids=*/{},
+                                         kVideoCodecVP8, encoded_image);
+}
+
 }  // namespace
 
 TEST(LoopbackMediaTransport, AudioWithNoSinkSilentlyIgnored) {
@@ -77,6 +92,38 @@
   transport_pair.second()->SetReceiveAudioSink(nullptr);
 }
 
+TEST(LoopbackMediaTransport, VideoDeliveredToSink) {
+  std::unique_ptr<rtc::Thread> thread = rtc::Thread::Create();
+  thread->Start();
+  MediaTransportPair transport_pair(thread.get());
+  testing::StrictMock<MockMediaTransportVideoSinkInterface> sink;
+  uint8_t encoded_data[] = {1, 2, 3};
+  EncodedImage encoded_image;
+  encoded_image._buffer = encoded_data;
+  encoded_image._length = sizeof(encoded_data);
+
+  EXPECT_CALL(sink, OnData(1, testing::Property(
+                                  &MediaTransportEncodedVideoFrame::frame_id,
+                                  testing::Eq(10))))
+      .WillOnce(testing::Invoke(
+          [&encoded_image](int frame_id,
+                           const MediaTransportEncodedVideoFrame& frame) {
+            EXPECT_NE(frame.encoded_image()._buffer, encoded_image._buffer);
+            EXPECT_EQ(frame.encoded_image()._length, encoded_image._length);
+            EXPECT_EQ(
+                0, memcmp(frame.encoded_image()._buffer, encoded_image._buffer,
+                          std::min(frame.encoded_image()._length,
+                                   encoded_image._length)));
+          }));
+
+  transport_pair.second()->SetReceiveVideoSink(&sink);
+  transport_pair.first()->SendVideoFrame(1,
+                                         CreateVideoFrame(10, encoded_image));
+
+  transport_pair.FlushAsyncInvokes();
+  transport_pair.second()->SetReceiveVideoSink(nullptr);
+}
+
 TEST(LoopbackMediaTransport, DataDeliveredToSink) {
   std::unique_ptr<rtc::Thread> thread = rtc::Thread::Create();
   thread->Start();