Refactor RTP Header Extension Management into RtpTransport

Moves the caching, validation, and history tracking of RTP header
extensions from BaseChannel to RtpTransport. This aligns with the BUNDLE
model and simplifies the concurrency model by handling this on the
network thread.

Bug: webrtc:42222117
Change-Id: Iae4c65c938ee085d242cc3a27b5abaff5d7471de
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/455460
Reviewed-by: Harald Alvestrand <hta@webrtc.org>
Commit-Queue: Tomas Gunnarsson <tommi@webrtc.org>
Cr-Commit-Position: refs/heads/main@{#47275}
diff --git a/pc/BUILD.gn b/pc/BUILD.gn
index ca8dcd8..1f19fdc 100644
--- a/pc/BUILD.gn
+++ b/pc/BUILD.gn
@@ -103,6 +103,7 @@
     "../rtc_base/containers:flat_set",
     "../rtc_base/network:sent_packet",
     "//third_party/abseil-cpp/absl/algorithm:container",
+    "//third_party/abseil-cpp/absl/cleanup",
     "//third_party/abseil-cpp/absl/functional:any_invocable",
     "//third_party/abseil-cpp/absl/strings:string_view",
   ]
@@ -588,6 +589,8 @@
     ":rtp_transport_internal",
     ":session_description",
     "../api:field_trials_view",
+    "../api:rtc_error",
+    "../api:rtp_parameters",
     "../api:sequence_checker",
     "../api/task_queue",
     "../api/task_queue:pending_task_safety_flag",
@@ -606,11 +609,13 @@
     "../rtc_base:macromagic",
     "../rtc_base:network_route",
     "../rtc_base:socket",
+    "../rtc_base:stringutils",
     "../rtc_base/containers:flat_map",
     "../rtc_base/containers:flat_set",
     "../rtc_base/network:received_packet",
     "../rtc_base/network:sent_packet",
     "../rtc_base/system:no_unique_address",
+    "//third_party/abseil-cpp/absl/algorithm:container",
     "//third_party/abseil-cpp/absl/strings:string_view",
   ]
 }
@@ -623,6 +628,7 @@
   sources = [ "rtp_transport_internal.h" ]
   deps = [
     ":session_description",
+    "../api:rtc_error",
     "../api/task_queue:pending_task_safety_flag",
     "../api/transport:ecn_marking",
     "../api/units:timestamp",
diff --git a/pc/channel.cc b/pc/channel.cc
index 5761f2a..1774778 100644
--- a/pc/channel.cc
+++ b/pc/channel.cc
@@ -11,7 +11,6 @@
 #include "pc/channel.h"
 
 #include <algorithm>
-#include <bitset>
 #include <cstdint>
 #include <memory>
 #include <optional>
@@ -20,6 +19,7 @@
 #include <vector>
 
 #include "absl/algorithm/container.h"
+#include "absl/cleanup/cleanup.h"
 #include "absl/functional/any_invocable.h"
 #include "absl/strings/string_view.h"
 #include "api/crypto/crypto_options.h"
@@ -267,23 +267,6 @@
 
   if (rtp_transport_) {
     DisconnectFromRtpTransport_n();
-    // Clear the cached header extensions on the worker.
-    // If the network and worker thread pointers are configured to map to the
-    // same thread object, we'll do this synchronously. To start with, we're on
-    // the correct thread anyway, but an important second reason is that other
-    // parts of the code (SetLocalContent_w, SetRemoteContent_w) may execute a
-    // BlockingCall that touches `rtp_header_extensions` which, for the case
-    // where the threads are the same, will be executed before the lambda in the
-    // PostTask and not after, which may lead to unexpected behavior.
-    if (worker_thread_ == network_thread_) {
-      RTC_DCHECK_RUN_ON(worker_thread());
-      rtp_header_extensions_.clear();
-    } else {
-      worker_thread_->PostTask(SafeTask(alive_, [this] {
-        RTC_DCHECK_RUN_ON(worker_thread());
-        rtp_header_extensions_.clear();
-      }));
-    }
   }
 
   RTC_DCHECK(!rtp_transport_);
@@ -578,50 +561,27 @@
     std::optional<flat_set<uint8_t>> payload_types,
     const RtpHeaderExtensions& extensions,
     std::optional<flat_set<uint32_t>> ssrcs) {
-  bool update_extensions = true;
-  if (rtp_header_extensions_ == extensions) {
-    update_extensions = false;  // No need to update header extensions.
-  } else {
-    RTCError error = CheckRtpExtensionValidity(extensions);
-    if (!error.ok()) {
-      return error;
-    }
-    rtp_header_extensions_ = extensions;
-
-    for (const auto& extension : extensions) {
-      if (extension.id == 0)
-        continue;
-      if (absl::c_find_if(historical_rtp_header_extensions_,
-                          [&](const RtpExtension& ext) {
-                            return ext.id == extension.id;
-                          }) == historical_rtp_header_extensions_.end()) {
-        historical_rtp_header_extensions_.push_back(extension);
-      }
-    }
-  }
-
-  if (!update_demuxer && !update_extensions && !payload_types.has_value() &&
-      !ssrcs.has_value()) {
-    return RTCError::OK();  // No update needed.
-  }
-
-  // TODO(bugs.webrtc.org/13536): See if we can do this asynchronously.
-
-  if (update_demuxer || payload_types.has_value() || ssrcs.has_value()) {
+  const bool pending_update =
+      update_demuxer || payload_types.has_value() || ssrcs.has_value();
+  if (pending_update) {
     media_receive_channel()->OnDemuxerCriteriaUpdatePending();
   }
+  absl::Cleanup cleanup = [this, pending_update] {
+    if (pending_update) {
+      media_receive_channel()->OnDemuxerCriteriaUpdateComplete();
+    }
+  };
 
+  // TODO(bugs.webrtc.org/13536): See if we can do this asynchronously.
   RTCError error = network_thread()->BlockingCall([&]() -> RTCError {
     RTC_DCHECK_RUN_ON(network_thread());
-    if (!rtp_transport_) {
-      // To repro this situation, run the
-      // `ApplyDescriptionWithSameSsrcsBundledFails` test.
-      return LOG_ERROR(RTCError::InvalidState()
-                       << "No transport assigned for mid=" << mid());
-    }
-
-    if (update_extensions) {
-      rtp_transport_->RegisterRtpHeaderExtensionMap(mid(), extensions);
+    RTCError error =
+        rtp_transport_
+            ? rtp_transport_->RegisterRtpHeaderExtensionMap(mid(), extensions)
+            : (RTCError::InvalidState() << "No transport assigned.");
+    if (!error.ok()) {
+      error.string_builder() << " (mid=" << mid() << ")";
+      return LOG_ERROR(error);
     }
 
     if (payload_types) {
@@ -650,9 +610,6 @@
     return RTCError::OK();
   });
 
-  if (update_demuxer || payload_types.has_value() || ssrcs.has_value())
-    media_receive_channel()->OnDemuxerCriteriaUpdateComplete();
-
   return error;
 }
 
@@ -660,6 +617,9 @@
     bool clear_payload_types,
     std::optional<flat_set<uint32_t>> ssrcs) {
   media_receive_channel()->OnDemuxerCriteriaUpdatePending();
+  absl::Cleanup cleanup = [this] {
+    media_receive_channel()->OnDemuxerCriteriaUpdateComplete();
+  };
   bool ret = network_thread_->BlockingCall([&] {
     RTC_DCHECK_RUN_ON(network_thread());
     if (!rtp_transport_) {
@@ -694,8 +654,6 @@
     return rtp_transport_->RegisterRtpDemuxerSink(criteria, this);
   });
 
-  media_receive_channel()->OnDemuxerCriteriaUpdateComplete();
-
   return ret;
 }
 
@@ -928,41 +886,24 @@
                                                    extensions_filter_);
 }
 
+// TODO: webrtc:42222117 - Move header extension logic in the channel classes
+// to the network thread. At the moment, this function does a BlockingCall
+// to the network thread in order to delegate the check to the transport.
+// The worker and network threads are commonly configured to map to the same
+// actual thread, so a blocking call in those cases isn't expensive, although
+// not ideal.
 RTCError BaseChannel::CheckRtpExtensionValidity(
     const RtpHeaderExtensions& extensions) const {
-  std::bitset<1 + RtpExtension::kMaxId> id_used;
-  for (const auto& extension : extensions) {
-    if (extension.id == 0)
-      continue;
-    if (extension.id < RtpExtension::kMinId ||
-        extension.id > RtpExtension::kMaxId) {
-      return RTCError::InvalidParameter()
-             << "Bad RTP extension ID: " << extension.ToString();
+  return network_thread()->BlockingCall([&]() -> RTCError {
+    RTC_DCHECK_RUN_ON(network_thread());
+    RTCError error =
+        rtp_transport_ ? rtp_transport_->VerifyRtpHeaderExtensionMap(extensions)
+                       : (RTCError::InvalidState() << "No transport assigned.");
+    if (!error.ok()) {
+      error.string_builder() << " (mid=" << mid() << ")";
     }
-    if (id_used[extension.id]) {
-      return RTCError::InvalidParameter()
-             << "Duplicate RTP extension ID: " << extension.ToString();
-    }
-    id_used[extension.id] = true;
-  }
-
-  for (const auto& new_extension : extensions) {
-    if (new_extension.id == 0)
-      continue;
-    auto it = absl::c_find_if(
-        historical_rtp_header_extensions_,
-        [&](const RtpExtension& ext) { return ext.id == new_extension.id; });
-    if (it != historical_rtp_header_extensions_.end() &&
-        it->uri != new_extension.uri) {
-      return RTCError::InvalidParameter()
-             << "Failed to update RTP header extensions for m-section with "
-             << "mid='" << mid() << "'. RTP extension ID reassignment from "
-             << it->uri << " to " << new_extension.uri << " for ID "
-             << new_extension.id << ".";
-    }
-  }
-
-  return RTCError::OK();
+    return error;
+  });
 }
 
 void BaseChannel::SignalSentPacket_n(const SentPacketInfo& sent_packet) {
@@ -1186,8 +1127,6 @@
                                          SdpType type) {
   TRACE_EVENT0("webrtc", "VideoChannel::SetLocalContent_w");
 
-  RTC_LOG_THREAD_BLOCK_COUNT();
-
   RtpHeaderExtensions header_extensions =
       GetDeduplicatedRtpHeaderExtensions(content->rtp_header_extensions());
 
@@ -1196,6 +1135,8 @@
     return error;
   }
 
+  RTC_LOG_THREAD_BLOCK_COUNT();
+
   // TODO: issues.webrtc.org/383078466 - remove if pushdown on answer is enough.
   media_send_channel()->SetExtmapAllowMixed(content->extmap_allow_mixed());
 
@@ -1256,9 +1197,13 @@
 
   RTC_DCHECK_BLOCK_COUNT_NO_MORE_THAN(0);
 
-  return MaybeUpdateDemuxerAndRtpExtensions_w(
+  error = MaybeUpdateDemuxerAndRtpExtensions_w(
       /*update_demuxer=*/false, std::move(payload_types),
       recv_params.extensions, /*ssrcs=*/std::nullopt);
+
+  RTC_DCHECK_BLOCK_COUNT_NO_MORE_THAN(1);
+
+  return error;
 }
 
 RTCError VideoChannel::SetRemoteContent_w(
diff --git a/pc/channel.h b/pc/channel.h
index 6b4438a..fcb0126 100644
--- a/pc/channel.h
+++ b/pc/channel.h
@@ -354,13 +354,6 @@
 
   // Cached list of payload types, used if payload type demuxing is re-enabled.
   flat_set<uint8_t> payload_types_ RTC_GUARDED_BY(network_thread());
-  // A stored copy of the rtp header extensions as applied to the transport.
-  RtpHeaderExtensions rtp_header_extensions_ RTC_GUARDED_BY(worker_thread());
-
-  // Set of all historic RTP header extensions mapped, keyed by URI,
-  // to ensure no ID-URI reassignment occurs per RFC 8285.
-  RtpHeaderExtensions historical_rtp_header_extensions_
-      RTC_GUARDED_BY(worker_thread());
 
   const std::string mid_;
   flat_set<uint32_t> ssrcs_ RTC_GUARDED_BY(network_thread());
diff --git a/pc/channel_unittest.cc b/pc/channel_unittest.cc
index aae219f..1705685 100644
--- a/pc/channel_unittest.cc
+++ b/pc/channel_unittest.cc
@@ -69,6 +69,7 @@
 using ::testing::AllOf;
 using ::testing::ElementsAre;
 using ::testing::Field;
+using ::testing::HasSubstr;
 using ::webrtc::CreateTestFieldTrials;
 using ::webrtc::DtlsTransportInternal;
 using ::webrtc::FakeVoiceMediaReceiveChannel;
@@ -708,8 +709,7 @@
     CreateChannels(0, 0);
     RTCError error = channel1_->SetLocalContent(&local, SdpType::kOffer);
     EXPECT_FALSE(error.ok());
-    EXPECT_THAT(error.message(),
-                ::testing::HasSubstr("Duplicate RTP extension ID"));
+    EXPECT_THAT(error.message(), HasSubstr("Duplicate extension ID"));
   }
 
   void TestInvalidRtpHeaderExtensionIds() {
@@ -722,7 +722,7 @@
     CreateChannels(0, 0);
     RTCError error = channel1_->SetLocalContent(&local, SdpType::kOffer);
     EXPECT_FALSE(error.ok());
-    EXPECT_THAT(error.message(), ::testing::HasSubstr("Bad RTP extension ID"));
+    EXPECT_THAT(error.message(), HasSubstr("Bad extension ID"));
   }
 
   void TestRtpHeaderExtensionIdReassignment() {
@@ -743,8 +743,7 @@
     RTCError error =
         channel1_->SetLocalContent(&local_updated, SdpType::kOffer);
     EXPECT_FALSE(error.ok());
-    EXPECT_THAT(error.message(),
-                ::testing::HasSubstr("RTP extension ID reassignment"));
+    EXPECT_THAT(error.message(), HasSubstr("RTP extension ID reassignment"));
   }
 
   void TestRtpHeaderExtensionIdHistoryReassignment() {
@@ -770,8 +769,7 @@
     RTCError error =
         channel1_->SetLocalContent(&local_updated, SdpType::kOffer);
     EXPECT_FALSE(error.ok());
-    EXPECT_THAT(error.message(),
-                ::testing::HasSubstr("RTP extension ID reassignment"));
+    EXPECT_THAT(error.message(), HasSubstr("RTP extension ID reassignment"));
   }
 
   // Test that SetLocalContent and SetRemoteContent properly configure
diff --git a/pc/rtp_transport.cc b/pc/rtp_transport.cc
index c23046f..20db5d3 100644
--- a/pc/rtp_transport.cc
+++ b/pc/rtp_transport.cc
@@ -11,6 +11,8 @@
 #include "pc/rtp_transport.h"
 
 #include <algorithm>
+#include <bitset>
+#include <cstddef>
 #include <cstdint>
 #include <memory>
 #include <optional>
@@ -18,7 +20,10 @@
 #include <utility>
 #include <vector>
 
+#include "absl/algorithm/container.h"
 #include "absl/strings/string_view.h"
+#include "api/rtc_error.h"
+#include "api/rtp_parameters.h"
 #include "api/sequence_checker.h"
 #include "api/task_queue/pending_task_safety_flag.h"
 #include "api/task_queue/task_queue_base.h"
@@ -55,6 +60,28 @@
   }
 }
 
+RTCError VerifyExtensionIds(const RtpHeaderExtensions& extensions) {
+  using ExtensionsUsed = std::bitset<1 + RtpExtension::kMaxId>;
+  ExtensionsUsed id_used;
+  for (const auto& extension : extensions) {
+    if (extension.id == 0) {
+      continue;
+    }
+    if (extension.id < RtpExtension::kMinId ||
+        extension.id > RtpExtension::kMaxId) {
+      return RTCError::InvalidParameter()
+             << "Bad extension ID: " << extension.ToString();
+    }
+    ExtensionsUsed::reference entry = id_used[extension.id];
+    if (entry) {
+      return RTCError::InvalidParameter()
+             << "Duplicate extension ID: " << extension.ToString();
+    }
+    entry = true;
+  }
+  return RTCError::OK();
+}
+
 }  // namespace
 
 void RtpTransport::SetRtcpMuxEnabled(bool enable) {
@@ -202,14 +229,72 @@
   return true;
 }
 
-void RtpTransport::RegisterRtpHeaderExtensionMap(
-    absl::string_view mid,
-    const RtpHeaderExtensions& header_extensions) {
+RTCError RtpTransport::VerifyRtpHeaderExtensionMap(
+    const RtpHeaderExtensions& extensions) const {
   RTC_DCHECK_RUN_ON(&network_thread_checker_);
+
+  RTCError error = VerifyExtensionIds(extensions);
+  if (!error.ok()) {
+    return error;
+  }
+
+  for (const auto& new_extension : extensions) {
+    if (new_extension.id == 0) {
+      continue;
+    }
+    auto it = absl::c_find_if(
+        historical_rtp_header_extensions_,
+        [&](const RtpExtension& ext) { return ext.id == new_extension.id; });
+    if (it != historical_rtp_header_extensions_.end() &&
+        it->uri != new_extension.uri) {
+      return RTCError::InvalidParameter()
+             << "RTP extension ID reassignment not supported (id="
+             << new_extension.id << ", old_uri=\"" << it->uri
+             << "\", new_uri=\"" << new_extension.uri << "\").";
+    }
+  }
+
+  return RTCError::OK();
+}
+
+RTCError RtpTransport::RegisterRtpHeaderExtensionMap(
+    absl::string_view mid,
+    const RtpHeaderExtensions& extensions) {
+  RTC_DCHECK_RUN_ON(&network_thread_checker_);
+
+  RTCError error = VerifyRtpHeaderExtensionMap(extensions);
+  if (!error.ok()) {
+    return error;
+  }
+
+  auto existing_extensions =
+      absl::c_find_if(header_extensions_by_mid_,
+                      [mid](const auto& kv) { return kv.first == mid; });
+  if (existing_extensions != header_extensions_by_mid_.end() &&
+      existing_extensions->second == extensions) {
+    return RTCError::OK();
+  }
+
+  for (const RtpExtension& extension : extensions) {
+    if (extension.id == 0) {
+      continue;
+    }
+    auto it = absl::c_find_if(historical_rtp_header_extensions_,
+                              [&extension](const RtpExtension& ext) {
+                                return ext.id == extension.id;
+                              });
+    if (it == historical_rtp_header_extensions_.end()) {
+      historical_rtp_header_extensions_.push_back(extension);
+    } else {
+      RTC_DCHECK_EQ(it->uri, extension.uri);
+    }
+  }
+
   RemoveExtensionMapForMid(mid, header_extensions_by_mid_);
-  header_extensions_by_mid_.emplace_back(std::string(mid), header_extensions);
+  header_extensions_by_mid_.emplace_back(std::string(mid), extensions);
 
   RebuildMergedMap();
+  return RTCError::OK();
 }
 
 void RtpTransport::UnregisterRtpHeaderExtensionMap(absl::string_view mid) {
diff --git a/pc/rtp_transport.h b/pc/rtp_transport.h
index 4e1fd93..1b0d65b 100644
--- a/pc/rtp_transport.h
+++ b/pc/rtp_transport.h
@@ -22,6 +22,7 @@
 
 #include "absl/strings/string_view.h"
 #include "api/field_trials_view.h"
+#include "api/rtc_error.h"
 #include "api/sequence_checker.h"
 #include "api/task_queue/pending_task_safety_flag.h"
 #include "api/transport/ecn_marking.h"
@@ -92,9 +93,12 @@
 
   bool IsSrtpActive() const override { return false; }
 
-  void RegisterRtpHeaderExtensionMap(
+  RTCError VerifyRtpHeaderExtensionMap(
+      const RtpHeaderExtensions& extensions) const override;
+
+  RTCError RegisterRtpHeaderExtensionMap(
       absl::string_view mid,
-      const RtpHeaderExtensions& header_extensions) override;
+      const RtpHeaderExtensions& extensions) override;
 
   // Currently only used for testing. In production, unregistration isn't needed
   // because leaving the registered extensions in `RtpTransport` is harmless
@@ -175,6 +179,9 @@
   std::vector<std::pair<std::string, RtpHeaderExtensions>>
       header_extensions_by_mid_ RTC_GUARDED_BY(network_thread_checker_);
 
+  RtpHeaderExtensions historical_rtp_header_extensions_
+      RTC_GUARDED_BY(network_thread_checker_);
+
   // Guard against recursive "ready to send" signals
   bool processing_ready_to_send_ = false;
   RTC_NO_UNIQUE_ADDRESS SequenceChecker network_thread_checker_;
diff --git a/pc/rtp_transport_internal.h b/pc/rtp_transport_internal.h
index a2c9967..96aaa3e 100644
--- a/pc/rtp_transport_internal.h
+++ b/pc/rtp_transport_internal.h
@@ -17,6 +17,7 @@
 
 #include "absl/functional/any_invocable.h"
 #include "absl/strings/string_view.h"
+#include "api/rtc_error.h"
 #include "api/task_queue/pending_task_safety_flag.h"
 #include "api/transport/ecn_marking.h"
 #include "api/units/timestamp.h"
@@ -134,9 +135,12 @@
   //   UpdateSendEncryptedHeaderExtensionIds,
   //   UpdateRecvEncryptedHeaderExtensionIds,
   //   CacheRtpAbsSendTimeHeaderExtension,
-  virtual void RegisterRtpHeaderExtensionMap(
+  virtual RTCError RegisterRtpHeaderExtensionMap(
       absl::string_view mid,
-      const RtpHeaderExtensions& header_extensions) = 0;
+      const RtpHeaderExtensions& extensions) = 0;
+
+  virtual RTCError VerifyRtpHeaderExtensionMap(
+      const RtpHeaderExtensions& extensions) const = 0;
 
   virtual void UnregisterRtpHeaderExtensionMap(absl::string_view mid) = 0;
 
diff --git a/pc/rtp_transport_unittest.cc b/pc/rtp_transport_unittest.cc
index 4c2aca6..adbe708 100644
--- a/pc/rtp_transport_unittest.cc
+++ b/pc/rtp_transport_unittest.cc
@@ -14,6 +14,7 @@
 #include <cstdint>
 #include <optional>
 
+#include "api/rtc_error.h"
 #include "api/rtp_parameters.h"
 #include "api/test/rtc_error_matchers.h"
 #include "api/transport/ecn_marking.h"
@@ -44,6 +45,7 @@
 
 using ::testing::Eq;
 using ::testing::Ge;
+using ::testing::HasSubstr;
 
 constexpr bool kMuxDisabled = false;
 constexpr bool kMuxEnabled = true;
@@ -365,6 +367,30 @@
   transport.UnregisterRtpDemuxerSink(&observer);
 }
 
+TEST(RtpTransportTest, VerifyRtpHeaderExtensionMapRejectsIdReassignment) {
+  test::RunLoop loop;
+  RtpTransport transport(kMuxDisabled, CreateTestFieldTrials());
+  RtpHeaderExtensions extensions1 = {
+      RtpExtension("urn:ietf:params:rtp-hdrext:ssrc-audio-level", 1)};
+  RtpHeaderExtensions extensions2 = {RtpExtension(
+      "http://www.webrtc.org/experiments/rtp-hdrext/abs-send-time", 1)};
+
+  // Registering the first map should succeed.
+  EXPECT_TRUE(
+      transport.RegisterRtpHeaderExtensionMap("audio", extensions1).ok());
+
+  // Verifying a map that tries to reassign ID 1 to a different URI should fail.
+  RTCError error = transport.VerifyRtpHeaderExtensionMap(extensions2);
+  EXPECT_FALSE(error.ok());
+  EXPECT_EQ(error.type(), RTCErrorType::INVALID_PARAMETER);
+  EXPECT_THAT(error.message(), HasSubstr("RTP extension ID reassignment"));
+
+  // Registering the second map should also fail.
+  error = transport.RegisterRtpHeaderExtensionMap("video", extensions2);
+  EXPECT_FALSE(error.ok());
+  EXPECT_EQ(error.type(), RTCErrorType::INVALID_PARAMETER);
+}
+
 // Test that SignalPacketReceived fires with rtcp=true when a RTCP packet is
 // received.
 TEST(RtpTransportTest, SignalDemuxedRtcp) {