Revert "Fix race between destroying SctpTransport and receiving notification on timer thread."

This reverts commit a88fe7be146b9b85575504d4d5193c007f2e3de4.

Reason for revert: Breaks downstream test, still investigating.

Original change's description:
> Fix race between destroying SctpTransport and receiving notification on timer thread.
>
> This gets rid of the SctpTransportMap::Retrieve method and forces
> everything to go through PostToTransportThread, which behaves safely
> with relation to the transport's destruction.
>
> Bug: webrtc:12467
> Change-Id: Id4a723c2c985be2a368d2cc5c5e62deb04c509ab
> Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/208800
> Reviewed-by: Niels Moller <nisse@webrtc.org>
> Commit-Queue: Taylor <deadbeef@webrtc.org>
> Cr-Commit-Position: refs/heads/master@{#33364}

TBR=nisse@webrtc.org

Bug: webrtc:12467
Change-Id: Ib5d815a2cbca4feb25f360bff7ed62c02d1910a0
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/209820
Reviewed-by: Taylor <deadbeef@webrtc.org>
Commit-Queue: Taylor <deadbeef@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#33386}
diff --git a/media/sctp/sctp_transport.cc b/media/sctp/sctp_transport.cc
index 3a1574c..5878f45 100644
--- a/media/sctp/sctp_transport.cc
+++ b/media/sctp/sctp_transport.cc
@@ -20,7 +20,6 @@
 // Successful return value from usrsctp callbacks. Is not actually used by
 // usrsctp, but all example programs for usrsctp use 1 as their return value.
 constexpr int kSctpSuccessReturn = 1;
-constexpr int kSctpErrorReturn = 0;
 
 }  // namespace
 
@@ -28,6 +27,7 @@
 #include <stdio.h>
 #include <usrsctp.h>
 
+#include <functional>
 #include <memory>
 #include <unordered_map>
 
@@ -252,20 +252,31 @@
     return map_.erase(id) > 0;
   }
 
+  // Must be called on the transport's network thread to protect against
+  // simultaneous deletion/deregistration of the transport; if that's not
+  // guaranteed, use ExecuteWithLock.
+  SctpTransport* Retrieve(uintptr_t id) const {
+    webrtc::MutexLock lock(&lock_);
+    SctpTransport* transport = RetrieveWhileHoldingLock(id);
+    if (transport) {
+      RTC_DCHECK_RUN_ON(transport->network_thread());
+    }
+    return transport;
+  }
+
   // Posts |action| to the network thread of the transport identified by |id|
   // and returns true if found, all while holding a lock to protect against the
   // transport being simultaneously deleted/deregistered, or returns false if
   // not found.
-  template <typename F>
-  bool PostToTransportThread(uintptr_t id, F action) const {
+  bool PostToTransportThread(uintptr_t id,
+                             std::function<void(SctpTransport*)> action) const {
     webrtc::MutexLock lock(&lock_);
     SctpTransport* transport = RetrieveWhileHoldingLock(id);
     if (!transport) {
       return false;
     }
     transport->network_thread_->PostTask(ToQueuedTask(
-        transport->task_safety_,
-        [transport, action{std::move(action)}]() { action(transport); }));
+        transport->task_safety_, [transport, action]() { action(transport); }));
     return true;
   }
 
@@ -418,7 +429,7 @@
     if (!found) {
       RTC_LOG(LS_ERROR)
           << "OnSctpOutboundPacket: Failed to get transport for socket ID "
-          << addr << "; possibly was already destroyed.";
+          << addr;
       return EINVAL;
     }
 
@@ -436,49 +447,28 @@
                                  struct sctp_rcvinfo rcv,
                                  int flags,
                                  void* ulp_info) {
-    struct DeleteByFree {
-      void operator()(void* p) const { free(p); }
-    };
-    std::unique_ptr<void, DeleteByFree> owned_data(data, DeleteByFree());
-
-    absl::optional<uintptr_t> id = GetTransportIdFromSocket(sock);
-    if (!id) {
+    SctpTransport* transport = GetTransportFromSocket(sock);
+    if (!transport) {
       RTC_LOG(LS_ERROR)
-          << "OnSctpInboundPacket: Failed to get transport ID from socket "
-          << sock;
-      return kSctpErrorReturn;
+          << "OnSctpInboundPacket: Failed to get transport for socket " << sock
+          << "; possibly was already destroyed.";
+      free(data);
+      return 0;
     }
-
-    if (!g_transport_map_) {
-      RTC_LOG(LS_ERROR)
-          << "OnSctpInboundPacket called after usrsctp uninitialized?";
-      return kSctpErrorReturn;
-    }
-    // PostsToTransportThread protects against the transport being
-    // simultaneously deregistered/deleted, since this callback may come from
-    // the SCTP timer thread and thus race with the network thread.
-    bool found = g_transport_map_->PostToTransportThread(
-        *id, [owned_data{std::move(owned_data)}, length, rcv,
-              flags](SctpTransport* transport) {
-          transport->OnDataOrNotificationFromSctp(owned_data.get(), length, rcv,
-                                                  flags);
-        });
-    if (!found) {
-      RTC_LOG(LS_ERROR)
-          << "OnSctpInboundPacket: Failed to get transport for socket ID "
-          << *id << "; possibly was already destroyed.";
-      return kSctpErrorReturn;
-    }
-    return kSctpSuccessReturn;
+    // Sanity check that both methods of getting the SctpTransport pointer
+    // yield the same result.
+    RTC_CHECK_EQ(transport, static_cast<SctpTransport*>(ulp_info));
+    int result =
+        transport->OnDataOrNotificationFromSctp(data, length, rcv, flags);
+    free(data);
+    return result;
   }
 
-  static absl::optional<uintptr_t> GetTransportIdFromSocket(
-      struct socket* sock) {
-    absl::optional<uintptr_t> ret;
+  static SctpTransport* GetTransportFromSocket(struct socket* sock) {
     struct sockaddr* addrs = nullptr;
     int naddrs = usrsctp_getladdrs(sock, 0, &addrs);
     if (naddrs <= 0 || addrs[0].sa_family != AF_CONN) {
-      return ret;
+      return nullptr;
     }
     // usrsctp_getladdrs() returns the addresses bound to this socket, which
     // contains the SctpTransport id as sconn_addr.  Read the id,
@@ -487,10 +477,17 @@
     // id of the transport that created them, so [0] is as good as any other.
     struct sockaddr_conn* sconn =
         reinterpret_cast<struct sockaddr_conn*>(&addrs[0]);
-    ret = reinterpret_cast<uintptr_t>(sconn->sconn_addr);
+    if (!g_transport_map_) {
+      RTC_LOG(LS_ERROR)
+          << "GetTransportFromSocket called after usrsctp uninitialized?";
+      usrsctp_freeladdrs(addrs);
+      return nullptr;
+    }
+    SctpTransport* transport = g_transport_map_->Retrieve(
+        reinterpret_cast<uintptr_t>(sconn->sconn_addr));
     usrsctp_freeladdrs(addrs);
 
-    return ret;
+    return transport;
   }
 
   // TODO(crbug.com/webrtc/11899): This is a legacy callback signature, remove
@@ -499,26 +496,14 @@
     // Fired on our I/O thread. SctpTransport::OnPacketReceived() gets
     // a packet containing acknowledgments, which goes into usrsctp_conninput,
     // and then back here.
-    absl::optional<uintptr_t> id = GetTransportIdFromSocket(sock);
-    if (!id) {
+    SctpTransport* transport = GetTransportFromSocket(sock);
+    if (!transport) {
       RTC_LOG(LS_ERROR)
-          << "SendThresholdCallback: Failed to get transport ID from socket "
-          << sock;
+          << "SendThresholdCallback: Failed to get transport for socket "
+          << sock << "; possibly was already destroyed.";
       return 0;
     }
-    if (!g_transport_map_) {
-      RTC_LOG(LS_ERROR)
-          << "SendThresholdCallback called after usrsctp uninitialized?";
-      return 0;
-    }
-    bool found = g_transport_map_->PostToTransportThread(
-        *id,
-        [](SctpTransport* transport) { transport->OnSendThresholdCallback(); });
-    if (!found) {
-      RTC_LOG(LS_ERROR)
-          << "SendThresholdCallback: Failed to get transport for socket ID "
-          << *id << "; possibly was already destroyed.";
-    }
+    transport->OnSendThresholdCallback();
     return 0;
   }
 
@@ -528,26 +513,17 @@
     // Fired on our I/O thread. SctpTransport::OnPacketReceived() gets
     // a packet containing acknowledgments, which goes into usrsctp_conninput,
     // and then back here.
-    absl::optional<uintptr_t> id = GetTransportIdFromSocket(sock);
-    if (!id) {
+    SctpTransport* transport = GetTransportFromSocket(sock);
+    if (!transport) {
       RTC_LOG(LS_ERROR)
-          << "SendThresholdCallback: Failed to get transport ID from socket "
-          << sock;
+          << "SendThresholdCallback: Failed to get transport for socket "
+          << sock << "; possibly was already destroyed.";
       return 0;
     }
-    if (!g_transport_map_) {
-      RTC_LOG(LS_ERROR)
-          << "SendThresholdCallback called after usrsctp uninitialized?";
-      return 0;
-    }
-    bool found = g_transport_map_->PostToTransportThread(
-        *id,
-        [](SctpTransport* transport) { transport->OnSendThresholdCallback(); });
-    if (!found) {
-      RTC_LOG(LS_ERROR)
-          << "SendThresholdCallback: Failed to get transport for socket ID "
-          << *id << "; possibly was already destroyed.";
-    }
+    // Sanity check that both methods of getting the SctpTransport pointer
+    // yield the same result.
+    RTC_CHECK_EQ(transport, static_cast<SctpTransport*>(ulp_info));
+    transport->OnSendThresholdCallback();
     return 0;
   }
 };
@@ -1199,25 +1175,24 @@
                          rtc::PacketOptions(), PF_NORMAL);
 }
 
-void SctpTransport::InjectDataOrNotificationFromSctpForTesting(
+int SctpTransport::InjectDataOrNotificationFromSctpForTesting(
     const void* data,
     size_t length,
     struct sctp_rcvinfo rcv,
     int flags) {
-  OnDataOrNotificationFromSctp(data, length, rcv, flags);
+  return OnDataOrNotificationFromSctp(data, length, rcv, flags);
 }
 
-void SctpTransport::OnDataOrNotificationFromSctp(const void* data,
-                                                 size_t length,
-                                                 struct sctp_rcvinfo rcv,
-                                                 int flags) {
-  RTC_DCHECK_RUN_ON(network_thread_);
+int SctpTransport::OnDataOrNotificationFromSctp(const void* data,
+                                                size_t length,
+                                                struct sctp_rcvinfo rcv,
+                                                int flags) {
   // If data is NULL, the SCTP association has been closed.
   if (!data) {
     RTC_LOG(LS_INFO) << debug_name_
                      << "->OnDataOrNotificationFromSctp(...): "
                         "No data; association closed.";
-    return;
+    return kSctpSuccessReturn;
   }
 
   // Handle notifications early.
@@ -1230,10 +1205,14 @@
         << "->OnDataOrNotificationFromSctp(...): SCTP notification"
         << " length=" << length;
 
+    // Copy and dispatch asynchronously
     rtc::CopyOnWriteBuffer notification(reinterpret_cast<const uint8_t*>(data),
                                         length);
-    OnNotificationFromSctp(notification);
-    return;
+    network_thread_->PostTask(ToQueuedTask(
+        task_safety_, [this, notification = std::move(notification)]() {
+          OnNotificationFromSctp(notification);
+        }));
+    return kSctpSuccessReturn;
   }
 
   // Log data chunk
@@ -1251,7 +1230,7 @@
     // Unexpected PPID, dropping
     RTC_LOG(LS_ERROR) << "Received an unknown PPID " << ppid
                       << " on an SCTP packet.  Dropping.";
-    return;
+    return kSctpSuccessReturn;
   }
 
   // Expect only continuation messages belonging to the same SID. The SCTP
@@ -1287,7 +1266,7 @@
     if (partial_incoming_message_.size() < kSctpSendBufferSize) {
       // We still have space in the buffer. Continue buffering chunks until
       // the message is complete before handing it out.
-      return;
+      return kSctpSuccessReturn;
     } else {
       // The sender is exceeding the maximum message size that we announced.
       // Spit out a warning but still hand out the partial message. Note that
@@ -1301,9 +1280,18 @@
     }
   }
 
-  // Dispatch the complete message and reset the message buffer.
-  OnDataFromSctpToTransport(params, partial_incoming_message_);
+  // Dispatch the complete message.
+  // The ownership of the packet transfers to |invoker_|. Using
+  // CopyOnWriteBuffer is the most convenient way to do this.
+  network_thread_->PostTask(webrtc::ToQueuedTask(
+      task_safety_, [this, params = std::move(params),
+                     message = partial_incoming_message_]() {
+        OnDataFromSctpToTransport(params, message);
+      }));
+
+  // Reset the message buffer
   partial_incoming_message_.Clear();
+  return kSctpSuccessReturn;
 }
 
 void SctpTransport::OnDataFromSctpToTransport(
diff --git a/media/sctp/sctp_transport.h b/media/sctp/sctp_transport.h
index e357e70..bd166ef 100644
--- a/media/sctp/sctp_transport.h
+++ b/media/sctp/sctp_transport.h
@@ -96,10 +96,10 @@
   void set_debug_name_for_testing(const char* debug_name) override {
     debug_name_ = debug_name;
   }
-  void InjectDataOrNotificationFromSctpForTesting(const void* data,
-                                                  size_t length,
-                                                  struct sctp_rcvinfo rcv,
-                                                  int flags);
+  int InjectDataOrNotificationFromSctpForTesting(const void* data,
+                                                 size_t length,
+                                                 struct sctp_rcvinfo rcv,
+                                                 int flags);
 
   // Exposed to allow Post call from c-callbacks.
   // TODO(deadbeef): Remove this or at least make it return a const pointer.
@@ -180,12 +180,12 @@
   // Called using |invoker_| to send packet on the network.
   void OnPacketFromSctpToNetwork(const rtc::CopyOnWriteBuffer& buffer);
 
-  // Called on the network thread.
+  // Called on the SCTP thread.
   // Flags are standard socket API flags (RFC 6458).
-  void OnDataOrNotificationFromSctp(const void* data,
-                                    size_t length,
-                                    struct sctp_rcvinfo rcv,
-                                    int flags);
+  int OnDataOrNotificationFromSctp(const void* data,
+                                   size_t length,
+                                   struct sctp_rcvinfo rcv,
+                                   int flags);
   // Called using |invoker_| to decide what to do with the data.
   void OnDataFromSctpToTransport(const ReceiveDataParams& params,
                                  const rtc::CopyOnWriteBuffer& buffer);
diff --git a/media/sctp/sctp_transport_unittest.cc b/media/sctp/sctp_transport_unittest.cc
index 120f4e5..98a9122 100644
--- a/media/sctp/sctp_transport_unittest.cc
+++ b/media/sctp/sctp_transport_unittest.cc
@@ -282,8 +282,8 @@
   meta.rcv_tsn = 42;
   meta.rcv_cumtsn = 42;
   chunk.SetData("meow?", 5);
-  transport1->InjectDataOrNotificationFromSctpForTesting(chunk.data(),
-                                                         chunk.size(), meta, 0);
+  EXPECT_EQ(1, transport1->InjectDataOrNotificationFromSctpForTesting(
+                   chunk.data(), chunk.size(), meta, 0));
 
   // Inject a notification in between chunks.
   union sctp_notification notification;
@@ -292,15 +292,15 @@
   notification.sn_header.sn_type = SCTP_PEER_ADDR_CHANGE;
   notification.sn_header.sn_flags = 0;
   notification.sn_header.sn_length = sizeof(notification);
-  transport1->InjectDataOrNotificationFromSctpForTesting(
-      &notification, sizeof(notification), {0}, MSG_NOTIFICATION);
+  EXPECT_EQ(1, transport1->InjectDataOrNotificationFromSctpForTesting(
+                   &notification, sizeof(notification), {0}, MSG_NOTIFICATION));
 
   // Inject chunk 2/2
   meta.rcv_tsn = 42;
   meta.rcv_cumtsn = 43;
   chunk.SetData(" rawr!", 6);
-  transport1->InjectDataOrNotificationFromSctpForTesting(
-      chunk.data(), chunk.size(), meta, MSG_EOR);
+  EXPECT_EQ(1, transport1->InjectDataOrNotificationFromSctpForTesting(
+                   chunk.data(), chunk.size(), meta, MSG_EOR));
 
   // Expect the message to contain both chunks.
   EXPECT_TRUE_WAIT(ReceivedData(&recv1, 1, "meow? rawr!"), kDefaultTimeout);