Propagate ECN information through Network Emulation

Bug: webrtc:42225697
Change-Id: Idbd1ded3b5401c86d9afc6fd74f6da58e47bf5cd
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/368862
Commit-Queue: Per Kjellander <perkj@webrtc.org>
Reviewed-by: Mirko Bonadei <mbonadei@webrtc.org>
Cr-Commit-Position: refs/heads/main@{#43441}
diff --git a/api/BUILD.gn b/api/BUILD.gn
index 904447f..9cab2b0 100644
--- a/api/BUILD.gn
+++ b/api/BUILD.gn
@@ -857,6 +857,7 @@
   deps = [
     "../rtc_base:macromagic",
     "../rtc_base:random",
+    "transport:ecn_marking",
     "units:data_rate",
     "//third_party/abseil-cpp/absl/functional:any_invocable",
   ]
diff --git a/api/test/network_emulation/BUILD.gn b/api/test/network_emulation/BUILD.gn
index bb4af45..3534ed5 100644
--- a/api/test/network_emulation/BUILD.gn
+++ b/api/test/network_emulation/BUILD.gn
@@ -52,6 +52,7 @@
     "../../../rtc_base:socket_address",
     "../../numerics",
     "../../task_queue",
+    "../../transport:ecn_marking",
     "../../units:data_rate",
     "../../units:data_size",
     "../../units:time_delta",
diff --git a/api/test/network_emulation/network_emulation_interfaces.cc b/api/test/network_emulation/network_emulation_interfaces.cc
index 1086b96..3986f82 100644
--- a/api/test/network_emulation/network_emulation_interfaces.cc
+++ b/api/test/network_emulation/network_emulation_interfaces.cc
@@ -25,13 +25,15 @@
                                    const rtc::SocketAddress& to,
                                    rtc::CopyOnWriteBuffer data,
                                    Timestamp arrival_time,
-                                   uint16_t application_overhead)
+                                   uint16_t application_overhead,
+                                   EcnMarking ecn)
     : from(from),
       to(to),
       data(data),
       headers_size(to.ipaddr().overhead() + application_overhead +
                    cricket::kUdpHeaderSize),
-      arrival_time(arrival_time) {
+      arrival_time(arrival_time),
+      ecn(ecn) {
   RTC_DCHECK(to.family() == AF_INET || to.family() == AF_INET6);
 }
 
diff --git a/api/test/network_emulation/network_emulation_interfaces.h b/api/test/network_emulation/network_emulation_interfaces.h
index 1789e04..d02d739 100644
--- a/api/test/network_emulation/network_emulation_interfaces.h
+++ b/api/test/network_emulation/network_emulation_interfaces.h
@@ -18,6 +18,7 @@
 #include <vector>
 
 #include "api/numerics/samples_stats_counter.h"
+#include "api/transport/ecn_marking.h"
 #include "api/units/data_rate.h"
 #include "api/units/data_size.h"
 #include "api/units/timestamp.h"
@@ -33,7 +34,8 @@
                    const rtc::SocketAddress& to,
                    rtc::CopyOnWriteBuffer data,
                    Timestamp arrival_time,
-                   uint16_t application_overhead = 0);
+                   uint16_t application_overhead = 0,
+                   EcnMarking ecn = EcnMarking::kNotEct);
   ~EmulatedIpPacket() = default;
   // This object is not copyable or assignable.
   EmulatedIpPacket(const EmulatedIpPacket&) = delete;
@@ -52,6 +54,7 @@
   rtc::CopyOnWriteBuffer data;
   uint16_t headers_size;
   Timestamp arrival_time;
+  EcnMarking ecn;
 };
 
 // Interface for handling IP packets from an emulated network. This is used with
@@ -254,7 +257,8 @@
   virtual void SendPacket(const rtc::SocketAddress& from,
                           const rtc::SocketAddress& to,
                           rtc::CopyOnWriteBuffer packet_data,
-                          uint16_t application_overhead = 0) = 0;
+                          uint16_t application_overhead = 0,
+                          EcnMarking ecn = EcnMarking::kNotEct) = 0;
 
   // Binds receiver to this endpoint to send and receive data.
   // `desired_port` is a port that should be used. If it is equal to 0,
diff --git a/api/test/simulated_network.h b/api/test/simulated_network.h
index 7b572a7..2b75b99 100644
--- a/api/test/simulated_network.h
+++ b/api/test/simulated_network.h
@@ -19,24 +19,40 @@
 #include <vector>
 
 #include "absl/functional/any_invocable.h"
+#include "api/transport/ecn_marking.h"
 #include "api/units/data_rate.h"
 
 namespace webrtc {
 
 struct PacketInFlightInfo {
+  PacketInFlightInfo(size_t size,
+                     int64_t send_time_us,
+                     uint64_t packet_id,
+                     webrtc::EcnMarking ecn)
+      : size(size),
+        send_time_us(send_time_us),
+        packet_id(packet_id),
+        ecn(ecn) {}
+
   PacketInFlightInfo(size_t size, int64_t send_time_us, uint64_t packet_id)
-      : size(size), send_time_us(send_time_us), packet_id(packet_id) {}
+      : PacketInFlightInfo(size,
+                           send_time_us,
+                           packet_id,
+                           webrtc::EcnMarking::kNotEct) {}
 
   size_t size;
   int64_t send_time_us;
   // Unique identifier for the packet in relation to other packets in flight.
   uint64_t packet_id;
+  webrtc::EcnMarking ecn;
 };
 
 struct PacketDeliveryInfo {
   static constexpr int kNotReceived = -1;
   PacketDeliveryInfo(PacketInFlightInfo source, int64_t receive_time_us)
-      : receive_time_us(receive_time_us), packet_id(source.packet_id) {}
+      : receive_time_us(receive_time_us),
+        packet_id(source.packet_id),
+        ecn(source.ecn) {}
 
   bool operator==(const PacketDeliveryInfo& other) const {
     return receive_time_us == other.receive_time_us &&
@@ -45,6 +61,7 @@
 
   int64_t receive_time_us;
   uint64_t packet_id;
+  webrtc::EcnMarking ecn;
 };
 
 // BuiltInNetworkBehaviorConfig is a built-in network behavior configuration
diff --git a/test/network/BUILD.gn b/test/network/BUILD.gn
index 8de643e..18d1ca6 100644
--- a/test/network/BUILD.gn
+++ b/test/network/BUILD.gn
@@ -51,6 +51,7 @@
     "../../api/task_queue",
     "../../api/task_queue:pending_task_safety_flag",
     "../../api/test/network_emulation",
+    "../../api/transport:ecn_marking",
     "../../api/transport:stun_types",
     "../../api/units:data_rate",
     "../../api/units:data_size",
@@ -60,6 +61,7 @@
     "../../p2p:p2p_server_utils",
     "../../p2p:rtc_p2p",
     "../../rtc_base:async_packet_socket",
+    "../../rtc_base:checks",
     "../../rtc_base:copy_on_write_buffer",
     "../../rtc_base:ip_address",
     "../../rtc_base:logging",
@@ -151,11 +153,14 @@
       "../../api:create_time_controller",
       "../../api:simulated_network_api",
       "../../api/task_queue:task_queue",
+      "../../api/transport:ecn_marking",
       "../../api/units:time_delta",
       "../../api/units:timestamp",
+      "../../rtc_base:buffer",
       "../../rtc_base:gunit_helpers",
       "../../rtc_base:logging",
       "../../rtc_base:rtc_event",
+      "../../rtc_base:socket",
       "../../rtc_base:task_queue_for_test",
       "../../rtc_base/synchronization:mutex",
     ]
diff --git a/test/network/fake_network_socket_server.cc b/test/network/fake_network_socket_server.cc
index 8dcca34..773a324 100644
--- a/test/network/fake_network_socket_server.cc
+++ b/test/network/fake_network_socket_server.cc
@@ -18,6 +18,8 @@
 #include "absl/algorithm/container.h"
 #include "api/scoped_refptr.h"
 #include "api/task_queue/pending_task_safety_flag.h"
+#include "api/transport/ecn_marking.h"
+#include "rtc_base/checks.h"
 #include "rtc_base/event.h"
 #include "rtc_base/logging.h"
 #include "rtc_base/thread.h"
@@ -52,11 +54,11 @@
   int SendTo(const void* pv,
              size_t cb,
              const rtc::SocketAddress& addr) override;
-  int Recv(void* pv, size_t cb, int64_t* timestamp) override;
-  int RecvFrom(void* pv,
-               size_t cb,
-               rtc::SocketAddress* paddr,
-               int64_t* timestamp) override;
+  int Recv(void* pv, size_t cb, int64_t* timestamp) override {
+    RTC_DCHECK_NOTREACHED() << " Use RecvFrom instead.";
+    return 0;
+  }
+  int RecvFrom(ReceiveBuffer& buffer) override;
   int Listen(int backlog) override;
   rtc::Socket* Accept(rtc::SocketAddress* paddr) override;
   int GetError() const override;
@@ -175,47 +177,26 @@
     return -1;
   }
   rtc::CopyOnWriteBuffer packet(static_cast<const uint8_t*>(pv), cb);
-  endpoint_->SendPacket(local_addr_, addr, packet);
+  EcnMarking ecn = EcnMarking::kNotEct;
+  auto it = options_map_.find(OPT_SEND_ECN);
+  if (it != options_map_.end() && it->second == 1) {
+    ecn = EcnMarking::kEct1;
+  }
+
+  endpoint_->SendPacket(local_addr_, addr, packet, /*application_overhead=*/0,
+                        ecn);
   return cb;
 }
 
-int FakeNetworkSocket::Recv(void* pv, size_t cb, int64_t* timestamp) {
-  rtc::SocketAddress paddr;
-  return RecvFrom(pv, cb, &paddr, timestamp);
-}
-
-// Reads 1 packet from internal queue. Reads up to `cb` bytes into `pv`
-// and returns the length of received packet.
-int FakeNetworkSocket::RecvFrom(void* pv,
-                                size_t cb,
-                                rtc::SocketAddress* paddr,
-                                int64_t* timestamp) {
+int FakeNetworkSocket::RecvFrom(ReceiveBuffer& buffer) {
   RTC_DCHECK_RUN_ON(thread_);
-
-  if (timestamp) {
-    *timestamp = -1;
-  }
   RTC_CHECK(pending_);
-
-  *paddr = pending_->from;
-  size_t data_read = std::min(cb, pending_->size());
-  memcpy(pv, pending_->cdata(), data_read);
-  *timestamp = pending_->arrival_time.us();
-
-  // According to RECV(2) Linux Man page
-  // real socket will discard data, that won't fit into provided buffer,
-  // but we won't to skip such error, so we will assert here.
-  RTC_CHECK(data_read == pending_->size())
-      << "Too small buffer is provided for socket read. "
-         "Received data size: "
-      << pending_->size() << "; Provided buffer size: " << cb;
-
+  buffer.source_address = pending_->from;
+  buffer.arrival_time = pending_->arrival_time;
+  buffer.payload.SetData(pending_->cdata(), pending_->size());
+  buffer.ecn = pending_->ecn;
   pending_.reset();
-
-  // According to RECV(2) Linux Man page
-  // real socket will return message length, not data read. In our case it is
-  // actually the same value.
-  return static_cast<int>(data_read);
+  return buffer.payload.size();
 }
 
 int FakeNetworkSocket::Listen(int backlog) {
diff --git a/test/network/network_emulation.cc b/test/network/network_emulation.cc
index 5a1d1dd..62168c2 100644
--- a/test/network/network_emulation.cc
+++ b/test/network/network_emulation.cc
@@ -28,6 +28,7 @@
 #include "api/task_queue/task_queue_base.h"
 #include "api/test/network_emulation/network_emulation_interfaces.h"
 #include "api/test/network_emulation_manager.h"
+#include "api/transport/ecn_marking.h"
 #include "api/units/data_size.h"
 #include "api/units/time_delta.h"
 #include "rtc_base/logging.h"
@@ -370,7 +371,7 @@
     uint64_t packet_id = next_packet_id_++;
     bool sent = network_behavior_->EnqueuePacket(
         PacketInFlightInfo(GetPacketSizeForEmulation(packet),
-                           packet.arrival_time.us(), packet_id));
+                           packet.arrival_time.us(), packet_id, packet.ecn));
     if (sent) {
       packets_.emplace_back(StoredPacket{.id = packet_id,
                                          .sent_time = clock_->CurrentTime(),
@@ -410,6 +411,8 @@
     if (delivery_info.receive_time_us != PacketDeliveryInfo::kNotReceived) {
       packet->packet.arrival_time =
           Timestamp::Micros(delivery_info.receive_time_us);
+      // Link may have changed ECN.
+      packet->packet.ecn = delivery_info.ecn;
       receiver_->OnPacketReceived(std::move(packet->packet));
     }
     while (!packets_.empty() && packets_.front().removed) {
@@ -615,12 +618,13 @@
 void EmulatedEndpointImpl::SendPacket(const rtc::SocketAddress& from,
                                       const rtc::SocketAddress& to,
                                       rtc::CopyOnWriteBuffer packet_data,
-                                      uint16_t application_overhead) {
+                                      uint16_t application_overhead,
+                                      EcnMarking ecn) {
   if (!options_.allow_send_packet_with_different_source_ip) {
     RTC_CHECK(from.ipaddr() == options_.ip);
   }
   EmulatedIpPacket packet(from, to, std::move(packet_data),
-                          clock_->CurrentTime(), application_overhead);
+                          clock_->CurrentTime(), application_overhead, ecn);
   task_queue_->PostTask([this, packet = std::move(packet)]() mutable {
     RTC_DCHECK_RUN_ON(task_queue_);
     stats_builder_.OnPacketSent(packet.arrival_time, clock_->CurrentTime(),
diff --git a/test/network/network_emulation.h b/test/network/network_emulation.h
index 6dec782..10bca42 100644
--- a/test/network/network_emulation.h
+++ b/test/network/network_emulation.h
@@ -29,6 +29,7 @@
 #include "api/test/network_emulation/network_emulation_interfaces.h"
 #include "api/test/network_emulation_manager.h"
 #include "api/test/simulated_network.h"
+#include "api/transport/ecn_marking.h"
 #include "api/units/time_delta.h"
 #include "api/units/timestamp.h"
 #include "rtc_base/copy_on_write_buffer.h"
@@ -297,7 +298,8 @@
   void SendPacket(const rtc::SocketAddress& from,
                   const rtc::SocketAddress& to,
                   rtc::CopyOnWriteBuffer packet_data,
-                  uint16_t application_overhead = 0) override;
+                  uint16_t application_overhead = 0,
+                  EcnMarking ecn = EcnMarking::kNotEct) override;
 
   std::optional<uint16_t> BindReceiver(
       uint16_t desired_port,
diff --git a/test/network/network_emulation_unittest.cc b/test/network/network_emulation_unittest.cc
index 1057d6d..dfacbf2 100644
--- a/test/network/network_emulation_unittest.cc
+++ b/test/network/network_emulation_unittest.cc
@@ -19,10 +19,12 @@
 #include "api/task_queue/task_queue_base.h"
 #include "api/test/create_time_controller.h"
 #include "api/test/simulated_network.h"
+#include "api/transport/ecn_marking.h"
 #include "api/units/time_delta.h"
 #include "api/units/timestamp.h"
-#include "rtc_base/event.h"
+#include "rtc_base/buffer.h"
 #include "rtc_base/gunit.h"
+#include "rtc_base/socket.h"
 #include "rtc_base/synchronization/mutex.h"
 #include "rtc_base/task_queue_for_test.h"
 #include "test/gmock.h"
@@ -45,34 +47,37 @@
   explicit SocketReader(rtc::Socket* socket, rtc::Thread* network_thread)
       : socket_(socket), network_thread_(network_thread) {
     socket_->SignalReadEvent.connect(this, &SocketReader::OnReadEvent);
-    size_ = 128 * 1024;
-    buf_ = new char[size_];
   }
-  ~SocketReader() override { delete[] buf_; }
 
   void OnReadEvent(rtc::Socket* socket) {
     RTC_DCHECK(socket_ == socket);
     RTC_DCHECK(network_thread_->IsCurrent());
-    int64_t timestamp;
-    len_ = socket_->Recv(buf_, size_, &timestamp);
+
+    rtc::Socket::ReceiveBuffer receive_buffer(payload_);
+    socket_->RecvFrom(receive_buffer);
+    last_ecn_mark_ = receive_buffer.ecn;
 
     MutexLock lock(&lock_);
     received_count_++;
   }
 
-  int ReceivedCount() {
+  int ReceivedCount() const {
     MutexLock lock(&lock_);
     return received_count_;
   }
 
+  webrtc::EcnMarking LastEcnMarking() const {
+    MutexLock lock(&lock_);
+    return last_ecn_mark_;
+  }
+
  private:
   rtc::Socket* const socket_;
   rtc::Thread* const network_thread_;
-  char* buf_;
-  size_t size_;
-  int len_;
+  rtc::Buffer payload_;
+  webrtc::EcnMarking last_ecn_mark_;
 
-  Mutex lock_;
+  mutable Mutex lock_;
   int received_count_ RTC_GUARDED_BY(lock_) = 0;
 };
 
@@ -359,6 +364,69 @@
                            *network_manager.time_controller());
 }
 
+TEST(NetworkEmulationManagerTest, EcnMarkingIsPropagated) {
+  NetworkEmulationManagerImpl network_manager(
+      {.time_mode = TimeMode::kRealTime});
+
+  EmulatedNetworkNode* alice_node = network_manager.CreateEmulatedNode(
+      std::make_unique<SimulatedNetwork>(BuiltInNetworkBehaviorConfig()));
+  EmulatedNetworkNode* bob_node = network_manager.CreateEmulatedNode(
+      std::make_unique<SimulatedNetwork>(BuiltInNetworkBehaviorConfig()));
+  EmulatedEndpoint* alice_endpoint =
+      network_manager.CreateEndpoint(EmulatedEndpointConfig());
+  EmulatedEndpoint* bob_endpoint =
+      network_manager.CreateEndpoint(EmulatedEndpointConfig());
+  network_manager.CreateRoute(alice_endpoint, {alice_node}, bob_endpoint);
+  network_manager.CreateRoute(bob_endpoint, {bob_node}, alice_endpoint);
+
+  EmulatedNetworkManagerInterface* nt1 =
+      network_manager.CreateEmulatedNetworkManagerInterface({alice_endpoint});
+  EmulatedNetworkManagerInterface* nt2 =
+      network_manager.CreateEmulatedNetworkManagerInterface({bob_endpoint});
+
+  rtc::Thread* t1 = nt1->network_thread();
+  rtc::Thread* t2 = nt2->network_thread();
+
+  rtc::Socket* s1 = nullptr;
+  rtc::Socket* s2 = nullptr;
+  SendTask(t1,
+           [&] { s1 = t1->socketserver()->CreateSocket(AF_INET, SOCK_DGRAM); });
+  SendTask(t2,
+           [&] { s2 = t2->socketserver()->CreateSocket(AF_INET, SOCK_DGRAM); });
+
+  SocketReader r1(s1, t1);
+  SocketReader r2(s2, t2);
+
+  rtc::SocketAddress a1(alice_endpoint->GetPeerLocalAddress(), 0);
+  rtc::SocketAddress a2(bob_endpoint->GetPeerLocalAddress(), 0);
+
+  SendTask(t1, [&] {
+    s1->Bind(a1);
+    a1 = s1->GetLocalAddress();
+  });
+  SendTask(t2, [&] {
+    s2->Bind(a2);
+    a2 = s2->GetLocalAddress();
+  });
+
+  SendTask(t1, [&] { s1->Connect(a2); });
+  SendTask(t2, [&] { s2->Connect(a1); });
+
+  t1->PostTask([&]() {
+    s1->SetOption(rtc::Socket::Option::OPT_SEND_ECN, 1);
+    rtc::CopyOnWriteBuffer data("Hello");
+    s1->Send(data.data(), data.size());
+  });
+
+  network_manager.time_controller()->AdvanceTime(TimeDelta::Seconds(1));
+
+  EXPECT_EQ(r2.ReceivedCount(), 1);
+  EXPECT_EQ(r2.LastEcnMarking(), webrtc::EcnMarking::kEct1);
+
+  SendTask(t1, [&] { delete s1; });
+  SendTask(t2, [&] { delete s2; });
+}
+
 TEST(NetworkEmulationManagerTest, DebugStatsCollectedInDebugMode) {
   NetworkEmulationManagerImpl network_manager(
       {.time_mode = TimeMode::kSimulated,