Add async cancellation support to RTCStatsCollector This reduces blocking call operations in the PC destructor down to a consistent 2 and removes the synchronous Wait() operation in RTCStatsCollector from that path (it's still in Close()). The implementation uses PendingTaskSafetyFlag and SafeTask to invalidate pending tasks across the signaling and network threads. When CancelPendingRequest is called or the collector is destroyed, in-flight tasks are safely dropped rather than being synchronously waited upon. Key modifications: - Integrate ScopedTaskSafety within RTCStatsCollector for signaling thread task management. - Require a network thread safety flag during RTCStatsCollector initialization to ensure safe cross-thread operations. - Replace WaitForPendingRequest with CancelPendingRequest in the PeerConnection destructor. - Wrap cross-thread PostTask calls in SafeTask to prevent callbacks from executing after the collector has been cancelled or destroyed. Bug: webrtc:42222804 Change-Id: I48f30e839d1d9b4d58ad829e3e9d29962d742a7b Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/435440 Commit-Queue: Tomas Gunnarsson <tommi@webrtc.org> Reviewed-by: Harald Alvestrand <hta@webrtc.org> Cr-Commit-Position: refs/heads/main@{#46536}
diff --git a/pc/BUILD.gn b/pc/BUILD.gn index e42bfae..da203c5 100644 --- a/pc/BUILD.gn +++ b/pc/BUILD.gn
@@ -1034,6 +1034,7 @@ "../api/audio:audio_processing_statistics", "../api/environment", "../api/task_queue", + "../api/task_queue:pending_task_safety_flag", "../api/transport:enums", "../api/units:time_delta", "../api/units:timestamp", @@ -1063,8 +1064,7 @@ "../rtc_base:threading", "../rtc_base:timeutils", "../rtc_base/containers:flat_set", - "../rtc_base/synchronization:mutex", - "//third_party/abseil-cpp/absl/functional:bind_front", + "//third_party/abseil-cpp/absl/functional:any_invocable", "//third_party/abseil-cpp/absl/strings", "//third_party/abseil-cpp/absl/strings:string_view", ]
diff --git a/pc/peer_connection.cc b/pc/peer_connection.cc index 315927d..b7520b6 100644 --- a/pc/peer_connection.cc +++ b/pc/peer_connection.cc
@@ -676,36 +676,27 @@ PeerConnection::~PeerConnection() { TRACE_EVENT0("webrtc", "PeerConnection::~PeerConnection"); RTC_DCHECK_RUN_ON(signaling_thread()); + RTC_LOG_THREAD_BLOCK_COUNT(); - if (sdp_handler_) { - sdp_handler_->PrepareForShutdown(); - } + sdp_handler_->PrepareForShutdown(); // In case `Close()` wasn't called, always make sure the controller cancels // potentially pending operations. data_channel_controller_.PrepareForShutdown(); - // Need to stop transceivers before destroying the stats collector because - // AudioRtpSender has a reference to the LegacyStatsCollector it will update - // when stopping. - if (rtp_manager()) { - for (const auto& transceiver : rtp_manager()->transceivers()->List()) { - transceiver->StopInternal(); - } - } + std::vector<absl::AnyInvocable<void() &&>> network_tasks; + std::vector<absl::AnyInvocable<void() &&>> worker_tasks; + + // Stop transceivers before destroying the stats collector because + // AudioRtpSender has a reference to the LegacyStatsCollector that it will + // update when stopping. The BaseChannels will eventually be deleted below + // when all the network and worker tasks are executed. + sdp_handler_->GetMediaChannelTeardownTasks(network_tasks, worker_tasks); legacy_stats_.reset(nullptr); if (stats_collector_) { - stats_collector_->WaitForPendingRequest(); - stats_collector_ = nullptr; - } - - std::vector<absl::AnyInvocable<void() &&>> network_tasks; - std::vector<absl::AnyInvocable<void() &&>> worker_tasks; - if (sdp_handler_) { - // Don't destroy BaseChannels until after stats has been cleaned up so that - // the last stats request can still read from the channels. - sdp_handler_->GetMediaChannelTeardownTasks(network_tasks, worker_tasks); + network_tasks.push_back( + stats_collector_->CancelPendingRequestAndGetShutdownTask()); } CloseOnNetworkThread(network_tasks); @@ -727,6 +718,10 @@ } data_channel_controller_.PrepareForShutdown(); + + // The expectation is that there will have been 1 blocking call for the worker + // thread and optionally 1 task for the network thread. + RTC_DCHECK_BLOCK_COUNT_NO_MORE_THAN(2); } JsepTransportController* PeerConnection::InitializeNetworkThread( @@ -866,30 +861,33 @@ void PeerConnection::CloseOnNetworkThread( std::vector<absl::AnyInvocable<void() &&>>& network_tasks) { RTC_DCHECK_RUN_ON(signaling_thread()); - if (!transport_controller_copy_) { - // If the transport has been torn down then there should not be any - // pending network tasks to run. - RTC_DCHECK(network_tasks.empty()); - RTC_DCHECK(!sctp_mid_s_.has_value()) << "Should already be reset."; - return; + if (transport_controller_copy_ || !network_tasks.empty()) { + network_thread()->BlockingCall([&] { + RTC_DCHECK_RUN_ON(network_thread()); + for (auto& task : network_tasks) { + std::move(task)(); + task = nullptr; + } + if (network_thread_safety_->alive()) { + // port_allocator_ and transport_controller_ live on the network thread + // and must be destroyed there. + TeardownDataChannelTransport_n(RTCError::OK()); + port_allocator_->DiscardCandidatePool(); + transport_controller_.reset(); + port_allocator_.reset(); + network_thread_safety_->SetNotAlive(); + } + }); } - // port_allocator_ and transport_controller_ live on the network thread and - // should be destroyed there. - transport_controller_copy_ = nullptr; - network_thread()->BlockingCall([&] { - RTC_DCHECK_RUN_ON(network_thread()); - for (auto& task : network_tasks) { - std::move(task)(); - task = nullptr; - } - TeardownDataChannelTransport_n(RTCError::OK()); - port_allocator_->DiscardCandidatePool(); - transport_controller_.reset(); - port_allocator_.reset(); - network_thread_safety_->SetNotAlive(); - }); - sctp_mid_s_.reset(); - SetSctpTransportName(""); + + if (transport_controller_copy_) { + transport_controller_copy_ = nullptr; + sctp_mid_s_.reset(); + SetSctpTransportName(""); + } else { + RTC_DCHECK(!sctp_mid_s_); + RTC_DCHECK(sctp_transport_name_s_.empty()); + } } JsepTransportController* PeerConnection::InitializeTransportController_n(
diff --git a/pc/peer_connection.h b/pc/peer_connection.h index 8551c1e..c2be9f1 100644 --- a/pc/peer_connection.h +++ b/pc/peer_connection.h
@@ -452,10 +452,7 @@ bool ConfiguredForMedia() const; // Functions made public for testing. - void ReturnHistogramVeryQuicklyForTesting() { - RTC_DCHECK_RUN_ON(signaling_thread()); - return_histogram_very_quickly_ = true; - } + void RequestUsagePatternReportForTesting(); int FeedbackAccordingToRfc8888CountForTesting() const; int FeedbackAccordingToTransportCcCountForTesting() const; @@ -659,8 +656,7 @@ const bool is_unified_plan_; const bool dtls_enabled_; - bool return_histogram_very_quickly_ RTC_GUARDED_BY(signaling_thread()) = - false; + // Did the connectionState ever change to `connected`? // Used to gather metrics only the first such state change. bool was_ever_connected_ RTC_GUARDED_BY(signaling_thread()) = false;
diff --git a/pc/rtc_stats_collector.cc b/pc/rtc_stats_collector.cc index 25b5ffe..2cd0731 100644 --- a/pc/rtc_stats_collector.cc +++ b/pc/rtc_stats_collector.cc
@@ -21,7 +21,7 @@ #include <utility> #include <vector> -#include "absl/functional/bind_front.h" +#include "absl/functional/any_invocable.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "api/array_view.h" @@ -42,6 +42,7 @@ #include "api/stats/rtc_stats_collector_callback.h" #include "api/stats/rtc_stats_report.h" #include "api/stats/rtcstats_objects.h" +#include "api/task_queue/pending_task_safety_flag.h" #include "api/transport/enums.h" #include "api/units/time_delta.h" #include "api/units/timestamp.h" @@ -75,8 +76,8 @@ #include "rtc_base/ssl_certificate.h" #include "rtc_base/ssl_stream_adapter.h" #include "rtc_base/strings/string_builder.h" -#include "rtc_base/synchronization/mutex.h" #include "rtc_base/thread.h" +#include "rtc_base/thread_annotations.h" #include "rtc_base/time_utils.h" #include "rtc_base/trace_event.h" @@ -1216,7 +1217,13 @@ network_report_event_(true /* manual_reset */, true /* initially_signaled */), cache_timestamp_us_(0), - cache_lifetime_us_(cache_lifetime_us) { + cache_lifetime_us_(cache_lifetime_us), + signaling_safety_( + PendingTaskSafetyFlag::CreateAttachedToTaskQueue(/*alive=*/true, + signaling_thread_)), + network_safety_( + PendingTaskSafetyFlag::CreateAttachedToTaskQueue(/*alive=*/true, + network_thread_)) { RTC_DCHECK(pc_); RTC_DCHECK(signaling_thread_); RTC_DCHECK(worker_thread_); @@ -1224,9 +1231,7 @@ RTC_DCHECK_GE(cache_lifetime_us_, 0); } -RTCStatsCollector::~RTCStatsCollector() { - RTC_DCHECK_EQ(num_pending_partial_reports_, 0); -} +RTCStatsCollector::~RTCStatsCollector() = default; void RTCStatsCollector::GetStatsReport( scoped_refptr<RTCStatsCollectorCallback> callback) { @@ -1257,15 +1262,22 @@ // We have a fresh cached report to deliver. Deliver asynchronously, since // the caller may not be expecting a synchronous callback, and it avoids // reentrancy problems. - signaling_thread_->PostTask( - absl::bind_front(&RTCStatsCollector::DeliverCachedReport, - scoped_refptr<RTCStatsCollector>(this), cached_report_, - std::move(requests_))); + signaling_thread_->PostTask(SafeTask( + signaling_safety_, [this, report = cached_report_, + requests = std::move(requests_)]() mutable { + DeliverCachedReport(std::move(report), std::move(requests)); + })); } else if (!num_pending_partial_reports_) { // Only start gathering stats if we're not already gathering stats. In the // case of already gathering stats, `callback_` will be invoked when there // are no more pending partial reports. + // Initialize common variables for the stats gather operation. + // As a future improvement, these could be owned by a dedicated stats + // gathering object that is used across the async steps. This would include + // moving variables such as partial_report_, network_report_, + // transceiver_stats_infos_, etc to that object rather than keep it as + // unguarded member variables. Timestamp timestamp = stats_timestamp_with_environment_clock_ ? @@ -1276,34 +1288,51 @@ // 1970, UTC), in microseconds. The system clock could be modified // and is not necessarily monotonically increasing. Timestamp::Micros(TimeUTCMicros()); - num_pending_partial_reports_ = 2; partial_report_timestamp_us_ = cache_now_us; + network_report_event_.Reset(); // Prepare `transceiver_stats_infos_` and `call_stats_` for use in // `ProducePartialResultsOnNetworkThread` and // `ProducePartialResultsOnSignalingThread`. PrepareTransceiverStatsInfosAndCallStats_s_w_n(); - // Don't touch `network_report_` on the signaling thread until - // ProducePartialResultsOnNetworkThread() has signaled the - // `network_report_event_`. - network_report_event_.Reset(); - scoped_refptr<RTCStatsCollector> collector(this); - network_thread_->PostTask([collector, - sctp_transport_name = pc_->sctp_transport_name(), - timestamp]() mutable { - collector->ProducePartialResultsOnNetworkThread( - timestamp, std::move(sctp_transport_name)); - }); + + // Create the initial `partial_report_` for the gathering operation. ProducePartialResultsOnSignalingThread(timestamp); + + std::set<std::string> transport_names; + auto sctp_transport_name = pc_->sctp_transport_name(); + if (sctp_transport_name) { + transport_names.emplace(std::move(*sctp_transport_name)); + } + + for (const auto& info : transceiver_stats_infos_) { + if (info.transport_name) + transport_names.insert(*info.transport_name); + } + + std::vector<RtpTransceiverStatsInfo>* cheating = &transceiver_stats_infos_; + network_thread_->PostTask(SafeTask( + network_safety_, + [this, transport_names = std::move(transport_names), timestamp, + signaling_flag = signaling_safety_, cheating = cheating]() mutable { + ProducePartialResultsOnNetworkThread( + std::move(signaling_flag), timestamp, std::move(transport_names), + *cheating); + })); } } void RTCStatsCollector::ClearCachedStatsReport() { RTC_DCHECK_RUN_ON(signaling_thread_); cached_report_ = nullptr; - MutexLock lock(&cached_certificates_mutex_); - cached_certificates_by_transport_.clear(); + // If we're not shutting down, clear the cache on the network thread. + if (signaling_safety_->alive()) { + network_thread_->PostTask(SafeTask(network_safety_, [this]() { + RTC_DCHECK_RUN_ON(network_thread_); + cached_certificates_by_transport_.clear(); + })); + } } void RTCStatsCollector::WaitForPendingRequest() { @@ -1313,6 +1342,13 @@ MergeNetworkReport_s(); } +absl::AnyInvocable<void() &&> +RTCStatsCollector::CancelPendingRequestAndGetShutdownTask() { + RTC_DCHECK_RUN_ON(signaling_thread_); + signaling_safety_->SetNotAlive(); + return [flag = network_safety_]() { flag->SetNotAlive(); }; +} + void RTCStatsCollector::ProducePartialResultsOnSignalingThread( Timestamp timestamp) { RTC_DCHECK_RUN_ON(signaling_thread_); @@ -1322,10 +1358,10 @@ ProducePartialResultsOnSignalingThreadImpl(timestamp, partial_report_.get()); - // ProducePartialResultsOnSignalingThread() is running synchronously on the - // signaling thread, so it is always the first partial result delivered on the + // ProducePartialResultsOnSignalingThread() runs synchronously on the + // signaling thread. So it is always the first partial result delivered on the // signaling thread. The request is not complete until MergeNetworkReport_s() - // happens; we don't have to do anything here. + // runs. We don't have to do anything here. RTC_DCHECK_GT(num_pending_partial_reports_, 1); --num_pending_partial_reports_; } @@ -1334,16 +1370,16 @@ Timestamp timestamp, RTCStatsReport* partial_report) { RTC_DCHECK_RUN_ON(signaling_thread_); - Thread::ScopedDisallowBlockingCalls no_blocking_calls; - ProduceMediaSourceStats_s(timestamp, partial_report); ProducePeerConnectionStats_s(timestamp, partial_report); ProduceAudioPlayoutStats_s(timestamp, partial_report); } void RTCStatsCollector::ProducePartialResultsOnNetworkThread( + scoped_refptr<PendingTaskSafetyFlag> signaling_safety, Timestamp timestamp, - std::optional<std::string> sctp_transport_name) { + std::set<std::string> transport_names, + std::vector<RtpTransceiverStatsInfo>& transceiver_stats_infos) { TRACE_EVENT0("webrtc", "RTCStatsCollector::ProducePartialResultsOnNetworkThread"); RTC_DCHECK_RUN_ON(network_thread_); @@ -1355,37 +1391,27 @@ ProduceDataChannelStats_n(timestamp, network_report_.get()); - std::set<std::string> transport_names; - if (sctp_transport_name) { - transport_names.emplace(std::move(*sctp_transport_name)); - } - - for (const auto& info : transceiver_stats_infos_) { - if (info.transport_name) - transport_names.insert(*info.transport_name); - } - std::map<std::string, TransportStats> transport_stats_by_name = pc_->GetTransportStatsByNames(transport_names); std::map<std::string, CertificateStatsPair> transport_cert_stats = PrepareTransportCertificateStats_n(transport_stats_by_name); - ProducePartialResultsOnNetworkThreadImpl(timestamp, transport_stats_by_name, - transport_cert_stats, - network_report_.get()); + ProducePartialResultsOnNetworkThreadImpl( + timestamp, transport_stats_by_name, transport_cert_stats, + transceiver_stats_infos, network_report_.get()); // Signal that it is now safe to touch `network_report_` on the signaling // thread, and post a task to merge it into the final results. network_report_event_.Set(); - scoped_refptr<RTCStatsCollector> collector(this); - signaling_thread_->PostTask( - [collector] { collector->MergeNetworkReport_s(); }); + signaling_thread_->PostTask(SafeTask(std::move(signaling_safety), + [this] { MergeNetworkReport_s(); })); } void RTCStatsCollector::ProducePartialResultsOnNetworkThreadImpl( Timestamp timestamp, const std::map<std::string, TransportStats>& transport_stats_by_name, const std::map<std::string, CertificateStatsPair>& transport_cert_stats, + const std::vector<RtpTransceiverStatsInfo>& transceiver_stats_infos, RTCStatsReport* partial_report) { RTC_DCHECK_RUN_ON(network_thread_); Thread::ScopedDisallowBlockingCalls no_blocking_calls; @@ -1395,11 +1421,12 @@ call_stats_, partial_report); ProduceTransportStats_n(timestamp, transport_stats_by_name, transport_cert_stats, call_stats_, partial_report); - ProduceRTPStreamStats_n(timestamp, transceiver_stats_infos_, partial_report); + ProduceRTPStreamStats_n(timestamp, transceiver_stats_infos, partial_report); } void RTCStatsCollector::MergeNetworkReport_s() { RTC_DCHECK_RUN_ON(signaling_thread_); + // The `network_report_event_` must be signaled for it to be safe to touch // `network_report_`. This is normally not blocking, but if // WaitForPendingRequest() is called while a request is pending, we might have @@ -2115,15 +2142,12 @@ Thread::ScopedDisallowBlockingCalls no_blocking_calls; std::map<std::string, CertificateStatsPair> transport_cert_stats; - { - MutexLock lock(&cached_certificates_mutex_); - // Copy the certificate info from the cache, avoiding expensive - // webrtc::SSLCertChain::GetStats() calls. - for (const auto& pair : cached_certificates_by_transport_) { - transport_cert_stats.insert( - std::make_pair(pair.first, pair.second.Copy())); - } + // Copy the certificate info from the cache, avoiding expensive + // webrtc::SSLCertChain::GetStats() calls. + for (const auto& pair : cached_certificates_by_transport_) { + transport_cert_stats.insert(std::make_pair(pair.first, pair.second.Copy())); } + if (transport_cert_stats.empty()) { // Collect certificate info. for (const auto& entry : transport_stats_by_name) { @@ -2145,7 +2169,6 @@ std::make_pair(transport_name, std::move(certificate_stats_pair))); } // Copy the result into the certificate cache for future reference. - MutexLock lock(&cached_certificates_mutex_); for (const auto& pair : transport_cert_stats) { cached_certificates_by_transport_.insert( std::make_pair(pair.first, pair.second.Copy())); @@ -2181,52 +2204,65 @@ } // TODO(tommi): See if we can avoid synchronously blocking the signaling - // thread while we do this (or avoid the BlockingCall at all). - network_thread_->BlockingCall([&] { - Thread::ScopedDisallowBlockingCalls no_blocking_calls; + // thread while we do this (or avoid the BlockingCall at all). Note also that + // where PrepareTransceiverStatsInfosAndCallStats_s_w_n is called from, + // there's a PostTask() to the network thread to call + // ProducePartialResultsOnNetworkThread(). See if this block should be merged + // with that. + // Currently using RTC_NO_THREAD_SAFETY_ANALYSIS here and below due to use of + // transceiver_stats_infos_. Remove this and pass transceiver_stats_infos_ in + // an object that's used to gather the data from start to finish. + network_thread_->BlockingCall( + [&, &transceiver_stats_infos = transceiver_stats_infos_]() + RTC_NO_THREAD_SAFETY_ANALYSIS mutable { + Thread::ScopedDisallowBlockingCalls no_blocking_calls; - for (auto& stats : transceiver_stats_infos_) { - ChannelInterface* channel = stats.transceiver->channel(); - if (!channel) { - continue; - } + for (auto& stats : transceiver_stats_infos) { + ChannelInterface* channel = stats.transceiver->channel(); + if (!channel) { + continue; + } - stats.transport_name = std::string(channel->transport_name()); + stats.transport_name = std::string(channel->transport_name()); - if (stats.media_type == MediaType::AUDIO) { - auto voice_send_channel = channel->voice_media_send_channel(); - RTC_DCHECK(voice_send_stats.find(voice_send_channel) == - voice_send_stats.end()); - voice_send_stats.insert( - std::make_pair(voice_send_channel, VoiceMediaSendInfo())); + if (stats.media_type == MediaType::AUDIO) { + auto voice_send_channel = channel->voice_media_send_channel(); + RTC_DCHECK(voice_send_stats.find(voice_send_channel) == + voice_send_stats.end()); + voice_send_stats.insert( + std::make_pair(voice_send_channel, VoiceMediaSendInfo())); - auto voice_receive_channel = channel->voice_media_receive_channel(); - RTC_DCHECK(voice_receive_stats.find(voice_receive_channel) == - voice_receive_stats.end()); - voice_receive_stats.insert( - std::make_pair(voice_receive_channel, VoiceMediaReceiveInfo())); - } else if (stats.media_type == MediaType::VIDEO) { - auto video_send_channel = channel->video_media_send_channel(); - RTC_DCHECK(video_send_stats.find(video_send_channel) == - video_send_stats.end()); - video_send_stats.insert( - std::make_pair(video_send_channel, VideoMediaSendInfo())); - auto video_receive_channel = channel->video_media_receive_channel(); - RTC_DCHECK(video_receive_stats.find(video_receive_channel) == - video_receive_stats.end()); - video_receive_stats.insert( - std::make_pair(video_receive_channel, VideoMediaReceiveInfo())); - } else { - RTC_DCHECK_NOTREACHED(); - } - } - }); + auto voice_receive_channel = + channel->voice_media_receive_channel(); + RTC_DCHECK(voice_receive_stats.find(voice_receive_channel) == + voice_receive_stats.end()); + voice_receive_stats.insert(std::make_pair( + voice_receive_channel, VoiceMediaReceiveInfo())); + } else if (stats.media_type == MediaType::VIDEO) { + auto video_send_channel = channel->video_media_send_channel(); + RTC_DCHECK(video_send_stats.find(video_send_channel) == + video_send_stats.end()); + video_send_stats.insert( + std::make_pair(video_send_channel, VideoMediaSendInfo())); + auto video_receive_channel = + channel->video_media_receive_channel(); + RTC_DCHECK(video_receive_stats.find(video_receive_channel) == + video_receive_stats.end()); + video_receive_stats.insert(std::make_pair( + video_receive_channel, VideoMediaReceiveInfo())); + } else { + RTC_DCHECK_NOTREACHED(); + } + } + }); // We jump to the worker thread and call GetStats() on each media channel as // well as GetCallStats(). At the same time we construct the // TrackMediaInfoMaps, which also needs info from the worker thread. This // minimizes the number of thread jumps. - worker_thread_->BlockingCall([&] { + // Currently using RTC_NO_THREAD_SAFETY_ANALYSIS here too due to use of + // transceiver_stats_infos_ in a blocking call. + worker_thread_->BlockingCall([&]() RTC_NO_THREAD_SAFETY_ANALYSIS mutable { Thread::ScopedDisallowBlockingCalls no_blocking_calls; for (auto& pair : voice_send_stats) {
diff --git a/pc/rtc_stats_collector.h b/pc/rtc_stats_collector.h index 8217c2c..75ec0b7 100644 --- a/pc/rtc_stats_collector.h +++ b/pc/rtc_stats_collector.h
@@ -17,9 +17,11 @@ #include <map> #include <memory> #include <optional> +#include <set> #include <string> #include <vector> +#include "absl/functional/any_invocable.h" #include "api/audio/audio_device.h" #include "api/data_channel_interface.h" #include "api/environment/environment.h" @@ -29,6 +31,7 @@ #include "api/scoped_refptr.h" #include "api/stats/rtc_stats_collector_callback.h" #include "api/stats/rtc_stats_report.h" +#include "api/task_queue/pending_task_safety_flag.h" #include "api/task_queue/task_queue_base.h" #include "api/units/timestamp.h" #include "call/call.h" @@ -42,7 +45,6 @@ #include "rtc_base/containers/flat_set.h" #include "rtc_base/event.h" #include "rtc_base/ssl_certificate.h" -#include "rtc_base/synchronization/mutex.h" #include "rtc_base/thread.h" #include "rtc_base/thread_annotations.h" #include "rtc_base/time_utils.h" @@ -52,6 +54,22 @@ class RtpSenderInternal; class RtpReceiverInternal; +// Structure for tracking stats about each RtpTransceiver managed by the +// PeerConnection. This can either by a Plan B style or Unified Plan style +// transceiver (i.e., can have 0 or many senders and receivers). +// Some fields are copied from the RtpTransceiver/BaseChannel object so that +// they can be accessed safely on threads other than the signaling thread. +// If a BaseChannel is not available (e.g., if signaling has not started), +// then `mid` and `transport_name` will be null. +struct RtpTransceiverStatsInfo { + const scoped_refptr<RtpTransceiver> transceiver; + const MediaType media_type; + const std::optional<std::string> mid; + std::optional<std::string> transport_name; + TrackMediaInfoMap track_media_info_map; + const std::optional<RtpTransceiverDirection> current_direction; +}; + // All public methods of the collector are to be called on the signaling thread. // Stats are gathered on the signaling, worker and network threads // asynchronously. The callback is invoked on the signaling thread. Resulting @@ -91,6 +109,11 @@ // completed. Must be called on the signaling thread. void WaitForPendingRequest(); + // Cancels pending stats gathering operations and prepares for shutdown. + // This method returns a task that the caller needs to make sure is executed + // on the network thread before the RTCStatsCollector instance is deleted. + absl::AnyInvocable<void() &&> CancelPendingRequestAndGetShutdownTask(); + // Called by the PeerConnection instance when data channel states change. void OnSctpDataChannelStateChanged(int channel_id, DataChannelInterface::DataState state); @@ -112,10 +135,12 @@ virtual void ProducePartialResultsOnSignalingThreadImpl( Timestamp timestamp, RTCStatsReport* partial_report); + virtual void ProducePartialResultsOnNetworkThreadImpl( Timestamp timestamp, const std::map<std::string, TransportStats>& transport_stats_by_name, const std::map<std::string, CertificateStatsPair>& transport_cert_stats, + const std::vector<RtpTransceiverStatsInfo>& transceiver_stats_infos, RTCStatsReport* partial_report); private: @@ -161,22 +186,6 @@ void GetStatsReportInternal(RequestInfo request); - // Structure for tracking stats about each RtpTransceiver managed by the - // PeerConnection. This can either by a Plan B style or Unified Plan style - // transceiver (i.e., can have 0 or many senders and receivers). - // Some fields are copied from the RtpTransceiver/BaseChannel object so that - // they can be accessed safely on threads other than the signaling thread. - // If a BaseChannel is not available (e.g., if signaling has not started), - // then `mid` and `transport_name` will be null. - struct RtpTransceiverStatsInfo { - scoped_refptr<RtpTransceiver> transceiver; - webrtc::MediaType media_type; - std::optional<std::string> mid; - std::optional<std::string> transport_name; - TrackMediaInfoMap track_media_info_map; - std::optional<RtpTransceiverDirection> current_direction; - }; - void DeliverCachedReport(scoped_refptr<const RTCStatsReport> cached_report, std::vector<RequestInfo> requests); @@ -237,8 +246,10 @@ // Stats gathering on a particular thread. void ProducePartialResultsOnSignalingThread(Timestamp timestamp); void ProducePartialResultsOnNetworkThread( + scoped_refptr<PendingTaskSafetyFlag> signaling_safety, Timestamp timestamp, - std::optional<std::string> sctp_transport_name); + std::set<std::string> transport_names, + std::vector<RtpTransceiverStatsInfo>& transceiver_stats_infos); // Merges `network_report_` into `partial_report_` and completes the request. // This is a NO-OP if `network_report_` is null. void MergeNetworkReport_s(); @@ -263,7 +274,7 @@ // merged into this report. It is only touched on the signaling thread. Once // all partial reports are merged this is the result of a request. scoped_refptr<RTCStatsReport> partial_report_; - std::vector<RequestInfo> requests_; + std::vector<RequestInfo> requests_ RTC_GUARDED_BY(signaling_thread_); // Holds the result of ProducePartialResultsOnNetworkThread(). It is merged // into `partial_report_` on the signaling thread and then nulled by // MergeNetworkReport_s(). Thread-safety is ensured by using @@ -278,19 +289,19 @@ // Cleared and set in `PrepareTransceiverStatsInfosAndCallStats_s_w_n`, // starting out on the signaling thread, then network. Later read on the // network and signaling threads as part of collecting stats and finally - // reset when the work is done. Initially this variable was added and not - // passed around as an arguments to avoid copies. This is thread safe due to - // how operations are sequenced and we don't start the stats collection - // sequence if one is in progress. As a future improvement though, we could - // now get rid of the variable and keep the data scoped within a stats - // collection sequence. - std::vector<RtpTransceiverStatsInfo> transceiver_stats_infos_; + // reset on the signaling thread when the work is done. + // Initially this variable was added and not passed around as an arguments to + // avoid copies. This is thread safe due to how operations are sequenced, + // sometimes blocking, and we don't start the stats collection sequence if one + // is in progress. As a future improvement though, we could now get rid of the + // variable and keep the data scoped within a stats collection sequence. + std::vector<RtpTransceiverStatsInfo> transceiver_stats_infos_ + RTC_GUARDED_BY(signaling_thread_); // This cache avoids having to call webrtc::SSLCertChain::GetStats(), which // can relatively expensive. ClearCachedStatsReport() needs to be called on // negotiation to ensure the cache is not obsolete. - Mutex cached_certificates_mutex_; std::map<std::string, CertificateStatsPair> cached_certificates_by_transport_ - RTC_GUARDED_BY(cached_certificates_mutex_); + RTC_GUARDED_BY(network_thread_); Call::Stats call_stats_; @@ -302,7 +313,8 @@ // report is. int64_t cache_timestamp_us_; int64_t cache_lifetime_us_; - scoped_refptr<const RTCStatsReport> cached_report_; + scoped_refptr<const RTCStatsReport> cached_report_ + RTC_GUARDED_BY(signaling_thread_); // Data recorded and maintained by the stats collector during its lifetime. // Some stats are produced from this record instead of other components. @@ -321,6 +333,8 @@ flat_set<int> opened_data_channels; }; InternalRecord internal_record_; + const scoped_refptr<PendingTaskSafetyFlag> signaling_safety_; + const scoped_refptr<PendingTaskSafetyFlag> network_safety_; }; } // namespace webrtc
diff --git a/pc/rtc_stats_collector_unittest.cc b/pc/rtc_stats_collector_unittest.cc index a6615e3..378c2e5 100644 --- a/pc/rtc_stats_collector_unittest.cc +++ b/pc/rtc_stats_collector_unittest.cc
@@ -94,6 +94,7 @@ #include "rtc_base/time_utils.h" #include "test/gmock.h" #include "test/gtest.h" +#include "test/run_loop.h" #include "test/wait_until.h" using ::testing::_; @@ -187,6 +188,12 @@ return candidate; } +class MockStatsCollectorCallback : public RTCStatsCollectorCallback { + public: + MOCK_METHOD1(OnStatsDelivered, + void(const scoped_refptr<const RTCStatsReport>&)); +}; + class FakeAudioProcessor : public AudioProcessorInterface { public: FakeAudioProcessor() {} @@ -388,7 +395,7 @@ const Environment& env) : pc_(pc), stats_collector_( - RTCStatsCollector::Create(pc.get(), + RTCStatsCollector::Create(pc_.get(), env, 50 * kNumMicrosecsPerMillisec)) {} @@ -3829,7 +3836,7 @@ static scoped_refptr<FakeRTCStatsCollector> Create( PeerConnectionInternal* pc, const Environment& env, - int64_t cache_lifetime_us) { + int64_t cache_lifetime_us = 50 * kNumMicrosecsPerMillisec) { return scoped_refptr<FakeRTCStatsCollector>( new RefCountedObject<FakeRTCStatsCollector>(pc, env, cache_lifetime_us)); @@ -3904,6 +3911,7 @@ Timestamp timestamp, const std::map<std::string, TransportStats>& transport_stats_by_name, const std::map<std::string, CertificateStatsPair>& transport_cert_stats, + const std::vector<RtpTransceiverStatsInfo>& transceiver_stats_infos, RTCStatsReport* partial_report) override { EXPECT_TRUE(network_thread_->IsCurrent()); { @@ -3927,8 +3935,77 @@ int produced_on_network_thread_ = 0; }; +// Simple test that verifies that GetStatsReport() can be called and async +// results delivered on the same thread. +// This covers the following steps: +// * Request stats +// * Task posted from signaling to network thread. +// * Task posted from network thread back to signaling +// * Issue the `OnStatsDelivered` callback on the signaling thread. +TEST(RTCStatsCollectorSafetyTest, WaitPendingRequestGetsCallback) { + test::RunLoop loop; + auto pc = make_ref_counted<FakePeerConnectionForStats>(); + auto env = CreateEnvironment(); + RTCStatsCollectorWrapper wrapper(pc, env); + auto callback = make_ref_counted<MockStatsCollectorCallback>(); + EXPECT_CALL(*callback, OnStatsDelivered(_)).WillOnce([&] { loop.Quit(); }); + wrapper.stats_collector()->GetStatsReport(callback); + loop.Run(); +} + +// Similar to WaitPendingRequestGetsCallback except that we call +// CancelPendingRequestAndGetShutdownTask to make sure the callback tasks won't +// run (and callbacks not issued). This covers the following steps: +// * Request stats +// * Task posted from signaling to network thread. +// * Task posted from network thread back to signaling +// * The task for the signaling thread will be dropped, no callback. +TEST(RTCStatsCollectorSafetyTest, CancelPendingRequestPreventsCallback) { + test::RunLoop loop; + auto pc = make_ref_counted<FakePeerConnectionForStats>(); + RTCStatsCollectorWrapper wrapper(pc, CreateEnvironment()); + auto callback = make_ref_counted<MockStatsCollectorCallback>(); + EXPECT_CALL(*callback, OnStatsDelivered(_)).Times(0); + // At this point, cancellation has not been made, this posts a task to the + // network thread. + wrapper.stats_collector()->GetStatsReport(callback); + // Now cancel any ongoing stats gathering operations. This should have the + // effect that the gathering that is ongoing on the network thread, will queue + // up a task for the signaling thread, but that task will be dropped. + auto network_task = + wrapper.stats_collector()->CancelPendingRequestAndGetShutdownTask(); + loop.Flush(); + // Run the network cleanup task for posterity. + std::move(network_task)(); +} + +// This covers the following steps: +// * Mark the network thread as not alive () +// * Request stats +// * Task posted from signaling to network thread. +// * The task for the network thread will be dropped, no further work done. +TEST(RTCStatsCollectorSafetyTest, NetworkThreadSafetyPreventsCallback) { + test::RunLoop loop; + auto pc = make_ref_counted<FakePeerConnectionForStats>(); + RTCStatsCollectorWrapper wrapper(pc, CreateEnvironment()); + auto callback = make_ref_counted<MockStatsCollectorCallback>(); + EXPECT_CALL(*callback, OnStatsDelivered(_)).Times(0); + // Start by canceling any ongoing tasks. There aren't actually any ongoing + // tasks, but this gives us the network cleanup task. + auto network_task = + wrapper.stats_collector()->CancelPendingRequestAndGetShutdownTask(); + // Clean up the state on the network thread. This will have the effect of + // dropping any tasks targeting the network thread. + std::move(network_task)(); + // Now, attempt to get a stats report. This will try to post a task to the + // network thread, which will be dropped. + wrapper.stats_collector()->GetStatsReport(callback); + + loop.Flush(); +} + TEST(RTCStatsCollectorTestWithFakeCollector, ThreadUsageAndResultsMerging) { - AutoThread main_thread_; + test::RunLoop loop; auto pc = make_ref_counted<FakePeerConnectionForStats>(); scoped_refptr<FakeRTCStatsCollector> stats_collector( FakeRTCStatsCollector::Create(pc.get(), CreateEnvironment(),
diff --git a/pc/webrtc_session_description_factory.cc b/pc/webrtc_session_description_factory.cc index 7f1c015..0725660 100644 --- a/pc/webrtc_session_description_factory.cc +++ b/pc/webrtc_session_description_factory.cc
@@ -178,6 +178,7 @@ WebRtcSessionDescriptionFactory::~WebRtcSessionDescriptionFactory() { RTC_DCHECK_RUN_ON(signaling_thread_); + RTC_DCHECK_DISALLOW_THREAD_BLOCKING_CALLS(); // Fail any requests that were asked for before identity generation completed. FailPendingRequests(kFailedDueToSessionShutdown);