Fix race with SctpTransport destruction and usrsctp timer thread.

The race occurs if the transport is being destroyed at the same time as
a callback occurs on the usrsctp timer thread (for example, for a
retransmission). Fixed by slightly extending the scope of mutex
acquisition to include posting a task to the network thread, where it's
safe to do further work.

Bug: chromium:1162424
Change-Id: Ia25c96fa51cd4ba2d8690ba03de8af9e9f1605ea
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/202560
Reviewed-by: Harald Alvestrand <hta@webrtc.org>
Commit-Queue: Taylor <deadbeef@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#33048}
diff --git a/media/sctp/sctp_transport.cc b/media/sctp/sctp_transport.cc
index 6bb4a8f..4a99144 100644
--- a/media/sctp/sctp_transport.cc
+++ b/media/sctp/sctp_transport.cc
@@ -27,6 +27,7 @@
 #include <stdio.h>
 #include <usrsctp.h>
 
+#include <functional>
 #include <memory>
 #include <unordered_map>
 
@@ -81,58 +82,8 @@
   PPID_TEXT_LAST = 51
 };
 
-// Maps SCTP transport ID to SctpTransport object, necessary in send threshold
-// callback and outgoing packet callback.
-// TODO(crbug.com/1076703): Remove once the underlying problem is fixed or
-// workaround is provided in usrsctp.
-class SctpTransportMap {
- public:
-  SctpTransportMap() = default;
-
-  // Assigns a new unused ID to the following transport.
-  uintptr_t Register(cricket::SctpTransport* transport) {
-    webrtc::MutexLock lock(&lock_);
-    // usrsctp_connect fails with a value of 0...
-    if (next_id_ == 0) {
-      ++next_id_;
-    }
-    // In case we've wrapped around and need to find an empty spot from a
-    // removed transport. Assumes we'll never be full.
-    while (map_.find(next_id_) != map_.end()) {
-      ++next_id_;
-      if (next_id_ == 0) {
-        ++next_id_;
-      }
-    };
-    map_[next_id_] = transport;
-    return next_id_++;
-  }
-
-  // Returns true if found.
-  bool Deregister(uintptr_t id) {
-    webrtc::MutexLock lock(&lock_);
-    return map_.erase(id) > 0;
-  }
-
-  cricket::SctpTransport* Retrieve(uintptr_t id) const {
-    webrtc::MutexLock lock(&lock_);
-    auto it = map_.find(id);
-    if (it == map_.end()) {
-      return nullptr;
-    }
-    return it->second;
-  }
-
- private:
-  mutable webrtc::Mutex lock_;
-
-  uintptr_t next_id_ RTC_GUARDED_BY(lock_) = 0;
-  std::unordered_map<uintptr_t, cricket::SctpTransport*> map_
-      RTC_GUARDED_BY(lock_);
-};
-
 // Should only be modified by UsrSctpWrapper.
-ABSL_CONST_INIT SctpTransportMap* g_transport_map_ = nullptr;
+ABSL_CONST_INIT cricket::SctpTransportMap* g_transport_map_ = nullptr;
 
 // Helper for logging SCTP messages.
 #if defined(__GNUC__)
@@ -258,6 +209,82 @@
 
 namespace cricket {
 
+// Maps SCTP transport ID to SctpTransport object, necessary in send threshold
+// callback and outgoing packet callback. It also provides a facility to
+// safely post a task to an SctpTransport's network thread from another thread.
+class SctpTransportMap {
+ public:
+  SctpTransportMap() = default;
+
+  // Assigns a new unused ID to the following transport.
+  uintptr_t Register(cricket::SctpTransport* transport) {
+    webrtc::MutexLock lock(&lock_);
+    // usrsctp_connect fails with a value of 0...
+    if (next_id_ == 0) {
+      ++next_id_;
+    }
+    // In case we've wrapped around and need to find an empty spot from a
+    // removed transport. Assumes we'll never be full.
+    while (map_.find(next_id_) != map_.end()) {
+      ++next_id_;
+      if (next_id_ == 0) {
+        ++next_id_;
+      }
+    };
+    map_[next_id_] = transport;
+    return next_id_++;
+  }
+
+  // Returns true if found.
+  bool Deregister(uintptr_t id) {
+    webrtc::MutexLock lock(&lock_);
+    return map_.erase(id) > 0;
+  }
+
+  // Must be called on the transport's network thread to protect against
+  // simultaneous deletion/deregistration of the transport; if that's not
+  // guaranteed, use ExecuteWithLock.
+  SctpTransport* Retrieve(uintptr_t id) const {
+    webrtc::MutexLock lock(&lock_);
+    SctpTransport* transport = RetrieveWhileHoldingLock(id);
+    if (transport) {
+      RTC_DCHECK_RUN_ON(transport->network_thread());
+    }
+    return transport;
+  }
+
+  // Posts |action| to the network thread of the transport identified by |id|
+  // and returns true if found, all while holding a lock to protect against the
+  // transport being simultaneously deleted/deregistered, or returns false if
+  // not found.
+  bool PostToTransportThread(uintptr_t id,
+                             std::function<void(SctpTransport*)> action) const {
+    webrtc::MutexLock lock(&lock_);
+    SctpTransport* transport = RetrieveWhileHoldingLock(id);
+    if (!transport) {
+      return false;
+    }
+    transport->network_thread_->PostTask(ToQueuedTask(
+        transport->task_safety_, [transport, action]() { action(transport); }));
+    return true;
+  }
+
+ private:
+  SctpTransport* RetrieveWhileHoldingLock(uintptr_t id) const
+      RTC_EXCLUSIVE_LOCKS_REQUIRED(lock_) {
+    auto it = map_.find(id);
+    if (it == map_.end()) {
+      return nullptr;
+    }
+    return it->second;
+  }
+
+  mutable webrtc::Mutex lock_;
+
+  uintptr_t next_id_ RTC_GUARDED_BY(lock_) = 0;
+  std::unordered_map<uintptr_t, SctpTransport*> map_ RTC_GUARDED_BY(lock_);
+};
+
 // Handles global init/deinit, and mapping from usrsctp callbacks to
 // SctpTransport calls.
 class SctpTransport::UsrSctpWrapper {
@@ -370,14 +397,6 @@
           << "OnSctpOutboundPacket called after usrsctp uninitialized?";
       return EINVAL;
     }
-    SctpTransport* transport =
-        g_transport_map_->Retrieve(reinterpret_cast<uintptr_t>(addr));
-    if (!transport) {
-      RTC_LOG(LS_ERROR)
-          << "OnSctpOutboundPacket: Failed to get transport for socket ID "
-          << addr;
-      return EINVAL;
-    }
     RTC_LOG(LS_VERBOSE) << "global OnSctpOutboundPacket():"
                            "addr: "
                         << addr << "; length: " << length
@@ -385,13 +404,23 @@
                         << "; set_df: " << rtc::ToHex(set_df);
 
     VerboseLogPacket(data, length, SCTP_DUMP_OUTBOUND);
+
     // Note: We have to copy the data; the caller will delete it.
     rtc::CopyOnWriteBuffer buf(reinterpret_cast<uint8_t*>(data), length);
 
-    transport->network_thread_->PostTask(ToQueuedTask(
-        transport->task_safety_, [transport, buf = std::move(buf)]() {
+    // PostsToTransportThread protects against the transport being
+    // simultaneously deregistered/deleted, since this callback may come from
+    // the SCTP timer thread and thus race with the network thread.
+    bool found = g_transport_map_->PostToTransportThread(
+        reinterpret_cast<uintptr_t>(addr), [buf](SctpTransport* transport) {
           transport->OnPacketFromSctpToNetwork(buf);
-        }));
+        });
+    if (!found) {
+      RTC_LOG(LS_ERROR)
+          << "OnSctpOutboundPacket: Failed to get transport for socket ID "
+          << addr;
+      return EINVAL;
+    }
 
     return 0;
   }
diff --git a/media/sctp/sctp_transport.h b/media/sctp/sctp_transport.h
index 38a89fc..bd166ef 100644
--- a/media/sctp/sctp_transport.h
+++ b/media/sctp/sctp_transport.h
@@ -281,6 +281,8 @@
   // various callbacks.
   uintptr_t id_ = 0;
 
+  friend class SctpTransportMap;
+
   RTC_DISALLOW_COPY_AND_ASSIGN(SctpTransport);
 };
 
@@ -299,6 +301,8 @@
   rtc::Thread* network_thread_;
 };
 
+class SctpTransportMap;
+
 }  // namespace cricket
 
 #endif  // MEDIA_SCTP_SCTP_TRANSPORT_H_