Ensure VirtualSocketServer can propagate ECN

VirtualSocketServer is used for testing much of the logic related to Turn/Stun ports.

Bug: webrtc:453581251
Change-Id: I4d099ed3f06925cb43f8b9c4ad6610c94ab28094
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/418741
Reviewed-by: Harald Alvestrand <hta@webrtc.org>
Commit-Queue: Per Kjellander <perkj@webrtc.org>
Cr-Commit-Position: refs/heads/main@{#45986}
diff --git a/rtc_base/BUILD.gn b/rtc_base/BUILD.gn
index 8651e33..016c99e 100644
--- a/rtc_base/BUILD.gn
+++ b/rtc_base/BUILD.gn
@@ -1759,6 +1759,7 @@
     ":socket_address",
     ":threading",
     ":timeutils",
+    "../api/transport:ecn_marking",
     "../api/units:time_delta",
     "../api/units:timestamp",
     "../test:wait_until",
@@ -1848,6 +1849,7 @@
     "../api/environment",
     "../api/task_queue",
     "../api/task_queue:pending_task_safety_flag",
+    "../api/transport:ecn_marking",
     "../api/units:time_delta",
     "../api/units:timestamp",
     "memory:always_valid_pointer",
@@ -2176,6 +2178,7 @@
         "../api:scoped_refptr",
         "../api/environment",
         "../api/numerics",
+        "../api/transport:ecn_marking",
         "../api/units:data_rate",
         "../api/units:data_size",
         "../api/units:frequency",
diff --git a/rtc_base/test_client.cc b/rtc_base/test_client.cc
index b48cda4..3608136 100644
--- a/rtc_base/test_client.cc
+++ b/rtc_base/test_client.cc
@@ -153,11 +153,13 @@
     : addr(received_packet.source_address()),
       // Copy received_packet payload to a buffer owned by Packet.
       buf(received_packet.payload().data(), received_packet.payload().size()),
+      ecn(received_packet.ecn()),
       packet_time(received_packet.arrival_time()) {}
 
 TestClient::Packet::Packet(const Packet& p)
     : addr(p.addr),
       buf(p.buf.data(), p.buf.size()),
+      ecn(p.ecn),
       packet_time(p.packet_time) {}
 
 }  // namespace webrtc
diff --git a/rtc_base/test_client.h b/rtc_base/test_client.h
index 738cecc..68c7705 100644
--- a/rtc_base/test_client.h
+++ b/rtc_base/test_client.h
@@ -16,6 +16,7 @@
 #include <optional>
 #include <vector>
 
+#include "api/transport/ecn_marking.h"
 #include "api/units/time_delta.h"
 #include "api/units/timestamp.h"
 #include "rtc_base/async_packet_socket.h"
@@ -40,6 +41,7 @@
 
     SocketAddress addr;
     Buffer buf;
+    EcnMarking ecn;
     std::optional<Timestamp> packet_time;
   };
 
@@ -77,7 +79,7 @@
 
   // Returns the next packet received by the client or null if none is received
   // within the specified timeout.
-  std::unique_ptr<Packet> NextPacket(int timeout_ms);
+  std::unique_ptr<Packet> NextPacket(int timeout_ms = kTimeoutMs);
 
   // Checks that the next packet has the given contents. Returns the remote
   // address that the packet was sent from.
diff --git a/rtc_base/virtual_socket_server.cc b/rtc_base/virtual_socket_server.cc
index c86aee9..041b0fb 100644
--- a/rtc_base/virtual_socket_server.cc
+++ b/rtc_base/virtual_socket_server.cc
@@ -26,7 +26,9 @@
 #include "absl/algorithm/container.h"
 #include "api/scoped_refptr.h"
 #include "api/sequence_checker.h"
+#include "api/transport/ecn_marking.h"
 #include "api/units/time_delta.h"
+#include "rtc_base/buffer.h"
 #include "rtc_base/byte_order.h"
 #include "rtc_base/checks.h"
 #include "rtc_base/event.h"
@@ -71,8 +73,11 @@
 // the kernel does.
 class VirtualSocketPacket {
  public:
-  VirtualSocketPacket(const char* data, size_t size, const SocketAddress& from)
-      : size_(size), consumed_(0), from_(from) {
+  VirtualSocketPacket(const char* data,
+                      size_t size,
+                      EcnMarking ecn,
+                      const SocketAddress& from)
+      : size_(size), consumed_(0), ecn_(ecn), from_(from) {
     RTC_DCHECK(nullptr != data);
     data_ = new char[size_];
     memcpy(data_, data, size_);
@@ -82,6 +87,7 @@
 
   const char* data() const { return data_ + consumed_; }
   size_t size() const { return size_ - consumed_; }
+  EcnMarking ecn() const { return ecn_; }
   const SocketAddress& from() const { return from_; }
 
   // Remove the first size bytes from the data.
@@ -93,6 +99,7 @@
  private:
   char* data_;
   size_t size_, consumed_;
+  EcnMarking ecn_;
   SocketAddress from_;
 };
 
@@ -271,11 +278,28 @@
                             size_t cb,
                             SocketAddress* paddr,
                             int64_t* timestamp) {
-  if (timestamp) {
-    *timestamp = -1;
+  Buffer payload;
+  payload.EnsureCapacity(cb);
+  ReceiveBuffer receive_buffer(payload);
+  int bytes_received = DoRecvFrom(receive_buffer);
+  if (bytes_received > 0) {
+    memcpy(pv, payload.data(), bytes_received);
   }
+  *paddr = receive_buffer.source_address;
+  return bytes_received;
+}
 
-  int data_read = safety_->RecvFrom(pv, cb, *paddr);
+int VirtualSocket::RecvFrom(ReceiveBuffer& buffer) {
+  static constexpr int BUF_SIZE = 64 * 1024;
+  buffer.payload.EnsureCapacity(BUF_SIZE);
+  return DoRecvFrom(buffer);
+}
+
+int VirtualSocket::DoRecvFrom(ReceiveBuffer& buffer) {
+  int data_read = safety_->RecvFrom(buffer);
+  if (options_map_[OPT_RECV_ECN] != 1) {
+    buffer.ecn = EcnMarking::kNotEct;
+  }
   if (data_read < 0) {
     error_ = EAGAIN;
     return -1;
@@ -292,9 +316,7 @@
   return data_read;
 }
 
-int VirtualSocket::SafetyBlock::RecvFrom(void* buffer,
-                                         size_t size,
-                                         SocketAddress& addr) {
+int VirtualSocket::SafetyBlock::RecvFrom(ReceiveBuffer& buffer) {
   MutexLock lock(&mutex_);
   // If we don't have a packet, then either error or wait for one to arrive.
   if (recv_buffer_.empty()) {
@@ -303,9 +325,10 @@
 
   // Return the packet at the front of the queue.
   VirtualSocketPacket& packet = *recv_buffer_.front();
-  size_t data_read = std::min(size, packet.size());
-  memcpy(buffer, packet.data(), data_read);
-  addr = packet.from();
+  size_t data_read = std::min(buffer.payload.capacity(), packet.size());
+  buffer.payload.SetData(packet.data(), data_read);
+  buffer.source_address = packet.from();
+  buffer.ecn = packet.ecn();
 
   if (data_read < packet.size()) {
     packet.Consume(data_read);
@@ -566,9 +589,12 @@
       return result;
     }
   }
+  EcnMarking ecn = (options_map_[Socket::OPT_SEND_ECN] == 1)
+                       ? EcnMarking::kEct1
+                       : EcnMarking::kNotEct;
 
   // Send the data in a message to the appropriate socket.
-  return server_->SendUdp(this, static_cast<const char*>(pv), cb, addr);
+  return server_->SendUdp(this, static_cast<const char*>(pv), cb, ecn, addr);
 }
 
 int VirtualSocket::SendTcp(const void* pv, size_t cb) {
@@ -961,6 +987,7 @@
 int VirtualSocketServer::SendUdp(VirtualSocket* socket,
                                  const char* data,
                                  size_t data_size,
+                                 EcnMarking ecn,
                                  const SocketAddress& remote_addr) {
   {
     MutexLock lock(&mutex_);
@@ -1029,7 +1056,7 @@
     }
 
     AddPacketToNetwork(socket, recipient, cur_time, data, data_size,
-                       UDP_HEADER_SIZE, false);
+                       UDP_HEADER_SIZE, false, ecn);
 
     return static_cast<int>(data_size);
   }
@@ -1072,7 +1099,7 @@
       break;
 
     AddPacketToNetwork(socket, recipient, cur_time, socket->send_buffer_data(),
-                       data_size, TCP_HEADER_SIZE, true);
+                       data_size, TCP_HEADER_SIZE, true, EcnMarking::kNotEct);
     recipient->UpdateRecv(data_size);
     socket->UpdateSend(data_size);
   }
@@ -1092,7 +1119,8 @@
                                              const char* data,
                                              size_t data_size,
                                              size_t header_size,
-                                             bool ordered) {
+                                             bool ordered,
+                                             EcnMarking ecn) {
   RTC_DCHECK(msg_queue_);
   uint32_t send_delay = sender->AddPacket(cur_time, data_size + header_size);
 
@@ -1114,7 +1142,7 @@
   }
   recipient->PostPacket(
       TimeDelta::Millis(ts - cur_time),
-      std::make_unique<VirtualSocketPacket>(data, data_size, sender_addr));
+      std::make_unique<VirtualSocketPacket>(data, data_size, ecn, sender_addr));
 }
 
 uint32_t VirtualSocketServer::SendDelay(uint32_t size) {
diff --git a/rtc_base/virtual_socket_server.h b/rtc_base/virtual_socket_server.h
index c191d4a..c144264 100644
--- a/rtc_base/virtual_socket_server.h
+++ b/rtc_base/virtual_socket_server.h
@@ -21,9 +21,11 @@
 #include <utility>
 #include <vector>
 
+#include "absl/functional/any_invocable.h"
 #include "api/make_ref_counted.h"
 #include "api/ref_counted_base.h"
 #include "api/scoped_refptr.h"
+#include "api/transport/ecn_marking.h"
 #include "api/units/time_delta.h"
 #include "rtc_base/event.h"
 #include "rtc_base/fake_clock.h"
@@ -62,6 +64,7 @@
                size_t cb,
                SocketAddress* paddr,
                int64_t* timestamp) override;
+  int RecvFrom(ReceiveBuffer& buffer) override;
   int Listen(int backlog) override;
   VirtualSocket* Accept(SocketAddress* paddr) override;
 
@@ -114,10 +117,11 @@
     void SetNotAlive();
     bool IsAlive();
 
-    // Copies up to `size` bytes into buffer from the next received packet
-    // and fills `addr` with remote address of that received packet.
-    // Returns number of bytes copied or negative value on failure.
-    int RecvFrom(void* buffer, size_t size, SocketAddress& addr);
+    // Copies up to `buffer.payload.capacity()` bytes into `buffer.payload()`
+    // from the next received packet and fills `addr` with remote address of
+    // that received packet. Returns number of bytes copied or negative value on
+    // failure.
+    int RecvFrom(ReceiveBuffer& buffer);
 
     void Listen();
 
@@ -179,6 +183,7 @@
   void CompleteConnect(const SocketAddress& addr);
   int SendUdp(const void* pv, size_t cb, const SocketAddress& addr);
   int SendTcp(const void* pv, size_t cb);
+  int DoRecvFrom(ReceiveBuffer& buffer);
 
   void OnSocketServerReadyToSend();
 
@@ -369,6 +374,7 @@
   int SendUdp(VirtualSocket* socket,
               const char* data,
               size_t data_size,
+              EcnMarking ecn,
               const SocketAddress& remote_addr);
 
   // Moves as much data as possible from the sender's buffer to the network
@@ -414,7 +420,8 @@
                           const char* data,
                           size_t data_size,
                           size_t header_size,
-                          bool ordered);
+                          bool ordered,
+                          EcnMarking ecn);
 
   // If the delay has been set for the address of the socket, returns the set
   // delay. Otherwise, returns a random transit delay chosen from the
diff --git a/rtc_base/virtual_socket_unittest.cc b/rtc_base/virtual_socket_unittest.cc
index f4404a0..9fadd19 100644
--- a/rtc_base/virtual_socket_unittest.cc
+++ b/rtc_base/virtual_socket_unittest.cc
@@ -19,6 +19,7 @@
 
 #include "absl/memory/memory.h"
 #include "api/environment/environment.h"
+#include "api/transport/ecn_marking.h"
 #include "api/units/time_delta.h"
 #include "rtc_base/async_packet_socket.h"
 #include "rtc_base/async_udp_socket.h"
@@ -36,11 +37,13 @@
 #include "rtc_base/thread.h"
 #include "rtc_base/virtual_socket_server.h"
 #include "test/create_test_environment.h"
+#include "test/gmock.h"
 #include "test/gtest.h"
 
 namespace webrtc {
 namespace {
 
+using ::testing::NotNull;
 using testing::SSE_CLOSE;
 using testing::SSE_ERROR;
 using testing::SSE_OPEN;
@@ -853,6 +856,41 @@
     }
   }
 
+  void SendReceiveEcn(const SocketAddress& initial_addr) {
+    std::unique_ptr<Socket> socket =
+        ss_.Create(initial_addr.family(), SOCK_DGRAM);
+    socket->Bind(initial_addr);
+    SocketAddress server_addr = socket->GetLocalAddress();
+
+    TestClient client1(
+        std::make_unique<AsyncUDPSocket>(env_, std::move(socket)),
+        &fake_clock_);
+
+    SocketAddress client2_addr;
+    std::unique_ptr<Socket> socket2 =
+        ss_.Create(initial_addr.family(), SOCK_DGRAM);
+    TestClient client2(
+        std::make_unique<AsyncUDPSocket>(env_, std::move(socket2)),
+        &fake_clock_);
+
+    client2.SendTo("foo", 3, server_addr);
+    std::unique_ptr<TestClient::Packet> packet_1 = client1.NextPacket();
+    ASSERT_THAT(packet_1.get(), NotNull());
+    EXPECT_EQ(packet_1->ecn, EcnMarking::kNotEct);
+
+    client2.SetOption(Socket::OPT_SEND_ECN, 1);
+    client2.SendTo("bar", 3, server_addr);
+    std::unique_ptr<TestClient::Packet> packet_2 = client1.NextPacket();
+    ASSERT_THAT(packet_2.get(), NotNull());
+    EXPECT_EQ(packet_2->ecn, EcnMarking::kNotEct);
+
+    client1.SetOption(Socket::OPT_RECV_ECN, 1);
+    client2.SendTo("bar", 3, server_addr);
+    std::unique_ptr<TestClient::Packet> packet_3 = client1.NextPacket();
+    ASSERT_THAT(packet_3.get(), NotNull());
+    EXPECT_EQ(packet_3->ecn, EcnMarking::kEct1);
+  }
+
  protected:
   ScopedFakeClock fake_clock_;
   const Environment env_ = CreateTestEnvironment();
@@ -916,6 +954,10 @@
   CloseTest(kIPv6AnyAddress);
 }
 
+TEST_F(VirtualSocketServerTest, SendReceiveEcn) {
+  SendReceiveEcn(kIPv4AnyAddress);
+}
+
 TEST_F(VirtualSocketServerTest, tcp_send_v4) {
   TcpSendTest(kIPv4AnyAddress);
 }