Propagate negotiated header extension to mediachannel on answer

Propagates the subset of header extensions to use for both audio and
video media send/receive channels when an answer has been accepted.

Bug: webrtc:383078466
Change-Id: I791ba828114a480bac7bb324a56dd93f59316c39
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/405241
Reviewed-by: Harald Alvestrand <hta@webrtc.org>
Commit-Queue: Per Kjellander <perkj@webrtc.org>
Cr-Commit-Position: refs/heads/main@{#45391}
diff --git a/pc/channel.cc b/pc/channel.cc
index 5513539..173a3c3 100644
--- a/pc/channel.cc
+++ b/pc/channel.cc
@@ -938,6 +938,7 @@
   RtpHeaderExtensions header_extensions =
       GetDeduplicatedRtpHeaderExtensions(content->rtp_header_extensions());
   bool update_header_extensions = true;
+  // TODO: issues.webrtc.org/383078466 - remove if pushdown on answer is enough.
   media_send_channel()->SetExtmapAllowMixed(content->extmap_allow_mixed());
 
   AudioReceiverParameters recv_params = last_recv_params_;
@@ -965,6 +966,19 @@
 
   last_recv_params_ = recv_params;
 
+  if (type == SdpType::kAnswer || type == SdpType::kPrAnswer) {
+    AudioSenderParameter send_params = last_send_params_;
+    send_params.extensions = header_extensions;
+    send_params.extmap_allow_mixed = content->extmap_allow_mixed();
+    if (!media_send_channel()->SetSenderParameters(send_params)) {
+      error_desc = StringFormat(
+          "Failed to set send parameters for m-section with mid='%s'.",
+          mid().c_str());
+      return false;
+    }
+    last_send_params_ = send_params;
+  }
+
   if (!UpdateLocalStreams_w(content->streams(), type, error_desc)) {
     RTC_DCHECK(!error_desc.empty());
     return false;
@@ -1009,6 +1023,18 @@
         mid().c_str());
     return false;
   }
+
+  if (type == SdpType::kAnswer || type == SdpType::kPrAnswer) {
+    AudioReceiverParameters recv_params = last_recv_params_;
+    recv_params.extensions = send_params.extensions;
+    if (!media_receive_channel()->SetReceiverParameters(recv_params)) {
+      error_desc = StringFormat(
+          "Failed to set recv parameters for m-section with mid='%s'.",
+          mid().c_str());
+      return false;
+    }
+    last_recv_params_ = recv_params;
+  }
   // The receive channel can send RTCP packets in the reverse direction. It
   // should use the reduced size mode if a peer has requested it through the
   // remote content.
@@ -1079,14 +1105,13 @@
                                      SdpType type,
                                      std::string& error_desc) {
   TRACE_EVENT0("webrtc", "VideoChannel::SetLocalContent_w");
-  RTC_DLOG(LS_INFO) << "Setting local video description for " << ToString();
 
   RTC_LOG_THREAD_BLOCK_COUNT();
 
   RtpHeaderExtensions header_extensions =
       GetDeduplicatedRtpHeaderExtensions(content->rtp_header_extensions());
   bool update_header_extensions = true;
-  // TODO: issues.webrtc.org/396640 - remove if pushdown on answer is enough.
+  // TODO: issues.webrtc.org/383078466 - remove if pushdown on answer is enough.
   media_send_channel()->SetExtmapAllowMixed(content->extmap_allow_mixed());
 
   VideoReceiverParameters recv_params = last_recv_params_;
@@ -1096,6 +1121,7 @@
       RtpTransceiverDirectionHasRecv(content->direction()), &recv_params);
 
   VideoSenderParameters send_params = last_send_params_;
+  send_params.extensions = header_extensions;
   send_params.extmap_allow_mixed = content->extmap_allow_mixed();
 
   // Ensure that there is a matching packetization for each send codec. If the
diff --git a/pc/channel_unittest.cc b/pc/channel_unittest.cc
index f4cf34f..7056b9f 100644
--- a/pc/channel_unittest.cc
+++ b/pc/channel_unittest.cc
@@ -76,6 +76,7 @@
 using ::webrtc::FieldTrials;
 using ::webrtc::RidDescription;
 using ::webrtc::RidDirection;
+using ::webrtc::RtpExtension;
 using ::webrtc::RtpTransceiverDirection;
 using ::webrtc::SdpType;
 using ::webrtc::StreamParams;
@@ -645,6 +646,65 @@
               media_send_channel1_impl()->send_codecs()[0]);
   }
 
+  void TestRemovesExtensionNotPresentInRemoteAnswer() {
+    typename T::Content local;
+    typename T::Content remote;
+    CreateContent(/*flags=*/0, kPcmuCodec, kH264Codec, &local);
+    CreateContent(/*flags=*/0, kPcmuCodec, kH264Codec, &remote);
+    local.set_rtp_header_extensions({
+        RtpExtension(RtpExtension::kTransportSequenceNumberUri, 0),
+        RtpExtension(RtpExtension::kVideoRotationUri, 1),
+    });
+    remote.set_rtp_header_extensions({
+        RtpExtension(RtpExtension::kVideoRotationUri, 1),
+    });
+
+    CreateChannels(0, 0);
+    std::string err;
+    ASSERT_TRUE(channel1_->SetLocalContent(&local, SdpType::kOffer, err))
+        << err;
+    ASSERT_TRUE(channel1_->SetRemoteContent(&remote, SdpType::kAnswer, err))
+        << err;
+
+    EXPECT_THAT(media_receive_channel1_impl()->recv_extensions(),
+                ElementsAre(AllOf(Field("id", &RtpExtension::id, 1),
+                                  Field("uri", &RtpExtension::uri,
+                                        RtpExtension::kVideoRotationUri))));
+    EXPECT_THAT(media_send_channel1_impl()->send_extensions(),
+                ElementsAre(AllOf(Field("id", &RtpExtension::id, 1),
+                                  Field("uri", &RtpExtension::uri,
+                                        RtpExtension::kVideoRotationUri))));
+  }
+  void TestRemovesExtensionNotPresentInLocalAnswer() {
+    typename T::Content local;
+    typename T::Content remote;
+    CreateContent(/*flags=*/0, kPcmuCodec, kH264Codec, &local);
+    CreateContent(/*flags=*/0, kPcmuCodec, kH264Codec, &remote);
+    local.set_rtp_header_extensions({
+        RtpExtension(RtpExtension::kVideoRotationUri, 1),
+    });
+    remote.set_rtp_header_extensions({
+        RtpExtension(RtpExtension::kTransportSequenceNumberUri, 0),
+        RtpExtension(RtpExtension::kVideoRotationUri, 1),
+    });
+
+    CreateChannels(0, 0);
+    std::string err;
+    ASSERT_TRUE(channel1_->SetRemoteContent(&remote, SdpType::kOffer, err))
+        << err;
+    ASSERT_TRUE(channel1_->SetLocalContent(&local, SdpType::kAnswer, err))
+        << err;
+
+    EXPECT_THAT(media_receive_channel1_impl()->recv_extensions(),
+                ElementsAre(AllOf(Field("id", &RtpExtension::id, 1),
+                                  Field("uri", &RtpExtension::uri,
+                                        RtpExtension::kVideoRotationUri))));
+    EXPECT_THAT(media_send_channel1_impl()->send_extensions(),
+                ElementsAre(AllOf(Field("id", &RtpExtension::id, 1),
+                                  Field("uri", &RtpExtension::uri,
+                                        RtpExtension::kVideoRotationUri))));
+  }
+
   // Test that SetLocalContent and SetRemoteContent properly configure
   // extmap-allow-mixed.
   void TestSetContentsExtmapAllowMixedCaller(bool offer, bool answer) {
@@ -1877,6 +1937,14 @@
   Base::SocketOptionsMergedOnSetTransport();
 }
 
+TEST_F(VoiceChannelSingleThreadTest, RemovesExtensionNotPresentInRemoteAnswer) {
+  Base::TestRemovesExtensionNotPresentInRemoteAnswer();
+}
+
+TEST_F(VoiceChannelSingleThreadTest, RemovesExtensionNotPresentInLocalAnswer) {
+  Base::TestRemovesExtensionNotPresentInLocalAnswer();
+}
+
 // VoiceChannelDoubleThreadTest
 TEST_F(VoiceChannelDoubleThreadTest, TestInit) {
   Base::TestInit();
@@ -2157,6 +2225,14 @@
   Base::TestUpdateLocalStreamsWithSimulcast();
 }
 
+TEST_F(VideoChannelSingleThreadTest, RemovesExtensionNotPresentInRemoteAnswer) {
+  Base::TestRemovesExtensionNotPresentInRemoteAnswer();
+}
+
+TEST_F(VideoChannelSingleThreadTest, RemovesExtensionNotPresentInLocalAnswer) {
+  Base::TestRemovesExtensionNotPresentInLocalAnswer();
+}
+
 TEST_F(VideoChannelSingleThreadTest, TestSetLocalOfferWithPacketization) {
   const webrtc::Codec kVp8Codec = webrtc::CreateVideoCodec(97, "VP8");
   webrtc::Codec vp9_codec = webrtc::CreateVideoCodec(98, "VP9");
diff --git a/pc/congestion_control_integrationtest.cc b/pc/congestion_control_integrationtest.cc
index 437186b..2888270 100644
--- a/pc/congestion_control_integrationtest.cc
+++ b/pc/congestion_control_integrationtest.cc
@@ -178,17 +178,26 @@
   ASSERT_TRUE(CreatePeerConnectionWrappers());
   ConnectFakeSignalingForSdpOnly();
   callee()->AddVideoTrack();
+  callee()->AddAudioTrack();
   // Add transceivers to caller in order to accomodate reception
   caller()->pc()->AddTransceiver(MediaType::VIDEO);
-  auto parameters = caller()->pc()->GetSenders()[0]->GetParameters();
+  caller()->pc()->AddTransceiver(MediaType::AUDIO);
+
   caller()->CreateAndSetAndSignalOffer();
   ASSERT_THAT(WaitUntil([&] { return SignalingStateStable(); }, IsTrue()),
               IsRtcOk());
 
-  std::vector<RtpHeaderExtensionCapability> negotiated_header_extensions =
-      caller()->pc()->GetTransceivers()[0]->GetNegotiatedHeaderExtensions();
+  ASSERT_THAT(caller()->pc()->GetTransceivers().size(), Eq(2));
   EXPECT_THAT(
-      negotiated_header_extensions,
+      caller()->pc()->GetTransceivers()[0]->GetNegotiatedHeaderExtensions(),
+      Not(Contains(
+          AllOf(Field("uri", &RtpHeaderExtensionCapability::uri,
+                      RtpExtension::kTransportSequenceNumberUri),
+                Not(Field("direction", &RtpHeaderExtensionCapability::direction,
+                          RtpTransceiverDirection::kStopped))))))
+      << " in caller negotiated header extensions";
+  EXPECT_THAT(
+      caller()->pc()->GetTransceivers()[1]->GetNegotiatedHeaderExtensions(),
       Not(Contains(
           AllOf(Field("uri", &RtpHeaderExtensionCapability::uri,
                       RtpExtension::kTransportSequenceNumberUri),
@@ -196,28 +205,51 @@
                           RtpTransceiverDirection::kStopped))))))
       << " in caller negotiated header extensions";
 
-  parameters = caller()->pc()->GetSenders()[0]->GetParameters();
-  EXPECT_THAT(parameters.header_extensions,
-              Not(Contains(Field("uri", &RtpExtension::uri,
-                                 RtpExtension::kTransportSequenceNumberUri))))
+  ASSERT_THAT(caller()->pc()->GetSenders().size(), Eq(2));
+  EXPECT_THAT(
+      caller()->pc()->GetSenders()[0]->GetParameters().header_extensions,
+      Not(Contains(Field("uri", &RtpExtension::uri,
+                         RtpExtension::kTransportSequenceNumberUri))))
       << " in caller sender parameters";
-  parameters = caller()->pc()->GetReceivers()[0]->GetParameters();
-  EXPECT_THAT(parameters.header_extensions,
-              Not(Contains(Field("uri", &RtpExtension::uri,
-                                 RtpExtension::kTransportSequenceNumberUri))))
+  EXPECT_THAT(
+      caller()->pc()->GetSenders()[1]->GetParameters().header_extensions,
+      Not(Contains(Field("uri", &RtpExtension::uri,
+                         RtpExtension::kTransportSequenceNumberUri))))
+      << " in caller sender parameters";
+  EXPECT_THAT(
+      caller()->pc()->GetReceivers()[0]->GetParameters().header_extensions,
+      Not(Contains(Field("uri", &RtpExtension::uri,
+                         RtpExtension::kTransportSequenceNumberUri))))
       << " in caller receiver parameters";
-  /* Callee senders are not fixed yet.
-     TODO: issues.webrtc.org/383078466 - enable
-  parameters = callee()->pc()->GetSenders()[0]->GetParameters();
-  EXPECT_THAT(parameters.header_extensions,
-              Not(Contains(Field("uri", &RtpExtension::uri,
-                                 RtpExtension::kTransportSequenceNumberUri))))
+  EXPECT_THAT(caller()->pc()->GetReceivers()[1]->media_type(),
+              Eq(MediaType::AUDIO));
+  EXPECT_THAT(
+      caller()->pc()->GetReceivers()[1]->GetParameters().header_extensions,
+      Not(Contains(Field("uri", &RtpExtension::uri,
+                         RtpExtension::kTransportSequenceNumberUri))))
+      << " in caller receiver parameters";
+
+  EXPECT_THAT(
+      callee()->pc()->GetSenders()[0]->GetParameters().header_extensions,
+      Not(Contains(Field("uri", &RtpExtension::uri,
+                         RtpExtension::kTransportSequenceNumberUri))))
       << " in callee sender parameters";
-  */
-  parameters = callee()->pc()->GetReceivers()[0]->GetParameters();
-  EXPECT_THAT(parameters.header_extensions,
-              Not(Contains(Field("uri", &RtpExtension::uri,
-                                 RtpExtension::kTransportSequenceNumberUri))))
+  EXPECT_THAT(
+      callee()->pc()->GetSenders()[1]->GetParameters().header_extensions,
+      Not(Contains(Field("uri", &RtpExtension::uri,
+                         RtpExtension::kTransportSequenceNumberUri))))
+      << " in callee sender parameters";
+
+  ASSERT_THAT(callee()->pc()->GetReceivers().size(), Eq(2));
+  EXPECT_THAT(
+      callee()->pc()->GetReceivers()[0]->GetParameters().header_extensions,
+      Not(Contains(Field("uri", &RtpExtension::uri,
+                         RtpExtension::kTransportSequenceNumberUri))))
+      << " in callee receiver parameters";
+  EXPECT_THAT(
+      callee()->pc()->GetReceivers()[1]->GetParameters().header_extensions,
+      Not(Contains(Field("uri", &RtpExtension::uri,
+                         RtpExtension::kTransportSequenceNumberUri))))
       << " in callee receiver parameters";
 }
 
diff --git a/test/peer_scenario/tests/l4s_test.cc b/test/peer_scenario/tests/l4s_test.cc
index 5f562ef..e81dfae 100644
--- a/test/peer_scenario/tests/l4s_test.cc
+++ b/test/peer_scenario/tests/l4s_test.cc
@@ -14,6 +14,7 @@
 #include <string>
 
 #include "absl/strings/str_cat.h"
+#include "api/audio_options.h"
 #include "api/jsep.h"
 #include "api/make_ref_counted.h"
 #include "api/rtc_error.h"
@@ -180,7 +181,9 @@
   PeerScenarioClient::VideoSendTrackConfig video_conf;
   video_conf.generator.squares_video->framerate = 15;
 
+  caller->CreateAudio("AUDIO_1", AudioOptions());
   caller->CreateVideo("VIDEO_1", video_conf);
+  callee->CreateAudio("AUDIO_2", AudioOptions());
   callee->CreateVideo("VIDEO_2", video_conf);
 
   signaling.StartIceSignaling();