Add Sender and Receiver interfaces for MediaTransport audio

Implement in LoopbackMediaTransport.

Bug: webrtc:9719
Change-Id: I429ac3f78d99b8ea4f9ac85b9a3600b215b61a55
Reviewed-on: https://webrtc-review.googlesource.com/c/121957
Commit-Queue: Niels Moller <nisse@webrtc.org>
Reviewed-by: Fredrik Solenberg <solenberg@webrtc.org>
Reviewed-by: Bjorn Mellem <mellem@webrtc.org>
Reviewed-by: Anton Sukhanov <sukhanov@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#26731}
diff --git a/api/media_transport_interface.cc b/api/media_transport_interface.cc
index de04e19..0dfec76 100644
--- a/api/media_transport_interface.cc
+++ b/api/media_transport_interface.cc
@@ -54,6 +54,18 @@
 MediaTransportInterface::MediaTransportInterface() = default;
 MediaTransportInterface::~MediaTransportInterface() = default;
 
+std::unique_ptr<MediaTransportAudioSender>
+MediaTransportInterface::CreateAudioSender(uint64_t channel_id) {
+  return nullptr;
+}
+
+std::unique_ptr<MediaTransportAudioReceiver>
+MediaTransportInterface::CreateAudioReceiver(
+    uint64_t channel_id,
+    MediaTransportAudioSinkInterface* sink) {
+  return nullptr;
+}
+
 void MediaTransportInterface::SetKeyFrameRequestCallback(
     MediaTransportKeyFrameRequestCallback* callback) {}
 
diff --git a/api/media_transport_interface.h b/api/media_transport_interface.h
index 2f5431f..e753ddc 100644
--- a/api/media_transport_interface.h
+++ b/api/media_transport_interface.h
@@ -187,10 +187,28 @@
   MediaTransportInterface();
   virtual ~MediaTransportInterface();
 
+  // Creates an object representing the send end-point of a audio stream using
+  // this transport.
+  // TODO(bugs.webrtc.org/9719): Make pure virtual after downstream
+  // implementations are updated.
+  virtual std::unique_ptr<MediaTransportAudioSender> CreateAudioSender(
+      uint64_t channel_id);
+
+  // Creates an object representing the receive end-point of a audio stream
+  // using this transport.
+  // TODO(bugs.webrtc.org/9719): Make pure virtual after downstream
+  // implementations are updated.
+  virtual std::unique_ptr<MediaTransportAudioReceiver> CreateAudioReceiver(
+      uint64_t channel_id,
+      // TODO(nisse): Add Rtt observer, or route that via Call to the receive
+      // stream instead?
+      MediaTransportAudioSinkInterface* sink);
+
   // Start asynchronous send of audio frame. The status returned by this method
   // only pertains to the synchronous operations (e.g.
   // serialization/packetization), not to the asynchronous operation.
-
+  // TODO(nisse): Deprecated, should be deleted when implementations are updated
+  // to use CreateAudioSender.
   virtual RTCError SendAudioFrame(uint64_t channel_id,
                                   MediaTransportEncodedAudioFrame frame) = 0;
 
diff --git a/api/test/loopback_media_transport.cc b/api/test/loopback_media_transport.cc
index c466170..fde732e 100644
--- a/api/test/loopback_media_transport.cc
+++ b/api/test/loopback_media_transport.cc
@@ -109,6 +109,7 @@
 
 MediaTransportPair::LoopbackMediaTransport::~LoopbackMediaTransport() {
   rtc::CritScope lock(&sink_lock_);
+  RTC_CHECK(audio_sinks_.empty());
   RTC_CHECK(audio_sink_ == nullptr);
   RTC_CHECK(video_sink_ == nullptr);
   RTC_CHECK(data_sink_ == nullptr);
@@ -116,6 +117,58 @@
   RTC_CHECK(rtt_observers_.empty());
 }
 
+class MediaTransportPair::LoopbackMediaTransport::AudioSender
+    : public MediaTransportAudioSender {
+ public:
+  AudioSender(LoopbackMediaTransport* transport, uint64_t channel_id)
+      : transport_(transport), channel_id_(channel_id) {}
+  void SendAudioFrame(MediaTransportEncodedAudioFrame frame) override {
+    transport_->SendAudioFrame(channel_id_, std::move(frame));
+  }
+
+ private:
+  LoopbackMediaTransport* transport_;
+  uint64_t channel_id_;
+};
+
+class MediaTransportPair::LoopbackMediaTransport::AudioReceiver
+    : public MediaTransportAudioReceiver {
+ public:
+  AudioReceiver(LoopbackMediaTransport* transport, uint64_t channel_id)
+      : transport_(transport), channel_id_(channel_id) {}
+  ~AudioReceiver() override {
+    transport_->UnregisterAudioReceiver(channel_id_);
+  }
+
+ private:
+  LoopbackMediaTransport* transport_;
+  uint64_t channel_id_;
+};
+
+std::unique_ptr<MediaTransportAudioSender>
+MediaTransportPair::LoopbackMediaTransport::CreateAudioSender(
+    uint64_t channel_id) {
+  return absl::make_unique<AudioSender>(this, channel_id);
+}
+
+std::unique_ptr<MediaTransportAudioReceiver>
+MediaTransportPair::LoopbackMediaTransport::CreateAudioReceiver(
+    uint64_t channel_id,
+    MediaTransportAudioSinkInterface* sink) {
+  rtc::CritScope cs(&sink_lock_);
+  auto res = audio_sinks_.emplace(channel_id, sink);
+  RTC_DCHECK(res.second);
+  return absl::make_unique<AudioReceiver>(this, channel_id);
+}
+
+void MediaTransportPair::LoopbackMediaTransport::UnregisterAudioReceiver(
+    uint64_t channel_id) {
+  rtc::CritScope cs(&sink_lock_);
+  auto it = audio_sinks_.find(channel_id);
+  RTC_DCHECK(it != audio_sinks_.end());
+  audio_sinks_.erase(it);
+}
+
 RTCError MediaTransportPair::LoopbackMediaTransport::SendAudioFrame(
     uint64_t channel_id,
     MediaTransportEncodedAudioFrame frame) {
@@ -317,7 +370,10 @@
     MediaTransportEncodedAudioFrame frame) {
   {
     rtc::CritScope lock(&sink_lock_);
-    if (audio_sink_) {
+    const auto it = audio_sinks_.find(channel_id);
+    if (it != audio_sinks_.end()) {
+      it->second->OnData(frame);
+    } else if (audio_sink_) {
       audio_sink_->OnData(channel_id, frame);
     }
   }
diff --git a/api/test/loopback_media_transport.h b/api/test/loopback_media_transport.h
index bcfdb63..d2c503b 100644
--- a/api/test/loopback_media_transport.h
+++ b/api/test/loopback_media_transport.h
@@ -11,6 +11,7 @@
 #ifndef API_TEST_LOOPBACK_MEDIA_TRANSPORT_H_
 #define API_TEST_LOOPBACK_MEDIA_TRANSPORT_H_
 
+#include <map>
 #include <memory>
 #include <utility>
 #include <vector>
@@ -85,6 +86,13 @@
 
     ~LoopbackMediaTransport() override;
 
+    std::unique_ptr<MediaTransportAudioSender> CreateAudioSender(
+        uint64_t channel_id) override;
+
+    std::unique_ptr<MediaTransportAudioReceiver> CreateAudioReceiver(
+        uint64_t channel_id,
+        MediaTransportAudioSinkInterface* sink) override;
+
     RTCError SendAudioFrame(uint64_t channel_id,
                             MediaTransportEncodedAudioFrame frame) override;
 
@@ -131,6 +139,9 @@
         const MediaTransportAllocatedBitrateLimits& limits) override;
 
    private:
+    class AudioReceiver;
+    class AudioSender;
+
     void OnData(uint64_t channel_id, MediaTransportEncodedAudioFrame frame);
 
     void OnData(uint64_t channel_id, MediaTransportEncodedVideoFrame frame);
@@ -144,11 +155,17 @@
     void OnRemoteCloseChannel(int channel_id);
 
     void OnStateChanged() RTC_RUN_ON(thread_);
+    void UnregisterAudioReceiver(uint64_t channel_id);
 
     rtc::Thread* const thread_;
     rtc::CriticalSection sink_lock_;
     rtc::CriticalSection stats_lock_;
 
+    std::map<uint64_t, MediaTransportAudioSinkInterface*> audio_sinks_
+        RTC_GUARDED_BY(sink_lock_);
+
+    // TODO(bugs.webrtc.org/9719): Delete when everything is converted to
+    // CreateAudioReceiver.
     MediaTransportAudioSinkInterface* audio_sink_ RTC_GUARDED_BY(sink_lock_) =
         nullptr;
     MediaTransportVideoSinkInterface* video_sink_ RTC_GUARDED_BY(sink_lock_) =
diff --git a/api/test/loopback_media_transport_unittest.cc b/api/test/loopback_media_transport_unittest.cc
index b827405..8fe432d 100644
--- a/api/test/loopback_media_transport_unittest.cc
+++ b/api/test/loopback_media_transport_unittest.cc
@@ -22,6 +22,8 @@
 class MockMediaTransportAudioSinkInterface
     : public MediaTransportAudioSinkInterface {
  public:
+  MOCK_METHOD1(OnData, void(MediaTransportEncodedAudioFrame));
+  // TODO(nisse): Deprecated version, delete.
   MOCK_METHOD2(OnData, void(uint64_t, MediaTransportEncodedAudioFrame));
 };
 
diff --git a/api/transport/media/audio_transport.cc b/api/transport/media/audio_transport.cc
index 7285ad4..5dae4d3 100644
--- a/api/transport/media/audio_transport.cc
+++ b/api/transport/media/audio_transport.cc
@@ -51,4 +51,10 @@
 MediaTransportEncodedAudioFrame::MediaTransportEncodedAudioFrame(
     MediaTransportEncodedAudioFrame&&) = default;
 
+void MediaTransportAudioSinkInterface::OnData(
+    uint64_t channel_id,
+    MediaTransportEncodedAudioFrame frame) {
+  OnData(frame);
+}
+
 }  // namespace webrtc
diff --git a/api/transport/media/audio_transport.h b/api/transport/media/audio_transport.h
index dcbdcd7..d3afbf3 100644
--- a/api/transport/media/audio_transport.h
+++ b/api/transport/media/audio_transport.h
@@ -111,9 +111,29 @@
  public:
   virtual ~MediaTransportAudioSinkInterface() = default;
 
-  // Called when new encoded audio frame is received.
+  // Called when new encoded audio frame is received, and no receiver is
+  // registered. Deprecated.
   virtual void OnData(uint64_t channel_id,
-                      MediaTransportEncodedAudioFrame frame) = 0;
+                      MediaTransportEncodedAudioFrame frame);
+
+  // Called when new encoded audio frame is received.
+  // TODO(bugs.webrtc.org/9719): Make pure virtual after downstream
+  // implementations are updated.
+  virtual void OnData(MediaTransportEncodedAudioFrame frame) {}
+};
+
+class MediaTransportAudioSender {
+ public:
+  virtual ~MediaTransportAudioSender() = default;
+
+  virtual void SendAudioFrame(MediaTransportEncodedAudioFrame frame) = 0;
+};
+
+// Similar to RtpStreamReceiverInterface, only owns the association with the
+// demuxer.
+class MediaTransportAudioReceiver {
+ public:
+  virtual ~MediaTransportAudioReceiver() = default;
 };
 
 }  // namespace webrtc
diff --git a/audio/channel_receive.cc b/audio/channel_receive.cc
index 0e218ed..40dc2c1 100644
--- a/audio/channel_receive.cc
+++ b/audio/channel_receive.cc
@@ -58,15 +58,14 @@
 constexpr int kVoiceEngineMaxMinPlayoutDelayMs = 10000;
 
 RTPHeader CreateRTPHeaderForMediaTransportFrame(
-    const MediaTransportEncodedAudioFrame& frame,
-    uint64_t channel_id) {
+    const MediaTransportEncodedAudioFrame& frame) {
   webrtc::RTPHeader rtp_header;
   rtp_header.payloadType = frame.payload_type();
   rtp_header.payload_type_frequency = frame.sampling_rate_hz();
   rtp_header.timestamp = frame.starting_sample_index();
   rtp_header.sequenceNumber = frame.sequence_number();
 
-  rtp_header.ssrc = static_cast<uint32_t>(channel_id);
+  // Note: SSRC is no longer used by NetEq, so not set.
 
   // The rest are initialized by the RTPHeader constructor.
   return rtp_header;
@@ -167,8 +166,12 @@
   int64_t GetRTT() const;
 
   // MediaTransportAudioSinkInterface override;
-  void OnData(uint64_t channel_id,
-              MediaTransportEncodedAudioFrame frame) override;
+  void OnData(MediaTransportEncodedAudioFrame frame) override;
+  // TODO(nisse): Deprecated variant. Delete.
+  void OnData(uint64_t /* channel_id */,
+              MediaTransportEncodedAudioFrame frame) override {
+    OnData(std::move(frame));
+  }
 
   int32_t OnReceivedPayloadData(const uint8_t* payloadData,
                                 size_t payloadSize,
@@ -293,8 +296,7 @@
 }
 
 // MediaTransportAudioSinkInterface override.
-void ChannelReceive::OnData(uint64_t channel_id,
-                            MediaTransportEncodedAudioFrame frame) {
+void ChannelReceive::OnData(MediaTransportEncodedAudioFrame frame) {
   RTC_CHECK(media_transport_);
 
   if (!Playing()) {
@@ -306,7 +308,7 @@
   // Send encoded audio frame to Decoder / NetEq.
   if (audio_coding_->IncomingPacket(
           frame.encoded_data().data(), frame.encoded_data().size(),
-          CreateRTPHeaderForMediaTransportFrame(frame, channel_id)) != 0) {
+          CreateRTPHeaderForMediaTransportFrame(frame)) != 0) {
     RTC_DLOG(LS_ERROR) << "ChannelReceive::OnData: unable to "
                           "push data to the ACM";
   }