Unbind VirtualSocket from rtc::MessageHandler

Instead protect pending tasks with a shared object.
Some tests destroy VirtualSocket on a different thread than it is used on,
Some tests destroy VirtualSocket together with VirtualSocketServer after
associated thread is deleted, thus complicated check is used to ensure
VirtualSockets are safe to use.

Bug: webrtc:9702
Change-Id: I1a19cd24ac6a598a1cde64434104cad0b750096e
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/274460
Reviewed-by: Tomas Gunnarsson <tommi@webrtc.org>
Commit-Queue: Danil Chapovalov <danilchap@webrtc.org>
Cr-Commit-Position: refs/heads/main@{#38103}
diff --git a/rtc_base/BUILD.gn b/rtc_base/BUILD.gn
index 0e397df..7dc32d3 100644
--- a/rtc_base/BUILD.gn
+++ b/rtc_base/BUILD.gn
@@ -1399,6 +1399,10 @@
     ":stringutils",
     ":threading",
     ":timeutils",
+    "../api:make_ref_counted",
+    "../api:refcountedbase",
+    "../api:scoped_refptr",
+    "../api/task_queue",
     "../api/units:time_delta",
     "../api/units:timestamp",
     "../test:scoped_key_value_config",
@@ -1411,6 +1415,7 @@
     "//third_party/abseil-cpp/absl/algorithm:container",
     "//third_party/abseil-cpp/absl/memory",
     "//third_party/abseil-cpp/absl/strings",
+    "//third_party/abseil-cpp/absl/types:optional",
   ]
 }
 
diff --git a/rtc_base/virtual_socket_server.cc b/rtc_base/virtual_socket_server.cc
index 6e5eeb2..efc206b 100644
--- a/rtc_base/virtual_socket_server.cc
+++ b/rtc_base/virtual_socket_server.cc
@@ -29,6 +29,11 @@
 #include "rtc_base/time_utils.h"
 
 namespace rtc {
+
+using ::webrtc::MutexLock;
+using ::webrtc::TaskQueueBase;
+using ::webrtc::TimeDelta;
+
 #if defined(WEBRTC_WIN)
 const in_addr kInitialNextIPv4 = {{{0x01, 0, 0, 0}}};
 #else
@@ -53,16 +58,9 @@
 // Note: The current algorithm doesn't work for sample sizes smaller than this.
 const int NUM_SAMPLES = 1000;
 
-enum {
-  MSG_ID_PACKET,
-  MSG_ID_CONNECT,
-  MSG_ID_DISCONNECT,
-  MSG_ID_SIGNALREADEVENT,
-};
-
 // Packets are passed between sockets as messages.  We copy the data just like
 // the kernel does.
-class Packet : public MessageData {
+class Packet {
  public:
   Packet(const char* data, size_t size, const SocketAddress& from)
       : size_(size), consumed_(0), from_(from) {
@@ -71,7 +69,7 @@
     memcpy(data_, data, size_);
   }
 
-  ~Packet() override { delete[] data_; }
+  ~Packet() { delete[] data_; }
 
   const char* data() const { return data_ + consumed_; }
   size_t size() const { return size_ - consumed_; }
@@ -89,17 +87,11 @@
   SocketAddress from_;
 };
 
-struct MessageAddress : public MessageData {
-  explicit MessageAddress(const SocketAddress& a) : addr(a) {}
-  SocketAddress addr;
-};
-
 VirtualSocket::VirtualSocket(VirtualSocketServer* server, int family, int type)
     : server_(server),
       type_(type),
       state_(CS_CLOSED),
       error_(0),
-      listen_queue_(nullptr),
       network_size_(0),
       recv_buffer_size_(0),
       bound_(false),
@@ -111,11 +103,6 @@
 
 VirtualSocket::~VirtualSocket() {
   Close();
-
-  for (RecvBuffer::iterator it = recv_buffer_.begin(); it != recv_buffer_.end();
-       ++it) {
-    delete *it;
-  }
 }
 
 SocketAddress VirtualSocket::GetLocalAddress() const {
@@ -151,6 +138,75 @@
   return InitiateConnect(addr, true);
 }
 
+VirtualSocket::SafetyBlock::SafetyBlock(VirtualSocket* socket)
+    : socket_(*socket) {}
+
+VirtualSocket::SafetyBlock::~SafetyBlock() {
+  // Ensure `SetNotAlive` was called and there is nothing left to cleanup.
+  RTC_DCHECK(!alive_);
+  RTC_DCHECK(posted_connects_.empty());
+  RTC_DCHECK(recv_buffer_.empty());
+  RTC_DCHECK(!listen_queue_.has_value());
+}
+
+void VirtualSocket::SafetyBlock::SetNotAlive() {
+  VirtualSocketServer* const server = socket_.server_;
+  const SocketAddress& local_addr = socket_.local_addr_;
+
+  MutexLock lock(&mutex_);
+  // Cancel pending sockets
+  if (listen_queue_.has_value()) {
+    for (const SocketAddress& remote_addr : *listen_queue_) {
+      server->Disconnect(remote_addr);
+    }
+    listen_queue_ = absl::nullopt;
+  }
+
+  // Cancel potential connects
+  for (const SocketAddress& remote_addr : posted_connects_) {
+    // Lookup remote side.
+    VirtualSocket* lookup_socket =
+        server->LookupConnection(local_addr, remote_addr);
+    if (lookup_socket) {
+      // Server socket, remote side is a socket retreived by accept. Accepted
+      // sockets are not bound so we will not find it by looking in the
+      // bindings table.
+      server->Disconnect(lookup_socket);
+      server->RemoveConnection(local_addr, remote_addr);
+    } else {
+      server->Disconnect(remote_addr);
+    }
+  }
+  posted_connects_.clear();
+
+  recv_buffer_.clear();
+
+  alive_ = false;
+}
+
+void VirtualSocket::SafetyBlock::PostSignalReadEvent() {
+  if (pending_read_signal_event_) {
+    // Avoid posting multiple times.
+    return;
+  }
+
+  pending_read_signal_event_ = true;
+  rtc::scoped_refptr<SafetyBlock> safety(this);
+  socket_.server_->msg_queue_->PostTask(
+      [safety = std::move(safety)] { safety->MaybeSignalReadEvent(); });
+}
+
+void VirtualSocket::SafetyBlock::MaybeSignalReadEvent() {
+  {
+    MutexLock lock(&mutex_);
+    pending_read_signal_event_ = false;
+    if (!alive_ || recv_buffer_.empty()) {
+      return;
+    }
+  }
+  socket_.SignalReadEvent(&socket_);
+}
+
 int VirtualSocket::Close() {
   if (!local_addr_.IsNil() && bound_) {
     // Remove from the binding table.
@@ -158,30 +214,12 @@
     bound_ = false;
   }
 
-  if (SOCK_STREAM == type_) {
-    webrtc::MutexLock lock(&mutex_);
-
-    // Cancel pending sockets
-    if (listen_queue_) {
-      while (!listen_queue_->empty()) {
-        SocketAddress addr = listen_queue_->front();
-
-        // Disconnect listening socket.
-        server_->Disconnect(addr);
-        listen_queue_->pop_front();
-      }
-      listen_queue_ = nullptr;
-    }
-    // Disconnect stream sockets
-    if (CS_CONNECTED == state_) {
-      server_->Disconnect(local_addr_, remote_addr_);
-    }
-    // Cancel potential connects
-    server_->CancelConnects(this);
+  // Disconnect stream sockets
+  if (state_ == CS_CONNECTED && type_ == SOCK_STREAM) {
+    server_->Disconnect(local_addr_, remote_addr_);
   }
 
-  // Clear incoming packets and disconnect messages
-  server_->Clear(this);
+  safety_->SetNotAlive();
 
   state_ = CS_CLOSED;
   local_addr_.Clear();
@@ -228,33 +266,13 @@
     *timestamp = -1;
   }
 
-  webrtc::MutexLock lock(&mutex_);
-  // If we don't have a packet, then either error or wait for one to arrive.
-  if (recv_buffer_.empty()) {
+  int data_read = safety_->RecvFrom(pv, cb, *paddr);
+  if (data_read < 0) {
     error_ = EAGAIN;
     return -1;
   }
 
-  // Return the packet at the front of the queue.
-  Packet* packet = recv_buffer_.front();
-  size_t data_read = std::min(cb, packet->size());
-  memcpy(pv, packet->data(), data_read);
-  *paddr = packet->from();
-
-  if (data_read < packet->size()) {
-    packet->Consume(data_read);
-  } else {
-    recv_buffer_.pop_front();
-    delete packet;
-  }
-
-  // To behave like a real socket, SignalReadEvent should fire in the next
-  // message loop pass if there's still data buffered.
-  if (!recv_buffer_.empty()) {
-    server_->PostSignalReadEvent(this);
-  }
-
-  if (SOCK_STREAM == type_) {
+  if (type_ == SOCK_STREAM) {
     bool was_full = (recv_buffer_size_ == server_->recv_buffer_capacity());
     recv_buffer_size_ -= data_read;
     if (was_full) {
@@ -262,51 +280,97 @@
     }
   }
 
-  return static_cast<int>(data_read);
+  return data_read;
+}
+
+int VirtualSocket::SafetyBlock::RecvFrom(void* buffer,
+                                         size_t size,
+                                         SocketAddress& addr) {
+  MutexLock lock(&mutex_);
+  // If we don't have a packet, then either error or wait for one to arrive.
+  if (recv_buffer_.empty()) {
+    return -1;
+  }
+
+  // Return the packet at the front of the queue.
+  Packet& packet = *recv_buffer_.front();
+  size_t data_read = std::min(size, packet.size());
+  memcpy(buffer, packet.data(), data_read);
+  addr = packet.from();
+
+  if (data_read < packet.size()) {
+    packet.Consume(data_read);
+  } else {
+    recv_buffer_.pop_front();
+  }
+
+  // To behave like a real socket, SignalReadEvent should fire if there's still
+  // data buffered.
+  if (!recv_buffer_.empty()) {
+    PostSignalReadEvent();
+  }
+
+  return data_read;
 }
 
 int VirtualSocket::Listen(int backlog) {
-  webrtc::MutexLock lock(&mutex_);
   RTC_DCHECK(SOCK_STREAM == type_);
   RTC_DCHECK(CS_CLOSED == state_);
   if (local_addr_.IsNil()) {
     error_ = EINVAL;
     return -1;
   }
-  RTC_DCHECK(nullptr == listen_queue_);
-  listen_queue_ = std::make_unique<ListenQueue>();
+  safety_->Listen();
   state_ = CS_CONNECTING;
   return 0;
 }
 
+void VirtualSocket::SafetyBlock::Listen() {
+  MutexLock lock(&mutex_);
+  RTC_DCHECK(!listen_queue_.has_value());
+  listen_queue_.emplace();
+}
+
 VirtualSocket* VirtualSocket::Accept(SocketAddress* paddr) {
-  webrtc::MutexLock lock(&mutex_);
-  if (nullptr == listen_queue_) {
-    error_ = EINVAL;
+  SafetyBlock::AcceptResult result = safety_->Accept();
+  if (result.error != 0) {
+    error_ = result.error;
     return nullptr;
   }
+  if (paddr) {
+    *paddr = result.remote_addr;
+  }
+  return result.socket.release();
+}
+
+VirtualSocket::SafetyBlock::AcceptResult VirtualSocket::SafetyBlock::Accept() {
+  AcceptResult result;
+  MutexLock lock(&mutex_);
+  RTC_DCHECK(alive_);
+  if (!listen_queue_.has_value()) {
+    result.error = EINVAL;
+    return result;
+  }
   while (!listen_queue_->empty()) {
-    VirtualSocket* socket = new VirtualSocket(server_, AF_INET, type_);
+    auto socket = std::make_unique<VirtualSocket>(socket_.server_, AF_INET,
+                                                  socket_.type_);
 
     // Set the new local address to the same as this server socket.
-    socket->SetLocalAddress(local_addr_);
+    socket->SetLocalAddress(socket_.local_addr_);
     // Sockets made from a socket that 'was Any' need to inherit that.
-    socket->set_was_any(was_any_);
-    SocketAddress remote_addr(listen_queue_->front());
-    int result = socket->InitiateConnect(remote_addr, false);
+    socket->set_was_any(socket_.was_any());
+    SocketAddress remote_addr = listen_queue_->front();
     listen_queue_->pop_front();
-    if (result != 0) {
-      delete socket;
+    if (socket->InitiateConnect(remote_addr, false) != 0) {
       continue;
     }
     socket->CompleteConnect(remote_addr);
-    if (paddr) {
-      *paddr = remote_addr;
-    }
-    return socket;
+    result.socket = std::move(socket);
+    result.remote_addr = remote_addr;
+    return result;
   }
-  error_ = EWOULDBLOCK;
-  return nullptr;
+  result.error = EWOULDBLOCK;
+  return result;
 }
 
 int VirtualSocket::GetError() const {
@@ -335,59 +399,109 @@
   return 0;  // 0 is success to emulate setsockopt()
 }
 
-void VirtualSocket::OnMessage(Message* pmsg) {
-  bool signal_read_event = false;
-  bool signal_close_event = false;
-  bool signal_connect_event = false;
-  int error_to_signal = 0;
-  {
-    webrtc::MutexLock lock(&mutex_);
-    if (pmsg->message_id == MSG_ID_PACKET) {
-      RTC_DCHECK(nullptr != pmsg->pdata);
-      Packet* packet = static_cast<Packet*>(pmsg->pdata);
+void VirtualSocket::PostPacket(TimeDelta delay,
+                               std::unique_ptr<Packet> packet) {
+  rtc::scoped_refptr<SafetyBlock> safety = safety_;
+  VirtualSocket* socket = this;
+  server_->msg_queue_->PostDelayedTask(
+      [safety = std::move(safety), socket,
+       packet = std::move(packet)]() mutable {
+        if (safety->AddPacket(std::move(packet))) {
+          socket->SignalReadEvent(socket);
+        }
+      },
+      delay);
+}
 
-      recv_buffer_.push_back(packet);
-      signal_read_event = true;
-    } else if (pmsg->message_id == MSG_ID_CONNECT) {
-      RTC_DCHECK(nullptr != pmsg->pdata);
-      MessageAddress* data = static_cast<MessageAddress*>(pmsg->pdata);
-      if (listen_queue_ != nullptr) {
-        listen_queue_->push_back(data->addr);
-        signal_read_event = true;
-      } else if ((SOCK_STREAM == type_) && (CS_CONNECTING == state_)) {
-        CompleteConnect(data->addr);
-        signal_connect_event = true;
-      } else {
-        RTC_LOG(LS_VERBOSE)
-            << "Socket at " << local_addr_.ToString() << " is not listening";
-        server_->Disconnect(data->addr);
-      }
-      delete data;
-    } else if (pmsg->message_id == MSG_ID_DISCONNECT) {
-      RTC_DCHECK(SOCK_STREAM == type_);
-      if (CS_CLOSED != state_) {
-        error_to_signal = (CS_CONNECTING == state_) ? ECONNREFUSED : 0;
-        state_ = CS_CLOSED;
-        remote_addr_.Clear();
-        signal_close_event = true;
-      }
-    } else if (pmsg->message_id == MSG_ID_SIGNALREADEVENT) {
-      signal_read_event = !recv_buffer_.empty();
-    } else {
-      RTC_DCHECK_NOTREACHED();
+bool VirtualSocket::SafetyBlock::AddPacket(std::unique_ptr<Packet> packet) {
+  MutexLock lock(&mutex_);
+  if (alive_) {
+    recv_buffer_.push_back(std::move(packet));
+  }
+  return alive_;
+}
+
+void VirtualSocket::PostConnect(TimeDelta delay,
+                                const SocketAddress& remote_addr) {
+  safety_->PostConnect(delay, remote_addr);
+}
+
+void VirtualSocket::SafetyBlock::PostConnect(TimeDelta delay,
+                                             const SocketAddress& remote_addr) {
+  rtc::scoped_refptr<SafetyBlock> safety(this);
+
+  MutexLock lock(&mutex_);
+  RTC_DCHECK(alive_);
+  // Save addresses of the pending connects to allow propertly disconnect them
+  // if socket closes before delayed task below runs.
+  // `posted_connects_` is an std::list, thus its iterators are valid while the
+  // element is in the list. It can be removed either in the `Connect` just
+  // below or by calling SetNotAlive function, thus inside `Connect` `it` should
+  // be valid when alive_ == true.
+  auto it = posted_connects_.insert(posted_connects_.end(), remote_addr);
+  auto task = [safety = std::move(safety), it] {
+    switch (safety->Connect(it)) {
+      case Signal::kNone:
+        break;
+      case Signal::kReadEvent:
+        safety->socket_.SignalReadEvent(&safety->socket_);
+        break;
+      case Signal::kConnectEvent:
+        safety->socket_.SignalConnectEvent(&safety->socket_);
+        break;
     }
+  };
+  socket_.server_->msg_queue_->PostDelayedTask(std::move(task), delay);
+}
+
+VirtualSocket::SafetyBlock::Signal VirtualSocket::SafetyBlock::Connect(
+    VirtualSocket::SafetyBlock::PostedConnects::iterator remote_addr_it) {
+  MutexLock lock(&mutex_);
+  if (!alive_) {
+    return Signal::kNone;
   }
-  // Signal events without holding `mutex_`, to avoid recursive locking, as well
-  // as issues with sigslot and lock order.
-  if (signal_read_event) {
-    SignalReadEvent(this);
+  RTC_DCHECK(!posted_connects_.empty());
+  SocketAddress remote_addr = *remote_addr_it;
+  posted_connects_.erase(remote_addr_it);
+
+  if (listen_queue_.has_value()) {
+    listen_queue_->push_back(remote_addr);
+    return Signal::kReadEvent;
   }
-  if (signal_close_event) {
-    SignalCloseEvent(this, error_to_signal);
+  if (socket_.type_ == SOCK_STREAM && socket_.state_ == CS_CONNECTING) {
+    socket_.CompleteConnect(remote_addr);
+    return Signal::kConnectEvent;
   }
-  if (signal_connect_event) {
-    SignalConnectEvent(this);
-  }
+  RTC_LOG(LS_VERBOSE) << "Socket at " << socket_.local_addr_.ToString()
+                      << " is not listening";
+  socket_.server_->Disconnect(remote_addr);
+  return Signal::kNone;
+}
+
+bool VirtualSocket::SafetyBlock::IsAlive() {
+  MutexLock lock(&mutex_);
+  return alive_;
+}
+
+void VirtualSocket::PostDisconnect(TimeDelta delay) {
+  // Posted task may outlive this. Use different name for `this` inside the task
+  // to avoid accidental unsafe `this->safety_` instead of safe `safety`
+  VirtualSocket* socket = this;
+  rtc::scoped_refptr<SafetyBlock> safety = safety_;
+  auto task = [safety = std::move(safety), socket] {
+    if (!safety->IsAlive()) {
+      return;
+    }
+    RTC_DCHECK_EQ(socket->type_, SOCK_STREAM);
+    if (socket->state_ == CS_CLOSED) {
+      return;
+    }
+    int error_to_signal = (socket->state_ == CS_CONNECTING) ? ECONNREFUSED : 0;
+    socket->state_ = CS_CLOSED;
+    socket->remote_addr_.Clear();
+    socket->SignalCloseEvent(socket, error_to_signal);
+  };
+  server_->msg_queue_->PostDelayedTask(std::move(task), delay);
 }
 
 int VirtualSocket::InitiateConnect(const SocketAddress& addr, bool use_delay) {
@@ -478,7 +592,6 @@
 }
 
 void VirtualSocket::SetToBlocked() {
-  webrtc::MutexLock lock(&mutex_);
   ready_to_send_ = false;
   error_ = EWOULDBLOCK;
 }
@@ -528,8 +641,6 @@
 }
 
 size_t VirtualSocket::PurgeNetworkPackets(int64_t cur_time) {
-  webrtc::MutexLock lock(&mutex_);
-
   while (!network_.empty() && (network_.front().done_time <= cur_time)) {
     RTC_DCHECK(network_size_ >= network_.front().size);
     network_size_ -= network_.front().size;
@@ -787,7 +898,7 @@
                                  bool use_delay) {
   RTC_DCHECK(msg_queue_);
 
-  uint32_t delay = use_delay ? GetTransitDelay(socket) : 0;
+  TimeDelta delay = TimeDelta::Millis(use_delay ? GetTransitDelay(socket) : 0);
   VirtualSocket* remote = LookupBinding(remote_addr);
   if (!CanInteractWith(socket, remote)) {
     RTC_LOG(LS_INFO) << "Address family mismatch between "
@@ -796,12 +907,10 @@
     return -1;
   }
   if (remote != nullptr) {
-    SocketAddress addr = socket->GetLocalAddress();
-    msg_queue_->PostDelayed(RTC_FROM_HERE, delay, remote, MSG_ID_CONNECT,
-                            new MessageAddress(addr));
+    remote->PostConnect(delay, socket->GetLocalAddress());
   } else {
     RTC_LOG(LS_INFO) << "No one listening at " << remote_addr.ToString();
-    msg_queue_->PostDelayed(RTC_FROM_HERE, delay, socket, MSG_ID_DISCONNECT);
+    socket->PostDisconnect(delay);
   }
   return 0;
 }
@@ -812,9 +921,7 @@
 
   // If we simulate packets being delayed, we should simulate the
   // equivalent of a FIN being delayed as well.
-  uint32_t delay = GetTransitDelay(socket);
-  // Remove the mapping.
-  msg_queue_->PostDelayed(RTC_FROM_HERE, delay, socket, MSG_ID_DISCONNECT);
+  socket->PostDisconnect(TimeDelta::Millis(GetTransitDelay(socket)));
   return true;
 }
 
@@ -841,46 +948,6 @@
   return socket != nullptr;
 }
 
-void VirtualSocketServer::CancelConnects(VirtualSocket* socket) {
-  MessageList msgs;
-  if (msg_queue_) {
-    msg_queue_->Clear(socket, MSG_ID_CONNECT, &msgs);
-  }
-  for (MessageList::iterator it = msgs.begin(); it != msgs.end(); ++it) {
-    RTC_DCHECK(nullptr != it->pdata);
-    MessageAddress* data = static_cast<MessageAddress*>(it->pdata);
-    SocketAddress local_addr = socket->GetLocalAddress();
-    // Lookup remote side.
-    VirtualSocket* lookup_socket = LookupConnection(local_addr, data->addr);
-    if (lookup_socket) {
-      // Server socket, remote side is a socket retreived by
-      // accept. Accepted sockets are not bound so we will not
-      // find it by looking in the bindings table.
-      Disconnect(lookup_socket);
-      RemoveConnection(local_addr, data->addr);
-    } else {
-      Disconnect(data->addr);
-    }
-    delete data;
-  }
-}
-
-void VirtualSocketServer::Clear(VirtualSocket* socket) {
-  // Clear incoming packets and disconnect messages
-  if (msg_queue_) {
-    msg_queue_->Clear(socket);
-  }
-}
-
-void VirtualSocketServer::PostSignalReadEvent(VirtualSocket* socket) {
-  if (!msg_queue_)
-    return;
-
-  // Clear the message so it doesn't end up posted multiple times.
-  msg_queue_->Clear(socket, MSG_ID_SIGNALREADEVENT);
-  msg_queue_->Post(RTC_FROM_HERE, socket, MSG_ID_SIGNALREADEVENT);
-}
-
 int VirtualSocketServer::SendUdp(VirtualSocket* socket,
                                  const char* data,
                                  size_t data_size,
@@ -1031,14 +1098,12 @@
     sender_addr.SetIP(default_ip);
   }
 
-  // Post the packet as a message to be delivered (on our own thread)
-  Packet* p = new Packet(data, data_size, sender_addr);
-
-  int64_t ts = TimeAfter(send_delay + transit_delay);
+  int64_t ts = cur_time + send_delay + transit_delay;
   if (ordered) {
     ts = sender->UpdateOrderedDelivery(ts);
   }
-  msg_queue_->PostAt(RTC_FROM_HERE, ts, recipient, MSG_ID_PACKET, p);
+  recipient->PostPacket(TimeDelta::Millis(ts - cur_time),
+                        std::make_unique<Packet>(data, data_size, sender_addr));
 }
 
 uint32_t VirtualSocketServer::SendDelay(uint32_t size) {
diff --git a/rtc_base/virtual_socket_server.h b/rtc_base/virtual_socket_server.h
index eb9cfc1..93ef288 100644
--- a/rtc_base/virtual_socket_server.h
+++ b/rtc_base/virtual_socket_server.h
@@ -15,10 +15,14 @@
 #include <map>
 #include <vector>
 
+#include "absl/types/optional.h"
+#include "api/make_ref_counted.h"
+#include "api/ref_counted_base.h"
+#include "api/scoped_refptr.h"
+#include "api/task_queue/task_queue_base.h"
 #include "rtc_base/checks.h"
 #include "rtc_base/event.h"
 #include "rtc_base/fake_clock.h"
-#include "rtc_base/message_handler.h"
 #include "rtc_base/socket_server.h"
 #include "rtc_base/synchronization/mutex.h"
 
@@ -28,11 +32,9 @@
 class VirtualSocketServer;
 class SocketAddressPair;
 
-// Implements the socket interface using the virtual network.  Packets are
-// passed as messages using the message queue of the socket server.
-class VirtualSocket : public Socket,
-                      public MessageHandler,
-                      public sigslot::has_slots<> {
+// Implements the socket interface using the virtual network. Packets are
+// passed in tasks using the thread of the socket server.
+class VirtualSocket : public Socket, public sigslot::has_slots<> {
  public:
   VirtualSocket(VirtualSocketServer* server, int family, int type);
   ~VirtualSocket() override;
@@ -58,7 +60,6 @@
   ConnState GetState() const override;
   int GetOption(Option opt, int* value) override;
   int SetOption(Option opt, int value) override;
-  void OnMessage(Message* pmsg) override;
 
   size_t recv_buffer_size() const { return recv_buffer_size_; }
   size_t send_buffer_size() const { return send_buffer_.size(); }
@@ -85,16 +86,82 @@
   // Removes stale packets from the network. Returns current size.
   size_t PurgeNetworkPackets(int64_t cur_time);
 
+  void PostPacket(webrtc::TimeDelta delay, std::unique_ptr<Packet> packet);
+  void PostConnect(webrtc::TimeDelta delay, const SocketAddress& remote_addr);
+  void PostDisconnect(webrtc::TimeDelta delay);
+
  private:
+  // Struct shared with pending tasks that may outlive VirtualSocket.
+  class SafetyBlock : public RefCountedNonVirtual<SafetyBlock> {
+   public:
+    explicit SafetyBlock(VirtualSocket* socket);
+    SafetyBlock(const SafetyBlock&) = delete;
+    SafetyBlock& operator=(const SafetyBlock&) = delete;
+    ~SafetyBlock();
+
+    // Prohibits posted delayed task to access owning VirtualSocket and
+    // cleanups members protected by the `mutex`.
+    void SetNotAlive();
+    bool IsAlive();
+
+    // Copies up to `size` bytes into buffer from the next received packet
+    // and fills `addr` with remote address of that received packet.
+    // Returns number of bytes copied or negative value on failure.
+    int RecvFrom(void* buffer, size_t size, SocketAddress& addr);
+
+    void Listen();
+
+    struct AcceptResult {
+      int error = 0;
+      std::unique_ptr<VirtualSocket> socket;
+      SocketAddress remote_addr;
+    };
+    AcceptResult Accept();
+
+    bool AddPacket(std::unique_ptr<Packet> packet);
+    void PostConnect(webrtc::TimeDelta delay, const SocketAddress& remote_addr);
+
+   private:
+    enum class Signal { kNone, kReadEvent, kConnectEvent };
+    // `PostConnect` rely on the fact that std::list iterators are not
+    // invalidated on any changes to other elements in the container.
+    using PostedConnects = std::list<SocketAddress>;
+
+    void PostSignalReadEvent() RTC_EXCLUSIVE_LOCKS_REQUIRED(mutex_);
+    void MaybeSignalReadEvent();
+    Signal Connect(PostedConnects::iterator remote_addr_it);
+
+    webrtc::Mutex mutex_;
+    VirtualSocket& socket_;
+    bool alive_ RTC_GUARDED_BY(mutex_) = true;
+    // Flag indicating if async Task to signal SignalReadEvent is posted.
+    // To avoid posting multiple such tasks.
+    bool pending_read_signal_event_ RTC_GUARDED_BY(mutex_) = false;
+
+    // Members below do not need to outlive VirtualSocket, but are used by the
+    // posted tasks. Keeping them in the VirtualSocket confuses thread
+    // annotations because they can't detect that locked mutex is the same mutex
+    // this members are guarded by.
+
+    // Addresses of the sockets for potential connect. For each address there
+    // is a posted task that should finilze the connect.
+    PostedConnects posted_connects_ RTC_GUARDED_BY(mutex_);
+
+    // Data which has been received from the network
+    std::list<std::unique_ptr<Packet>> recv_buffer_ RTC_GUARDED_BY(mutex_);
+
+    // Pending sockets which can be Accepted
+    absl::optional<std::deque<SocketAddress>> listen_queue_
+        RTC_GUARDED_BY(mutex_);
+  };
+
   struct NetworkEntry {
     size_t size;
     int64_t done_time;
   };
 
-  typedef std::deque<SocketAddress> ListenQueue;
   typedef std::deque<NetworkEntry> NetworkQueue;
   typedef std::vector<char> SendBuffer;
-  typedef std::list<Packet*> RecvBuffer;
   typedef std::map<Option, int> OptionsMap;
 
   int InitiateConnect(const SocketAddress& addr, bool use_delay);
@@ -111,9 +178,8 @@
   SocketAddress local_addr_;
   SocketAddress remote_addr_;
 
-  // Pending sockets which can be Accepted
-  std::unique_ptr<ListenQueue> listen_queue_ RTC_GUARDED_BY(mutex_)
-      RTC_PT_GUARDED_BY(mutex_);
+  const scoped_refptr<SafetyBlock> safety_ =
+      make_ref_counted<SafetyBlock>(this);
 
   // Data which tcp has buffered for sending
   SendBuffer send_buffer_;
@@ -121,9 +187,6 @@
   // Set back to true when the socket can send again.
   bool ready_to_send_ = true;
 
-  // Mutex to protect the recv_buffer and listen_queue_
-  webrtc::Mutex mutex_;
-
   // Network model that enforces bandwidth and capacity constraints
   NetworkQueue network_;
   size_t network_size_;
@@ -131,8 +194,6 @@
   // It is used to ensure ordered delivery of packets sent on this socket.
   int64_t last_delivery_time_ = 0;
 
-  // Data which has been received from the network
-  RecvBuffer recv_buffer_ RTC_GUARDED_BY(mutex_);
   // The amount of data which is in flight or in recv_buffer_
   size_t recv_buffer_size_;
 
@@ -308,14 +369,6 @@
   // Computes the number of milliseconds required to send a packet of this size.
   uint32_t SendDelay(uint32_t size) RTC_LOCKS_EXCLUDED(mutex_);
 
-  // Cancel attempts to connect to a socket that is being closed.
-  void CancelConnects(VirtualSocket* socket);
-
-  // Clear incoming messages for a socket that is being closed.
-  void Clear(VirtualSocket* socket);
-
-  void PostSignalReadEvent(VirtualSocket* socket);
-
   // Sending was previously blocked, but now isn't.
   sigslot::signal0<> SignalReadyToSend;
 
@@ -327,6 +380,7 @@
   VirtualSocket* LookupBinding(const SocketAddress& addr);
 
  private:
+  friend VirtualSocket;
   uint16_t GetNextPort();
 
   // Find the socket pair corresponding to this server address.