Allow RTX ssrc to be updated on receive streams

This is used when an unsignaled stream with a known payload type is received and later a RTX packet is received.

Bug: webrtc:14817
Change-Id: I29f43281cec17553e1ec2483e21b8847714d2931
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/291328
Reviewed-by: Henrik Boström <hbos@webrtc.org>
Reviewed-by: Erik Språng <sprang@webrtc.org>
Reviewed-by: Harald Alvestrand <hta@webrtc.org>
Commit-Queue: Per Kjellander <perkj@webrtc.org>
Cr-Commit-Position: refs/heads/main@{#39243}
diff --git a/call/video_receive_stream.h b/call/video_receive_stream.h
index 0fa257e..bde8c8b 100644
--- a/call/video_receive_stream.h
+++ b/call/video_receive_stream.h
@@ -11,6 +11,7 @@
 #ifndef CALL_VIDEO_RECEIVE_STREAM_H_
 #define CALL_VIDEO_RECEIVE_STREAM_H_
 
+#include <cstdint>
 #include <limits>
 #include <map>
 #include <set>
@@ -33,6 +34,7 @@
 #include "common_video/frame_counts.h"
 #include "modules/rtp_rtcp/include/rtcp_statistics.h"
 #include "modules/rtp_rtcp/include/rtp_rtcp_defines.h"
+#include "rtc_base/checks.h"
 
 namespace webrtc {
 
@@ -310,6 +312,8 @@
   virtual void SetAssociatedPayloadTypes(
       std::map<int, int> associated_payload_types) = 0;
 
+  virtual void UpdateRtxSsrc(uint32_t ssrc) = 0;
+
  protected:
   virtual ~VideoReceiveStreamInterface() {}
 };
diff --git a/media/engine/fake_webrtc_call.cc b/media/engine/fake_webrtc_call.cc
index f3f6803..99a8d7b 100644
--- a/media/engine/fake_webrtc_call.cc
+++ b/media/engine/fake_webrtc_call.cc
@@ -660,7 +660,8 @@
 
   if (media_type == webrtc::MediaType::VIDEO) {
     for (auto receiver : video_receive_streams_) {
-      if (receiver->GetConfig().rtp.remote_ssrc == ssrc) {
+      if (receiver->GetConfig().rtp.remote_ssrc == ssrc ||
+          receiver->GetConfig().rtp.rtx_ssrc == ssrc) {
         ++delivered_packets_by_ssrc_[ssrc];
         return true;
       }
diff --git a/media/engine/fake_webrtc_call.h b/media/engine/fake_webrtc_call.h
index fc1458d..7c8b93d 100644
--- a/media/engine/fake_webrtc_call.h
+++ b/media/engine/fake_webrtc_call.h
@@ -259,6 +259,8 @@
     config_.rtp.local_ssrc = local_ssrc;
   }
 
+  void UpdateRtxSsrc(uint32_t ssrc) { config_.rtp.rtx_ssrc = ssrc; }
+
   void SetFrameDecryptor(rtc::scoped_refptr<webrtc::FrameDecryptorInterface>
                              frame_decryptor) override {}
 
diff --git a/media/engine/webrtc_video_engine.cc b/media/engine/webrtc_video_engine.cc
index e1329b8..3791e10 100644
--- a/media/engine/webrtc_video_engine.cc
+++ b/media/engine/webrtc_video_engine.cc
@@ -1792,10 +1792,7 @@
     // stream, which will be associated with unsignaled media stream.
     absl::optional<uint32_t> current_default_ssrc = GetUnsignaledSsrc();
     if (current_default_ssrc) {
-      // TODO(bug.webrtc.org/14817): Consider associating the existing default
-      // stream with this RTX stream instead of recreating.
-      ReCreateDefaulReceiveStream(/*ssrc =*/*current_default_ssrc,
-                                  packet.Ssrc());
+      FindReceiveStream(*current_default_ssrc)->UpdateRtxSsrc(packet.Ssrc());
     } else {
       // Received unsignaled RTX packet before a media packet. Create a default
       // stream with a "random" SSRC and the RTX SSRC from the packet.  The
@@ -1822,10 +1819,7 @@
       }
     }
   }
-
-  // TODO(bug.webrtc.org/14817): Consider creating a default stream with a fake
-  // RTX ssrc that can be updated when the real SSRC is known if rtx has been
-  // negotiated.
+  // RTX SSRC not yet known.
   ReCreateDefaulReceiveStream(packet.Ssrc(), absl::nullopt);
   last_unsignalled_ssrc_creation_time_ms_ = rtc::TimeMillis();
   return true;
@@ -3356,6 +3350,11 @@
     call_->OnLocalSsrcUpdated(*flexfec_stream_, ssrc);
 }
 
+void WebRtcVideoChannel::WebRtcVideoReceiveStream::UpdateRtxSsrc(
+    uint32_t ssrc) {
+  stream_->UpdateRtxSsrc(ssrc);
+}
+
 WebRtcVideoChannel::VideoCodecSettings::VideoCodecSettings()
     : flexfec_payload_type(-1), rtx_payload_type(-1) {}
 
@@ -3605,8 +3604,8 @@
   RTC_DCHECK(frame_transformer);
   RTC_DCHECK_RUN_ON(&thread_checker_);
   if (ssrc == 0) {
-    // If the receiver is unsignaled, save the frame transformer and set it when
-    // the stream is associated with an ssrc.
+    // If the receiver is unsignaled, save the frame transformer and set it
+    // when the stream is associated with an ssrc.
     unsignaled_frame_transformer_ = std::move(frame_transformer);
     return;
   }
diff --git a/media/engine/webrtc_video_engine.h b/media/engine/webrtc_video_engine.h
index 7ce1655..ee22a7e 100644
--- a/media/engine/webrtc_video_engine.h
+++ b/media/engine/webrtc_video_engine.h
@@ -11,6 +11,7 @@
 #ifndef MEDIA_ENGINE_WEBRTC_VIDEO_ENGINE_H_
 #define MEDIA_ENGINE_WEBRTC_VIDEO_ENGINE_H_
 
+#include <cstdint>
 #include <map>
 #include <memory>
 #include <set>
@@ -493,6 +494,7 @@
             frame_transformer);
 
     void SetLocalSsrc(uint32_t local_ssrc);
+    void UpdateRtxSsrc(uint32_t ssrc);
 
    private:
     // Attempts to reconfigure an already existing `flexfec_stream_`, create
diff --git a/media/engine/webrtc_video_engine_unittest.cc b/media/engine/webrtc_video_engine_unittest.cc
index 938f1a6..3cc1b3c 100644
--- a/media/engine/webrtc_video_engine_unittest.cc
+++ b/media/engine/webrtc_video_engine_unittest.cc
@@ -281,6 +281,47 @@
   return res;
 }
 
+RtpPacketReceived BuildVp8KeyFrame(uint32_t ssrc, uint8_t payload_type) {
+  RtpPacketReceived packet;
+  packet.SetMarker(true);
+  packet.SetPayloadType(payload_type);
+  packet.SetSsrc(ssrc);
+
+  // VP8 Keyframe + 1 byte payload
+  uint8_t* buf_ptr = packet.AllocatePayload(11);
+  memset(buf_ptr, 0, 11);  // Pass MSAN (don't care about bytes 1-9)
+  buf_ptr[0] = 0x10;       // Partition ID 0 + beginning of partition.
+  constexpr unsigned width = 1080;
+  constexpr unsigned height = 720;
+  buf_ptr[6] = width & 255;
+  buf_ptr[7] = width >> 8;
+  buf_ptr[8] = height & 255;
+  buf_ptr[9] = height >> 8;
+  return packet;
+}
+
+RtpPacketReceived BuildRtxPacket(uint32_t rtx_ssrc,
+                                 uint8_t rtx_payload_type,
+                                 const RtpPacketReceived& original_packet) {
+  constexpr size_t kRtxHeaderSize = 2;
+  RtpPacketReceived packet(original_packet);
+  packet.SetPayloadType(rtx_payload_type);
+  packet.SetSsrc(rtx_ssrc);
+
+  uint8_t* rtx_payload =
+      packet.AllocatePayload(original_packet.payload_size() + kRtxHeaderSize);
+  // Add OSN (original sequence number).
+  rtx_payload[0] = packet.SequenceNumber() >> 8;
+  rtx_payload[1] = packet.SequenceNumber();
+
+  // Add original payload data.
+  if (!original_packet.payload().empty()) {
+    memcpy(rtx_payload + kRtxHeaderSize, original_packet.payload().data(),
+           original_packet.payload().size());
+  }
+  return packet;
+}
+
 }  // namespace
 
 #define EXPECT_FRAME_WAIT(c, w, h, t)                        \
@@ -900,6 +941,50 @@
 
   channel->SetInterface(nullptr);
 }
+
+TEST_F(WebRtcVideoEngineTest, UpdatesUnsignaledRtxSsrcAndRecoversPayload) {
+  // Setup a channel with VP8, RTX and transport sequence number header
+  // extension. Receive stream is not explicitly configured.
+  AddSupportedVideoCodecType("VP8");
+  std::vector<VideoCodec> supported_codecs =
+      engine_.recv_codecs(/*include_rtx=*/true);
+  ASSERT_EQ(supported_codecs[1].name, "rtx");
+  int rtx_payload_type = supported_codecs[1].id;
+
+  std::unique_ptr<VideoMediaChannel> channel(engine_.CreateMediaChannel(
+      call_.get(), GetMediaConfig(), VideoOptions(), webrtc::CryptoOptions(),
+      video_bitrate_allocator_factory_.get()));
+  cricket::VideoRecvParameters parameters;
+  parameters.codecs = supported_codecs;
+  ASSERT_TRUE(channel->SetRecvParameters(parameters));
+
+  // Receive a normal payload packet. It is not a complete frame since the
+  // marker bit is not set.
+  RtpPacketReceived packet_1 =
+      BuildVp8KeyFrame(/*ssrc*/ 123, supported_codecs[0].id);
+  packet_1.SetMarker(false);
+  channel->AsVideoReceiveChannel()->OnPacketReceived(packet_1);
+
+  time_controller_.AdvanceTime(webrtc::TimeDelta::Millis(100));
+  // No complete frame received. No decoder created yet.
+  EXPECT_THAT(decoder_factory_->decoders(), IsEmpty());
+
+  RtpPacketReceived packet_2;
+  packet_2.SetSsrc(123);
+  packet_2.SetPayloadType(supported_codecs[0].id);
+  packet_2.SetSequenceNumber(packet_1.SequenceNumber() + 1);
+  memset(packet_2.AllocatePayload(500), 0, 1);
+  packet_2.SetMarker(true);  //  Frame is complete.
+  RtpPacketReceived rtx_packet =
+      BuildRtxPacket(345, rtx_payload_type, packet_2);
+
+  channel->AsVideoReceiveChannel()->OnPacketReceived(rtx_packet);
+
+  time_controller_.AdvanceTime(webrtc::TimeDelta::Millis(0));
+  ASSERT_THAT(decoder_factory_->decoders(), Not(IsEmpty()));
+  EXPECT_EQ(decoder_factory_->decoders()[0]->GetNumFramesReceived(), 1);
+}
+
 TEST_F(WebRtcVideoEngineTest, UsesSimulcastAdapterForVp8Factories) {
   AddSupportedVideoCodecType("VP8");
 
@@ -1528,23 +1613,7 @@
   }
 
   void DeliverKeyFrame(uint32_t ssrc) {
-    RtpPacketReceived packet;
-    packet.SetMarker(true);
-    packet.SetPayloadType(96);  // VP8
-    packet.SetSsrc(ssrc);
-
-    // VP8 Keyframe + 1 byte payload
-    uint8_t* buf_ptr = packet.AllocatePayload(11);
-    memset(buf_ptr, 0, 11);  // Pass MSAN (don't care about bytes 1-9)
-    buf_ptr[0] = 0x10;       // Partition ID 0 + beginning of partition.
-    constexpr unsigned width = 1080;
-    constexpr unsigned height = 720;
-    buf_ptr[6] = width & 255;
-    buf_ptr[7] = width >> 8;
-    buf_ptr[8] = height & 255;
-    buf_ptr[9] = height >> 8;
-
-    channel_->OnPacketReceived(packet);
+    channel_->OnPacketReceived(BuildVp8KeyFrame(ssrc, 96));
   }
 
   void DeliverKeyFrameAndWait(uint32_t ssrc) {
@@ -7227,8 +7296,7 @@
                                   false /* expect_created_receive_stream */);
 }
 
-TEST_F(WebRtcVideoChannelTest,
-       RtxAfterMediaPacketRecreatesUnsignalledStream) {
+TEST_F(WebRtcVideoChannelTest, RtxAfterMediaPacketUpdatesUnsignalledRtxSsrc) {
   AssignDefaultAptRtxTypes();
   const cricket::VideoCodec vp8 = GetEngineCodec("VP8");
   const int payload_type = vp8.id;
@@ -7253,13 +7321,14 @@
   EXPECT_EQ(1u, fake_call_->GetVideoReceiveStreams().size())
       << "RTX packet should not have added or removed a receive stream";
 
-  // Check receive stream has been recreated with correct ssrcs.
   auto recv_stream = fake_call_->GetVideoReceiveStreams().front();
   auto& config = recv_stream->GetConfig();
   EXPECT_EQ(config.rtp.remote_ssrc, ssrc)
       << "Receive stream should have correct media ssrc";
   EXPECT_EQ(config.rtp.rtx_ssrc, rtx_ssrc)
       << "Receive stream should have correct rtx ssrc";
+  EXPECT_EQ(fake_call_->GetDeliveredPacketsForSsrc(ssrc), 1u);
+  EXPECT_EQ(fake_call_->GetDeliveredPacketsForSsrc(rtx_ssrc), 1u);
 }
 
 TEST_F(WebRtcVideoChannelTest,
diff --git a/video/video_receive_stream2.cc b/video/video_receive_stream2.cc
index 15273b3..9cc78c7 100644
--- a/video/video_receive_stream2.cc
+++ b/video/video_receive_stream2.cc
@@ -255,7 +255,7 @@
       max_wait_for_keyframe_, max_wait_for_frame_, std::move(scheduler),
       call_->trials());
 
-  if (rtx_ssrc()) {
+  if (!config_.rtp.rtx_associated_payload_types.empty()) {
     rtx_receive_stream_ = std::make_unique<RtxReceiveStream>(
         &rtp_video_stream_receiver_,
         std::move(config_.rtp.rtx_associated_payload_types), remote_ssrc(),
@@ -278,6 +278,7 @@
   RTC_DCHECK_RUN_ON(&packet_sequence_checker_);
   RTC_DCHECK(!media_receiver_);
   RTC_DCHECK(!rtx_receiver_);
+  receiver_controller_ = receiver_controller;
 
   // Register with RtpStreamReceiverController.
   media_receiver_ = receiver_controller->CreateReceiver(
@@ -293,6 +294,7 @@
   RTC_DCHECK_RUN_ON(&packet_sequence_checker_);
   media_receiver_.reset();
   rtx_receiver_.reset();
+  receiver_controller_ = nullptr;
 }
 
 const std::string& VideoReceiveStream2::sync_group() const {
@@ -508,14 +510,7 @@
 void VideoReceiveStream2::SetAssociatedPayloadTypes(
     std::map<int, int> associated_payload_types) {
   RTC_DCHECK_RUN_ON(&packet_sequence_checker_);
-
-  // For setting the associated payload types after construction, we currently
-  // assume that the rtx_ssrc cannot change. In such a case we can know that
-  // if the ssrc is non-0, a `rtx_receive_stream_` instance has previously been
-  // created and configured (and is referenced by `rtx_receiver_`) and we can
-  // simply reconfigure it.
-  // If rtx_ssrc is 0 however, we ignore this call.
-  if (!rtx_ssrc())
+  if (!rtx_receive_stream_)
     return;
 
   rtx_receive_stream_->SetAssociatedPayloadTypes(
@@ -1075,5 +1070,15 @@
   keyframe_generation_requested_ = true;
 }
 
+void VideoReceiveStream2::UpdateRtxSsrc(uint32_t ssrc) {
+  RTC_DCHECK_RUN_ON(&packet_sequence_checker_);
+  RTC_DCHECK(rtx_receive_stream_);
+
+  rtx_receiver_.reset();
+  updated_rtx_ssrc_ = ssrc;
+  rtx_receiver_ = receiver_controller_->CreateReceiver(
+      rtx_ssrc(), rtx_receive_stream_.get());
+}
+
 }  // namespace internal
 }  // namespace webrtc
diff --git a/video/video_receive_stream2.h b/video/video_receive_stream2.h
index 5c3572d..ef4f900 100644
--- a/video/video_receive_stream2.h
+++ b/video/video_receive_stream2.h
@@ -16,6 +16,7 @@
 #include <string>
 #include <vector>
 
+#include "absl/types/optional.h"
 #include "api/sequence_checker.h"
 #include "api/task_queue/pending_task_safety_flag.h"
 #include "api/task_queue/task_queue_factory.h"
@@ -127,7 +128,11 @@
   // Getters for const remote SSRC values that won't change throughout the
   // object's lifetime.
   uint32_t remote_ssrc() const { return config_.rtp.remote_ssrc; }
-  uint32_t rtx_ssrc() const { return config_.rtp.rtx_ssrc; }
+  // RTX ssrc can be updated.
+  uint32_t rtx_ssrc() const {
+    RTC_DCHECK_RUN_ON(&packet_sequence_checker_);
+    return updated_rtx_ssrc_.value_or(config_.rtp.rtx_ssrc);
+  }
 
   void SignalNetworkState(NetworkState state);
   bool DeliverRtcp(const uint8_t* packet, size_t length);
@@ -191,6 +196,8 @@
                                          bool generate_key_frame) override;
   void GenerateKeyFrame() override;
 
+  void UpdateRtxSsrc(uint32_t ssrc) override;
+
  private:
   // FrameSchedulingReceiver implementation.
   // Called on packet sequence.
@@ -274,10 +281,17 @@
 
   std::unique_ptr<VideoStreamBufferController> buffer_;
 
+  // `receiver_controller_` is valid from when RegisterWithTransport is invoked
+  //  until UnregisterFromTransport.
+  RtpStreamReceiverControllerInterface* receiver_controller_
+      RTC_GUARDED_BY(packet_sequence_checker_) = nullptr;
+
   std::unique_ptr<RtpStreamReceiverInterface> media_receiver_
       RTC_GUARDED_BY(packet_sequence_checker_);
   std::unique_ptr<RtxReceiveStream> rtx_receive_stream_
       RTC_GUARDED_BY(packet_sequence_checker_);
+  absl::optional<uint32_t> updated_rtx_ssrc_
+      RTC_GUARDED_BY(packet_sequence_checker_);
   std::unique_ptr<RtpStreamReceiverInterface> rtx_receiver_
       RTC_GUARDED_BY(packet_sequence_checker_);