Separate transceiver references from the stats info struct

Extract the transceiver and receivers members from the
RtpTransceiverStatsInfo struct into a new TransceiverReferences
struct.

The RtpTransceiverStatsInfo structure is used to transport statistics
data between the signaling, worker, and network threads. Because the
transceiver and receiver objects are signaling-thread-specific, they
previously had to be manually cleared from the struct before it was
posted to the network thread to avoid cross-thread access.

Separating these references into a distinct container makes the data
ownership and thread boundaries more explicit. This refactoring
removes the need for manual pointer nullification and reduces the
risk of accidental cross-thread access during the statistics
gathering process.

Bug: none
Change-Id: I29a5fe77074a139ee9437dbca52b5450c51fa673
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/449040
Commit-Queue: Tomas Gunnarsson <tommi@webrtc.org>
Reviewed-by: Harald Alvestrand <hta@webrtc.org>
Cr-Commit-Position: refs/heads/main@{#46906}
diff --git a/pc/rtc_stats_collector.cc b/pc/rtc_stats_collector.cc
index 8c8a865..15cb8646 100644
--- a/pc/rtc_stats_collector.cc
+++ b/pc/rtc_stats_collector.cc
@@ -1355,11 +1355,14 @@
   // `ProducePartialResultsOnSignalingThread`.
   auto worker_task = PrepareTransceiverStatsInfosAndCallStats_s_w();
 
-  auto signaling_task = [this](StatsGatheringResults results) mutable {
+  auto signaling_task = [this](WorkerThreadResult worker_result) mutable {
     RTC_DCHECK_RUN_ON(signaling_thread_);
     // Create the initial `partial_report_` for the gathering operation.
-    ProducePartialResultsOnSignalingThread(results.transceiver_stats_infos,
-                                           results.audio_device_stats);
+    ProducePartialResultsOnSignalingThread(
+        worker_result.results.transceiver_stats_infos,
+        worker_result.transceiver_references,
+        worker_result.results.audio_device_stats);
+
     Timestamp timestamp = collection_context_->partial_report->timestamp();
     std::set<std::string> transport_names;
     auto sctp_transport_name = pc_->sctp_transport_name();
@@ -1367,22 +1370,16 @@
       transport_names.emplace(std::move(*sctp_transport_name));
     }
 
-    for (const auto& info : results.transceiver_stats_infos) {
+    for (const auto& info : worker_result.results.transceiver_stats_infos) {
       if (info.transport_name)
         transport_names.insert(*info.transport_name);
     }
 
-    // Clear references to signaling thread objects (and proxies) before
-    // attempting to post to the network thread. We're done with these objects.
-    for (RtpTransceiverStatsInfo& stats : results.transceiver_stats_infos) {
-      stats.transceiver = nullptr;
-      stats.receivers.clear();
-    }
-
     network_thread_->PostTask(SafeTask(
-        network_safety_, [this, transport_names = std::move(transport_names),
-                          timestamp, signaling_flag = signaling_safety_,
-                          results = std::move(results)]() mutable {
+        network_safety_,
+        [this, transport_names = std::move(transport_names), timestamp,
+         signaling_flag = signaling_safety_,
+         results = std::move(worker_result.results)]() mutable {
           ProducePartialResultsOnNetworkThread(
               std::move(signaling_flag), timestamp, std::move(transport_names),
               std::move(results));
@@ -1392,12 +1389,13 @@
   worker_thread_->PostTask(SafeTask(
       worker_safety_, [this, worker_task = std::move(worker_task),
                        signaling_task = std::move(signaling_task)]() mutable {
-        auto results = std::move(worker_task)();
-        signaling_thread_->PostTask(SafeTask(
-            signaling_safety_, [signaling_task = std::move(signaling_task),
-                                results = std::move(results)]() mutable {
-              signaling_task(std::move(results));
-            }));
+        auto worker_result = std::move(worker_task)();
+        signaling_thread_->PostTask(
+            SafeTask(signaling_safety_,
+                     [signaling_task = std::move(signaling_task),
+                      worker_result = std::move(worker_result)]() mutable {
+                       signaling_task(std::move(worker_result));
+                     }));
       }));
 }
 
@@ -1424,6 +1422,7 @@
 
 void RTCStatsCollector::ProducePartialResultsOnSignalingThread(
     const std::vector<RtpTransceiverStatsInfo>& transceiver_stats_infos,
+    const std::vector<TransceiverReferences>& transceiver_references,
     const std::optional<AudioDeviceModule::Stats>& audio_device_stats) {
   RTC_DCHECK_RUN_ON(signaling_thread_);
   RTC_DCHECK(collection_context_);
@@ -1432,16 +1431,19 @@
 
   ProducePartialResultsOnSignalingThreadImpl(
       collection_context_->partial_report->timestamp(), transceiver_stats_infos,
-      audio_device_stats, collection_context_->partial_report.get());
+      transceiver_references, audio_device_stats,
+      collection_context_->partial_report.get());
 }
 
 void RTCStatsCollector::ProducePartialResultsOnSignalingThreadImpl(
     Timestamp timestamp,
     const std::vector<RtpTransceiverStatsInfo>& transceiver_stats_infos,
+    const std::vector<TransceiverReferences>& transceiver_references,
     const std::optional<AudioDeviceModule::Stats>& audio_device_stats,
     RTCStatsReport* partial_report) {
   RTC_DCHECK_RUN_ON(signaling_thread_);
-  ProduceMediaSourceStats_s(timestamp, transceiver_stats_infos, partial_report);
+  ProduceMediaSourceStats_s(timestamp, transceiver_stats_infos,
+                            transceiver_references, partial_report);
   ProducePeerConnectionStats_s(timestamp, partial_report);
   ProduceAudioPlayoutStats_s(timestamp, audio_device_stats, partial_report);
 }
@@ -1711,12 +1713,17 @@
 void RTCStatsCollector::ProduceMediaSourceStats_s(
     Timestamp timestamp,
     const std::vector<RtpTransceiverStatsInfo>& transceiver_stats_infos,
+    const std::vector<TransceiverReferences>& transceiver_references,
     RTCStatsReport* report) const {
   RTC_DCHECK_RUN_ON(signaling_thread_);
   Thread::ScopedDisallowBlockingCalls no_blocking_calls;
 
-  for (const RtpTransceiverStatsInfo& transceiver_stats_info :
-       transceiver_stats_infos) {
+  RTC_DCHECK_EQ(transceiver_stats_infos.size(), transceiver_references.size());
+  for (size_t i = 0; i < transceiver_stats_infos.size(); ++i) {
+    const RtpTransceiverStatsInfo& transceiver_stats_info =
+        transceiver_stats_infos[i];
+    const TransceiverReferences& refs = transceiver_references[i];
+
     // The transceiver will still exist but in a stopped state after pc.close().
     if (transceiver_stats_info.current_direction ==
         RtpTransceiverDirection::kStopped) {
@@ -1726,7 +1733,7 @@
     const TrackMediaInfoMap& track_media_info_map =
         *transceiver_stats_info.track_media_info_map;
 
-    for (const auto& sender : transceiver_stats_info.transceiver->senders()) {
+    for (const auto& sender : refs.transceiver->senders()) {
       const auto& sender_internal = sender->internal();
       const auto& track = sender_internal->track();
       if (!track)
@@ -2254,11 +2261,12 @@
   return transport_cert_stats;
 }
 
-absl::AnyInvocable<RTCStatsCollector::StatsGatheringResults()>
+absl::AnyInvocable<RTCStatsCollector::WorkerThreadResult()>
 RTCStatsCollector::PrepareTransceiverStatsInfosAndCallStats_s_w() {
   RTC_DCHECK_RUN_ON(signaling_thread_);
 
   std::vector<RtpTransceiverStatsInfo> transceiver_stats_infos;
+  std::vector<TransceiverReferences> transceiver_references;
   // These are used to invoke GetStats for all the media channels together in
   // one worker thread hop.
   std::map<VoiceMediaSendChannelInterface*, VoiceMediaSendInfo>
@@ -2276,13 +2284,15 @@
     RtpTransceiver* transceiver = transceiver_proxy->internal();
 
     RtpTransceiverStatsInfo stats{
-        .transceiver = scoped_refptr<RtpTransceiver>(transceiver),
         .media_type = transceiver->media_type(),
         .mid = transceiver->mid(),
         .transport_name = transceiver->transport_name(),
         .current_direction = transceiver->current_direction(),
         .has_channel = transceiver->HasChannel()};
 
+    TransceiverReferences refs{.transceiver =
+                                   scoped_refptr<RtpTransceiver>(transceiver)};
+
     for (const auto& sender : transceiver->senders()) {
       stats.sender_infos.push_back(
           {.ssrc = sender->ssrc(),
@@ -2294,10 +2304,10 @@
           {.track_id = receiver->track() ? receiver->track()->id() : "",
            .attachment_id = receiver->internal()->AttachmentId(),
            .media_type = receiver->media_type()});
-      stats.receivers.push_back(
+      refs.receivers.push_back(
           scoped_refptr<RtpReceiverInternal>(receiver->internal()));
     }
-    stats.has_receivers = !stats.receivers.empty();
+    stats.has_receivers = !refs.receivers.empty();
 
     if (stats.has_channel) {
       if (stats.media_type == MediaType::AUDIO) {
@@ -2316,6 +2326,7 @@
     }
 
     transceiver_stats_infos.push_back(std::move(stats));
+    transceiver_references.push_back(std::move(refs));
   }
 
   // Embed the collected information into this lambda which will run on the
@@ -2323,14 +2334,17 @@
   // as GetCallStats(). At the same time we construct the TrackMediaInfoMaps,
   // which also needs info from the worker thread.
   return [this, transceiver_stats_infos = std::move(transceiver_stats_infos),
+          transceiver_references = std::move(transceiver_references),
           voice_send_stats = std::move(voice_send_stats),
           voice_receive_stats = std::move(voice_receive_stats),
           video_send_stats = std::move(video_send_stats),
           video_receive_stats = std::move(video_receive_stats)]() mutable {
     Thread::ScopedDisallowBlockingCalls no_blocking_calls;
 
-    StatsGatheringResults results = {.transceiver_stats_infos =
-                                         std::move(transceiver_stats_infos)};
+    WorkerThreadResult worker_result;
+    worker_result.results.transceiver_stats_infos =
+        std::move(transceiver_stats_infos);
+    worker_result.transceiver_references = std::move(transceiver_references);
 
     for (auto& pair : voice_send_stats) {
       if (!pair.first->GetStats(&pair.second)) {
@@ -2357,7 +2371,12 @@
     // Create the TrackMediaInfoMap for each transceiver stats object
     // and keep track of whether we have at least one audio receiver.
     bool has_audio_receiver = false;
-    for (auto& stats : results.transceiver_stats_infos) {
+    RTC_DCHECK_EQ(worker_result.results.transceiver_stats_infos.size(),
+                  worker_result.transceiver_references.size());
+    for (size_t i = 0; i < worker_result.results.transceiver_stats_infos.size();
+         ++i) {
+      auto& stats = worker_result.results.transceiver_stats_infos[i];
+      auto& refs = worker_result.transceiver_references[i];
       // The transceiver will still exist but in a stopped state after
       // pc.close().
       if (stats.current_direction == RtpTransceiverDirection::kStopped) {
@@ -2365,14 +2384,14 @@
       }
 
       std::vector<RtpParameters> receiver_parameters;
-      for (const auto& receiver : stats.receivers) {
+      for (const auto& receiver : refs.receivers) {
         receiver_parameters.push_back(receiver->GetParameters());
       }
 
       std::optional<VoiceMediaInfo> voice_media_info;
       std::optional<VideoMediaInfo> video_media_info;
       if (stats.has_channel) {
-        auto& transceiver = stats.transceiver;
+        auto& transceiver = refs.transceiver;
         if (stats.media_type == MediaType::AUDIO) {
           auto voice_send_channel = transceiver->voice_media_send_channel();
           auto voice_receive_channel =
@@ -2399,10 +2418,10 @@
       }
     }
 
-    results.call_stats = pc_->GetCallStats();
-    results.audio_device_stats =
+    worker_result.results.call_stats = pc_->GetCallStats();
+    worker_result.results.audio_device_stats =
         has_audio_receiver ? pc_->GetAudioDeviceStats() : std::nullopt;
-    return results;
+    return worker_result;
   };
 }
 
diff --git a/pc/rtc_stats_collector.h b/pc/rtc_stats_collector.h
index 4276db6..5e9bbfe 100644
--- a/pc/rtc_stats_collector.h
+++ b/pc/rtc_stats_collector.h
@@ -61,19 +61,24 @@
 // If a BaseChannel is not available (e.g., if signaling has not started),
 // then `mid` and `transport_name` will be null.
 struct RtpTransceiverStatsInfo {
-  scoped_refptr<RtpTransceiver> transceiver;
   const MediaType media_type;
   const std::optional<std::string> mid;
   std::optional<std::string> transport_name;
   std::vector<TrackMediaInfoMap::RtpSenderSignalInfo> sender_infos;
   std::vector<TrackMediaInfoMap::RtpReceiverSignalInfo> receiver_infos;
-  std::vector<scoped_refptr<RtpReceiverInternal>> receivers;
   std::unique_ptr<TrackMediaInfoMap> track_media_info_map;
   const std::optional<RtpTransceiverDirection> current_direction;
   bool has_receivers = false;
   const bool has_channel;
 };
 
+// References to objects used on the signaling and worker threads for populating
+// RtpTransceiverStatsInfo but must always be released on the signaling thread
+struct TransceiverReferences {
+  scoped_refptr<RtpTransceiver> transceiver;
+  std::vector<scoped_refptr<RtpReceiverInternal>> receivers;
+};
+
 // All public methods of the collector are to be called on the signaling thread.
 // Stats are gathered on the signaling, worker and network threads
 // asynchronously. The callback is invoked on the signaling thread. Resulting
@@ -134,6 +139,7 @@
   virtual void ProducePartialResultsOnSignalingThreadImpl(
       Timestamp timestamp,
       const std::vector<RtpTransceiverStatsInfo>& transceiver_stats_infos,
+      const std::vector<TransceiverReferences>& transceiver_references,
       const std::optional<AudioDeviceModule::Stats>& audio_device_stats,
       RTCStatsReport* partial_report);
 
@@ -153,6 +159,10 @@
     std::optional<AudioDeviceModule::Stats> audio_device_stats;
   };
 
+  struct WorkerThreadResult {
+    StatsGatheringResults results;
+    std::vector<TransceiverReferences> transceiver_references;
+  };
   struct CollectionContext;
   class RequestInfo {
    public:
@@ -221,6 +231,7 @@
   void ProduceMediaSourceStats_s(
       Timestamp timestamp,
       const std::vector<RtpTransceiverStatsInfo>& transceiver_stats_infos,
+      const std::vector<TransceiverReferences>& transceiver_references,
       RTCStatsReport* report) const;
   // Produces `RTCPeerConnectionStats`.
   void ProducePeerConnectionStats_s(Timestamp timestamp,
@@ -267,12 +278,13 @@
   // Prepares the transceiver stats infos and call stats.
   // Returns a callback that should be executed on the worker thread to populate
   // the stats.
-  absl::AnyInvocable<StatsGatheringResults()>
+  absl::AnyInvocable<WorkerThreadResult()>
   PrepareTransceiverStatsInfosAndCallStats_s_w();
 
   // Stats gathering on a particular thread.
   void ProducePartialResultsOnSignalingThread(
       const std::vector<RtpTransceiverStatsInfo>& transceiver_stats_infos,
+      const std::vector<TransceiverReferences>& transceiver_references,
       const std::optional<AudioDeviceModule::Stats>& audio_device_stats);
   void ProducePartialResultsOnNetworkThread(
       scoped_refptr<PendingTaskSafetyFlag> signaling_safety,
diff --git a/pc/rtc_stats_collector_unittest.cc b/pc/rtc_stats_collector_unittest.cc
index 3fb8e03a..480ab70 100644
--- a/pc/rtc_stats_collector_unittest.cc
+++ b/pc/rtc_stats_collector_unittest.cc
@@ -3980,6 +3980,7 @@
   void ProducePartialResultsOnSignalingThreadImpl(
       Timestamp timestamp,
       const std::vector<RtpTransceiverStatsInfo>& transceiver_stats_infos,
+      const std::vector<TransceiverReferences>& transceiver_references,
       const std::optional<AudioDeviceModule::Stats>& audio_device_stats,
       RTCStatsReport* partial_report) override {
     EXPECT_TRUE(signaling_thread_->IsCurrent());