Fix race with SctpTransport destruction and usrsctp timer thread.
The race occurs if the transport is being destroyed at the same time as
a callback occurs on the usrsctp timer thread (for example, for a
retransmission). Fixed by slightly extending the scope of mutex
acquisition to include posting a task to the network thread, where it's
safe to do further work.
Bug: chromium:1162424
Change-Id: Ia25c96fa51cd4ba2d8690ba03de8af9e9f1605ea
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/202560
Reviewed-by: Harald Alvestrand <hta@webrtc.org>
Commit-Queue: Taylor <deadbeef@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#33048}
diff --git a/media/sctp/sctp_transport.cc b/media/sctp/sctp_transport.cc
index 6bb4a8f..4a99144 100644
--- a/media/sctp/sctp_transport.cc
+++ b/media/sctp/sctp_transport.cc
@@ -27,6 +27,7 @@
#include <stdio.h>
#include <usrsctp.h>
+#include <functional>
#include <memory>
#include <unordered_map>
@@ -81,58 +82,8 @@
PPID_TEXT_LAST = 51
};
-// Maps SCTP transport ID to SctpTransport object, necessary in send threshold
-// callback and outgoing packet callback.
-// TODO(crbug.com/1076703): Remove once the underlying problem is fixed or
-// workaround is provided in usrsctp.
-class SctpTransportMap {
- public:
- SctpTransportMap() = default;
-
- // Assigns a new unused ID to the following transport.
- uintptr_t Register(cricket::SctpTransport* transport) {
- webrtc::MutexLock lock(&lock_);
- // usrsctp_connect fails with a value of 0...
- if (next_id_ == 0) {
- ++next_id_;
- }
- // In case we've wrapped around and need to find an empty spot from a
- // removed transport. Assumes we'll never be full.
- while (map_.find(next_id_) != map_.end()) {
- ++next_id_;
- if (next_id_ == 0) {
- ++next_id_;
- }
- };
- map_[next_id_] = transport;
- return next_id_++;
- }
-
- // Returns true if found.
- bool Deregister(uintptr_t id) {
- webrtc::MutexLock lock(&lock_);
- return map_.erase(id) > 0;
- }
-
- cricket::SctpTransport* Retrieve(uintptr_t id) const {
- webrtc::MutexLock lock(&lock_);
- auto it = map_.find(id);
- if (it == map_.end()) {
- return nullptr;
- }
- return it->second;
- }
-
- private:
- mutable webrtc::Mutex lock_;
-
- uintptr_t next_id_ RTC_GUARDED_BY(lock_) = 0;
- std::unordered_map<uintptr_t, cricket::SctpTransport*> map_
- RTC_GUARDED_BY(lock_);
-};
-
// Should only be modified by UsrSctpWrapper.
-ABSL_CONST_INIT SctpTransportMap* g_transport_map_ = nullptr;
+ABSL_CONST_INIT cricket::SctpTransportMap* g_transport_map_ = nullptr;
// Helper for logging SCTP messages.
#if defined(__GNUC__)
@@ -258,6 +209,82 @@
namespace cricket {
+// Maps SCTP transport ID to SctpTransport object, necessary in send threshold
+// callback and outgoing packet callback. It also provides a facility to
+// safely post a task to an SctpTransport's network thread from another thread.
+class SctpTransportMap {
+ public:
+ SctpTransportMap() = default;
+
+ // Assigns a new unused ID to the following transport.
+ uintptr_t Register(cricket::SctpTransport* transport) {
+ webrtc::MutexLock lock(&lock_);
+ // usrsctp_connect fails with a value of 0...
+ if (next_id_ == 0) {
+ ++next_id_;
+ }
+ // In case we've wrapped around and need to find an empty spot from a
+ // removed transport. Assumes we'll never be full.
+ while (map_.find(next_id_) != map_.end()) {
+ ++next_id_;
+ if (next_id_ == 0) {
+ ++next_id_;
+ }
+ };
+ map_[next_id_] = transport;
+ return next_id_++;
+ }
+
+ // Returns true if found.
+ bool Deregister(uintptr_t id) {
+ webrtc::MutexLock lock(&lock_);
+ return map_.erase(id) > 0;
+ }
+
+ // Must be called on the transport's network thread to protect against
+ // simultaneous deletion/deregistration of the transport; if that's not
+ // guaranteed, use ExecuteWithLock.
+ SctpTransport* Retrieve(uintptr_t id) const {
+ webrtc::MutexLock lock(&lock_);
+ SctpTransport* transport = RetrieveWhileHoldingLock(id);
+ if (transport) {
+ RTC_DCHECK_RUN_ON(transport->network_thread());
+ }
+ return transport;
+ }
+
+ // Posts |action| to the network thread of the transport identified by |id|
+ // and returns true if found, all while holding a lock to protect against the
+ // transport being simultaneously deleted/deregistered, or returns false if
+ // not found.
+ bool PostToTransportThread(uintptr_t id,
+ std::function<void(SctpTransport*)> action) const {
+ webrtc::MutexLock lock(&lock_);
+ SctpTransport* transport = RetrieveWhileHoldingLock(id);
+ if (!transport) {
+ return false;
+ }
+ transport->network_thread_->PostTask(ToQueuedTask(
+ transport->task_safety_, [transport, action]() { action(transport); }));
+ return true;
+ }
+
+ private:
+ SctpTransport* RetrieveWhileHoldingLock(uintptr_t id) const
+ RTC_EXCLUSIVE_LOCKS_REQUIRED(lock_) {
+ auto it = map_.find(id);
+ if (it == map_.end()) {
+ return nullptr;
+ }
+ return it->second;
+ }
+
+ mutable webrtc::Mutex lock_;
+
+ uintptr_t next_id_ RTC_GUARDED_BY(lock_) = 0;
+ std::unordered_map<uintptr_t, SctpTransport*> map_ RTC_GUARDED_BY(lock_);
+};
+
// Handles global init/deinit, and mapping from usrsctp callbacks to
// SctpTransport calls.
class SctpTransport::UsrSctpWrapper {
@@ -370,14 +397,6 @@
<< "OnSctpOutboundPacket called after usrsctp uninitialized?";
return EINVAL;
}
- SctpTransport* transport =
- g_transport_map_->Retrieve(reinterpret_cast<uintptr_t>(addr));
- if (!transport) {
- RTC_LOG(LS_ERROR)
- << "OnSctpOutboundPacket: Failed to get transport for socket ID "
- << addr;
- return EINVAL;
- }
RTC_LOG(LS_VERBOSE) << "global OnSctpOutboundPacket():"
"addr: "
<< addr << "; length: " << length
@@ -385,13 +404,23 @@
<< "; set_df: " << rtc::ToHex(set_df);
VerboseLogPacket(data, length, SCTP_DUMP_OUTBOUND);
+
// Note: We have to copy the data; the caller will delete it.
rtc::CopyOnWriteBuffer buf(reinterpret_cast<uint8_t*>(data), length);
- transport->network_thread_->PostTask(ToQueuedTask(
- transport->task_safety_, [transport, buf = std::move(buf)]() {
+ // PostsToTransportThread protects against the transport being
+ // simultaneously deregistered/deleted, since this callback may come from
+ // the SCTP timer thread and thus race with the network thread.
+ bool found = g_transport_map_->PostToTransportThread(
+ reinterpret_cast<uintptr_t>(addr), [buf](SctpTransport* transport) {
transport->OnPacketFromSctpToNetwork(buf);
- }));
+ });
+ if (!found) {
+ RTC_LOG(LS_ERROR)
+ << "OnSctpOutboundPacket: Failed to get transport for socket ID "
+ << addr;
+ return EINVAL;
+ }
return 0;
}
diff --git a/media/sctp/sctp_transport.h b/media/sctp/sctp_transport.h
index 38a89fc..bd166ef 100644
--- a/media/sctp/sctp_transport.h
+++ b/media/sctp/sctp_transport.h
@@ -281,6 +281,8 @@
// various callbacks.
uintptr_t id_ = 0;
+ friend class SctpTransportMap;
+
RTC_DISALLOW_COPY_AND_ASSIGN(SctpTransport);
};
@@ -299,6 +301,8 @@
rtc::Thread* network_thread_;
};
+class SctpTransportMap;
+
} // namespace cricket
#endif // MEDIA_SCTP_SCTP_TRANSPORT_H_