Ensure VirtualSocketServer can propagate ECN
VirtualSocketServer is used for testing much of the logic related to Turn/Stun ports.
Bug: webrtc:453581251
Change-Id: I4d099ed3f06925cb43f8b9c4ad6610c94ab28094
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/418741
Reviewed-by: Harald Alvestrand <hta@webrtc.org>
Commit-Queue: Per Kjellander <perkj@webrtc.org>
Cr-Commit-Position: refs/heads/main@{#45986}
diff --git a/rtc_base/BUILD.gn b/rtc_base/BUILD.gn
index 8651e33..016c99e 100644
--- a/rtc_base/BUILD.gn
+++ b/rtc_base/BUILD.gn
@@ -1759,6 +1759,7 @@
":socket_address",
":threading",
":timeutils",
+ "../api/transport:ecn_marking",
"../api/units:time_delta",
"../api/units:timestamp",
"../test:wait_until",
@@ -1848,6 +1849,7 @@
"../api/environment",
"../api/task_queue",
"../api/task_queue:pending_task_safety_flag",
+ "../api/transport:ecn_marking",
"../api/units:time_delta",
"../api/units:timestamp",
"memory:always_valid_pointer",
@@ -2176,6 +2178,7 @@
"../api:scoped_refptr",
"../api/environment",
"../api/numerics",
+ "../api/transport:ecn_marking",
"../api/units:data_rate",
"../api/units:data_size",
"../api/units:frequency",
diff --git a/rtc_base/test_client.cc b/rtc_base/test_client.cc
index b48cda4..3608136 100644
--- a/rtc_base/test_client.cc
+++ b/rtc_base/test_client.cc
@@ -153,11 +153,13 @@
: addr(received_packet.source_address()),
// Copy received_packet payload to a buffer owned by Packet.
buf(received_packet.payload().data(), received_packet.payload().size()),
+ ecn(received_packet.ecn()),
packet_time(received_packet.arrival_time()) {}
TestClient::Packet::Packet(const Packet& p)
: addr(p.addr),
buf(p.buf.data(), p.buf.size()),
+ ecn(p.ecn),
packet_time(p.packet_time) {}
} // namespace webrtc
diff --git a/rtc_base/test_client.h b/rtc_base/test_client.h
index 738cecc..68c7705 100644
--- a/rtc_base/test_client.h
+++ b/rtc_base/test_client.h
@@ -16,6 +16,7 @@
#include <optional>
#include <vector>
+#include "api/transport/ecn_marking.h"
#include "api/units/time_delta.h"
#include "api/units/timestamp.h"
#include "rtc_base/async_packet_socket.h"
@@ -40,6 +41,7 @@
SocketAddress addr;
Buffer buf;
+ EcnMarking ecn;
std::optional<Timestamp> packet_time;
};
@@ -77,7 +79,7 @@
// Returns the next packet received by the client or null if none is received
// within the specified timeout.
- std::unique_ptr<Packet> NextPacket(int timeout_ms);
+ std::unique_ptr<Packet> NextPacket(int timeout_ms = kTimeoutMs);
// Checks that the next packet has the given contents. Returns the remote
// address that the packet was sent from.
diff --git a/rtc_base/virtual_socket_server.cc b/rtc_base/virtual_socket_server.cc
index c86aee9..041b0fb 100644
--- a/rtc_base/virtual_socket_server.cc
+++ b/rtc_base/virtual_socket_server.cc
@@ -26,7 +26,9 @@
#include "absl/algorithm/container.h"
#include "api/scoped_refptr.h"
#include "api/sequence_checker.h"
+#include "api/transport/ecn_marking.h"
#include "api/units/time_delta.h"
+#include "rtc_base/buffer.h"
#include "rtc_base/byte_order.h"
#include "rtc_base/checks.h"
#include "rtc_base/event.h"
@@ -71,8 +73,11 @@
// the kernel does.
class VirtualSocketPacket {
public:
- VirtualSocketPacket(const char* data, size_t size, const SocketAddress& from)
- : size_(size), consumed_(0), from_(from) {
+ VirtualSocketPacket(const char* data,
+ size_t size,
+ EcnMarking ecn,
+ const SocketAddress& from)
+ : size_(size), consumed_(0), ecn_(ecn), from_(from) {
RTC_DCHECK(nullptr != data);
data_ = new char[size_];
memcpy(data_, data, size_);
@@ -82,6 +87,7 @@
const char* data() const { return data_ + consumed_; }
size_t size() const { return size_ - consumed_; }
+ EcnMarking ecn() const { return ecn_; }
const SocketAddress& from() const { return from_; }
// Remove the first size bytes from the data.
@@ -93,6 +99,7 @@
private:
char* data_;
size_t size_, consumed_;
+ EcnMarking ecn_;
SocketAddress from_;
};
@@ -271,11 +278,28 @@
size_t cb,
SocketAddress* paddr,
int64_t* timestamp) {
- if (timestamp) {
- *timestamp = -1;
+ Buffer payload;
+ payload.EnsureCapacity(cb);
+ ReceiveBuffer receive_buffer(payload);
+ int bytes_received = DoRecvFrom(receive_buffer);
+ if (bytes_received > 0) {
+ memcpy(pv, payload.data(), bytes_received);
}
+ *paddr = receive_buffer.source_address;
+ return bytes_received;
+}
- int data_read = safety_->RecvFrom(pv, cb, *paddr);
+int VirtualSocket::RecvFrom(ReceiveBuffer& buffer) {
+ static constexpr int BUF_SIZE = 64 * 1024;
+ buffer.payload.EnsureCapacity(BUF_SIZE);
+ return DoRecvFrom(buffer);
+}
+
+int VirtualSocket::DoRecvFrom(ReceiveBuffer& buffer) {
+ int data_read = safety_->RecvFrom(buffer);
+ if (options_map_[OPT_RECV_ECN] != 1) {
+ buffer.ecn = EcnMarking::kNotEct;
+ }
if (data_read < 0) {
error_ = EAGAIN;
return -1;
@@ -292,9 +316,7 @@
return data_read;
}
-int VirtualSocket::SafetyBlock::RecvFrom(void* buffer,
- size_t size,
- SocketAddress& addr) {
+int VirtualSocket::SafetyBlock::RecvFrom(ReceiveBuffer& buffer) {
MutexLock lock(&mutex_);
// If we don't have a packet, then either error or wait for one to arrive.
if (recv_buffer_.empty()) {
@@ -303,9 +325,10 @@
// Return the packet at the front of the queue.
VirtualSocketPacket& packet = *recv_buffer_.front();
- size_t data_read = std::min(size, packet.size());
- memcpy(buffer, packet.data(), data_read);
- addr = packet.from();
+ size_t data_read = std::min(buffer.payload.capacity(), packet.size());
+ buffer.payload.SetData(packet.data(), data_read);
+ buffer.source_address = packet.from();
+ buffer.ecn = packet.ecn();
if (data_read < packet.size()) {
packet.Consume(data_read);
@@ -566,9 +589,12 @@
return result;
}
}
+ EcnMarking ecn = (options_map_[Socket::OPT_SEND_ECN] == 1)
+ ? EcnMarking::kEct1
+ : EcnMarking::kNotEct;
// Send the data in a message to the appropriate socket.
- return server_->SendUdp(this, static_cast<const char*>(pv), cb, addr);
+ return server_->SendUdp(this, static_cast<const char*>(pv), cb, ecn, addr);
}
int VirtualSocket::SendTcp(const void* pv, size_t cb) {
@@ -961,6 +987,7 @@
int VirtualSocketServer::SendUdp(VirtualSocket* socket,
const char* data,
size_t data_size,
+ EcnMarking ecn,
const SocketAddress& remote_addr) {
{
MutexLock lock(&mutex_);
@@ -1029,7 +1056,7 @@
}
AddPacketToNetwork(socket, recipient, cur_time, data, data_size,
- UDP_HEADER_SIZE, false);
+ UDP_HEADER_SIZE, false, ecn);
return static_cast<int>(data_size);
}
@@ -1072,7 +1099,7 @@
break;
AddPacketToNetwork(socket, recipient, cur_time, socket->send_buffer_data(),
- data_size, TCP_HEADER_SIZE, true);
+ data_size, TCP_HEADER_SIZE, true, EcnMarking::kNotEct);
recipient->UpdateRecv(data_size);
socket->UpdateSend(data_size);
}
@@ -1092,7 +1119,8 @@
const char* data,
size_t data_size,
size_t header_size,
- bool ordered) {
+ bool ordered,
+ EcnMarking ecn) {
RTC_DCHECK(msg_queue_);
uint32_t send_delay = sender->AddPacket(cur_time, data_size + header_size);
@@ -1114,7 +1142,7 @@
}
recipient->PostPacket(
TimeDelta::Millis(ts - cur_time),
- std::make_unique<VirtualSocketPacket>(data, data_size, sender_addr));
+ std::make_unique<VirtualSocketPacket>(data, data_size, ecn, sender_addr));
}
uint32_t VirtualSocketServer::SendDelay(uint32_t size) {
diff --git a/rtc_base/virtual_socket_server.h b/rtc_base/virtual_socket_server.h
index c191d4a..c144264 100644
--- a/rtc_base/virtual_socket_server.h
+++ b/rtc_base/virtual_socket_server.h
@@ -21,9 +21,11 @@
#include <utility>
#include <vector>
+#include "absl/functional/any_invocable.h"
#include "api/make_ref_counted.h"
#include "api/ref_counted_base.h"
#include "api/scoped_refptr.h"
+#include "api/transport/ecn_marking.h"
#include "api/units/time_delta.h"
#include "rtc_base/event.h"
#include "rtc_base/fake_clock.h"
@@ -62,6 +64,7 @@
size_t cb,
SocketAddress* paddr,
int64_t* timestamp) override;
+ int RecvFrom(ReceiveBuffer& buffer) override;
int Listen(int backlog) override;
VirtualSocket* Accept(SocketAddress* paddr) override;
@@ -114,10 +117,11 @@
void SetNotAlive();
bool IsAlive();
- // Copies up to `size` bytes into buffer from the next received packet
- // and fills `addr` with remote address of that received packet.
- // Returns number of bytes copied or negative value on failure.
- int RecvFrom(void* buffer, size_t size, SocketAddress& addr);
+ // Copies up to `buffer.payload.capacity()` bytes into `buffer.payload()`
+ // from the next received packet and fills `addr` with remote address of
+ // that received packet. Returns number of bytes copied or negative value on
+ // failure.
+ int RecvFrom(ReceiveBuffer& buffer);
void Listen();
@@ -179,6 +183,7 @@
void CompleteConnect(const SocketAddress& addr);
int SendUdp(const void* pv, size_t cb, const SocketAddress& addr);
int SendTcp(const void* pv, size_t cb);
+ int DoRecvFrom(ReceiveBuffer& buffer);
void OnSocketServerReadyToSend();
@@ -369,6 +374,7 @@
int SendUdp(VirtualSocket* socket,
const char* data,
size_t data_size,
+ EcnMarking ecn,
const SocketAddress& remote_addr);
// Moves as much data as possible from the sender's buffer to the network
@@ -414,7 +420,8 @@
const char* data,
size_t data_size,
size_t header_size,
- bool ordered);
+ bool ordered,
+ EcnMarking ecn);
// If the delay has been set for the address of the socket, returns the set
// delay. Otherwise, returns a random transit delay chosen from the
diff --git a/rtc_base/virtual_socket_unittest.cc b/rtc_base/virtual_socket_unittest.cc
index f4404a0..9fadd19 100644
--- a/rtc_base/virtual_socket_unittest.cc
+++ b/rtc_base/virtual_socket_unittest.cc
@@ -19,6 +19,7 @@
#include "absl/memory/memory.h"
#include "api/environment/environment.h"
+#include "api/transport/ecn_marking.h"
#include "api/units/time_delta.h"
#include "rtc_base/async_packet_socket.h"
#include "rtc_base/async_udp_socket.h"
@@ -36,11 +37,13 @@
#include "rtc_base/thread.h"
#include "rtc_base/virtual_socket_server.h"
#include "test/create_test_environment.h"
+#include "test/gmock.h"
#include "test/gtest.h"
namespace webrtc {
namespace {
+using ::testing::NotNull;
using testing::SSE_CLOSE;
using testing::SSE_ERROR;
using testing::SSE_OPEN;
@@ -853,6 +856,41 @@
}
}
+ void SendReceiveEcn(const SocketAddress& initial_addr) {
+ std::unique_ptr<Socket> socket =
+ ss_.Create(initial_addr.family(), SOCK_DGRAM);
+ socket->Bind(initial_addr);
+ SocketAddress server_addr = socket->GetLocalAddress();
+
+ TestClient client1(
+ std::make_unique<AsyncUDPSocket>(env_, std::move(socket)),
+ &fake_clock_);
+
+ SocketAddress client2_addr;
+ std::unique_ptr<Socket> socket2 =
+ ss_.Create(initial_addr.family(), SOCK_DGRAM);
+ TestClient client2(
+ std::make_unique<AsyncUDPSocket>(env_, std::move(socket2)),
+ &fake_clock_);
+
+ client2.SendTo("foo", 3, server_addr);
+ std::unique_ptr<TestClient::Packet> packet_1 = client1.NextPacket();
+ ASSERT_THAT(packet_1.get(), NotNull());
+ EXPECT_EQ(packet_1->ecn, EcnMarking::kNotEct);
+
+ client2.SetOption(Socket::OPT_SEND_ECN, 1);
+ client2.SendTo("bar", 3, server_addr);
+ std::unique_ptr<TestClient::Packet> packet_2 = client1.NextPacket();
+ ASSERT_THAT(packet_2.get(), NotNull());
+ EXPECT_EQ(packet_2->ecn, EcnMarking::kNotEct);
+
+ client1.SetOption(Socket::OPT_RECV_ECN, 1);
+ client2.SendTo("bar", 3, server_addr);
+ std::unique_ptr<TestClient::Packet> packet_3 = client1.NextPacket();
+ ASSERT_THAT(packet_3.get(), NotNull());
+ EXPECT_EQ(packet_3->ecn, EcnMarking::kEct1);
+ }
+
protected:
ScopedFakeClock fake_clock_;
const Environment env_ = CreateTestEnvironment();
@@ -916,6 +954,10 @@
CloseTest(kIPv6AnyAddress);
}
+TEST_F(VirtualSocketServerTest, SendReceiveEcn) {
+ SendReceiveEcn(kIPv4AnyAddress);
+}
+
TEST_F(VirtualSocketServerTest, tcp_send_v4) {
TcpSendTest(kIPv4AnyAddress);
}