Refactor NatServer to use rtc::ReceivedPackets
Instead of using raw pointers.
Also, ensure callbacks are registered on the correct thread.
Nat servers are test only code.
Bug: webrtc:11943
Change-Id: Ib70a5966acb512f1a07212a07aaedab70aa20f9b
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/331260
Commit-Queue: Per Kjellander <perkj@webrtc.org>
Reviewed-by: Harald Alvestrand <hta@webrtc.org>
Reviewed-by: Jonas Oreland <jonaso@webrtc.org>
Reviewed-by: Jonas Oreland <jonaso@google.com>
Cr-Commit-Position: refs/heads/main@{#41372}
diff --git a/p2p/base/port_unittest.cc b/p2p/base/port_unittest.cc
index 96c1bd5..f5f3ee0 100644
--- a/p2p/base/port_unittest.cc
+++ b/p2p/base/port_unittest.cc
@@ -620,8 +620,8 @@
std::unique_ptr<rtc::NATServer> CreateNatServer(const SocketAddress& addr,
rtc::NATType type) {
- return std::make_unique<rtc::NATServer>(type, ss_.get(), addr, addr,
- ss_.get(), addr);
+ return std::make_unique<rtc::NATServer>(type, main_, ss_.get(), addr, addr,
+ main_, ss_.get(), addr);
}
static const char* StunName(NATType type) {
switch (type) {
diff --git a/p2p/client/basic_port_allocator_unittest.cc b/p2p/client/basic_port_allocator_unittest.cc
index defcab0..65f8e43 100644
--- a/p2p/client/basic_port_allocator_unittest.cc
+++ b/p2p/client/basic_port_allocator_unittest.cc
@@ -496,8 +496,8 @@
bool with_nat) {
if (with_nat) {
nat_server_.reset(new rtc::NATServer(
- rtc::NAT_OPEN_CONE, vss_.get(), kNatUdpAddr, kNatTcpAddr, vss_.get(),
- rtc::SocketAddress(kNatUdpAddr.ipaddr(), 0)));
+ rtc::NAT_OPEN_CONE, thread_, vss_.get(), kNatUdpAddr, kNatTcpAddr,
+ thread_, vss_.get(), rtc::SocketAddress(kNatUdpAddr.ipaddr(), 0)));
} else {
nat_socket_factory_ =
std::make_unique<rtc::BasicPacketSocketFactory>(fss_.get());
diff --git a/rtc_base/nat_server.cc b/rtc_base/nat_server.cc
index b818685..c274ced 100644
--- a/rtc_base/nat_server.cc
+++ b/rtc_base/nat_server.cc
@@ -10,12 +10,15 @@
#include "rtc_base/nat_server.h"
+#include <cstddef>
#include <memory>
#include "rtc_base/checks.h"
#include "rtc_base/logging.h"
#include "rtc_base/nat_socket_factory.h"
+#include "rtc_base/network/received_packet.h"
#include "rtc_base/socket_adapters.h"
+#include "rtc_base/socket_address.h"
namespace rtc {
@@ -125,17 +128,27 @@
};
NATServer::NATServer(NATType type,
+ rtc::Thread& internal_socket_thread,
SocketFactory* internal,
const SocketAddress& internal_udp_addr,
const SocketAddress& internal_tcp_addr,
+ rtc::Thread& external_socket_thread,
SocketFactory* external,
const SocketAddress& external_ip)
- : external_(external), external_ip_(external_ip.ipaddr(), 0) {
+ : internal_socket_thread_(internal_socket_thread),
+ external_socket_thread_(external_socket_thread),
+ external_(external),
+ external_ip_(external_ip.ipaddr(), 0) {
nat_ = NAT::Create(type);
- udp_server_socket_ = AsyncUDPSocket::Create(internal, internal_udp_addr);
- udp_server_socket_->SignalReadPacket.connect(this,
- &NATServer::OnInternalUDPPacket);
+ internal_socket_thread_.BlockingCall([&] {
+ udp_server_socket_ = AsyncUDPSocket::Create(internal, internal_udp_addr);
+ udp_server_socket_->RegisterReceivedPacketCallback(
+ [&](rtc::AsyncPacketSocket* socket, const rtc::ReceivedPacket& packet) {
+ OnInternalUDPPacket(socket, packet);
+ });
+ });
+
tcp_proxy_server_ =
new NATProxyServer(internal, internal_tcp_addr, external, external_ip);
@@ -156,10 +169,11 @@
}
void NATServer::OnInternalUDPPacket(AsyncPacketSocket* socket,
- const char* buf,
- size_t size,
- const SocketAddress& addr,
- const int64_t& /* packet_time_us */) {
+ const rtc::ReceivedPacket& packet) {
+ RTC_DCHECK(internal_socket_thread_.IsCurrent());
+ const char* buf = reinterpret_cast<const char*>(packet.payload().data());
+ size_t size = packet.payload().size();
+ const SocketAddress& addr = packet.source_address();
// Read the intended destination from the wire.
SocketAddress dest_addr;
size_t length = UnpackAddressFromNAT(buf, size, &dest_addr);
@@ -182,10 +196,8 @@
}
void NATServer::OnExternalUDPPacket(AsyncPacketSocket* socket,
- const char* buf,
- size_t size,
- const SocketAddress& remote_addr,
- const int64_t& /* packet_time_us */) {
+ const rtc::ReceivedPacket& packet) {
+ RTC_DCHECK(external_socket_thread_.IsCurrent());
SocketAddress local_addr = socket->GetLocalAddress();
// Find the translation for this addresses.
@@ -193,36 +205,46 @@
RTC_DCHECK(iter != ext_map_->end());
// Allow the NAT to reject this packet.
- if (ShouldFilterOut(iter->second, remote_addr)) {
- RTC_LOG(LS_INFO) << "Packet from " << remote_addr.ToSensitiveString()
+ if (ShouldFilterOut(iter->second, packet.source_address())) {
+ RTC_LOG(LS_INFO) << "Packet from "
+ << packet.source_address().ToSensitiveString()
<< " was filtered out by the NAT.";
return;
}
// Forward this packet to the internal address.
// First prepend the address in a quasi-STUN format.
- std::unique_ptr<char[]> real_buf(new char[size + kNATEncodedIPv6AddressSize]);
+ std::unique_ptr<char[]> real_buf(
+ new char[packet.payload().size() + kNATEncodedIPv6AddressSize]);
size_t addrlength = PackAddressForNAT(
- real_buf.get(), size + kNATEncodedIPv6AddressSize, remote_addr);
+ real_buf.get(), packet.payload().size() + kNATEncodedIPv6AddressSize,
+ packet.source_address());
// Copy the data part after the address.
rtc::PacketOptions options;
- memcpy(real_buf.get() + addrlength, buf, size);
- udp_server_socket_->SendTo(real_buf.get(), size + addrlength,
+ memcpy(real_buf.get() + addrlength, packet.payload().data(),
+ packet.payload().size());
+ udp_server_socket_->SendTo(real_buf.get(),
+ packet.payload().size() + addrlength,
iter->second->route.source(), options);
}
void NATServer::Translate(const SocketAddressPair& route) {
- AsyncUDPSocket* socket = AsyncUDPSocket::Create(external_, external_ip_);
+ external_socket_thread_.BlockingCall([&] {
+ AsyncUDPSocket* socket = AsyncUDPSocket::Create(external_, external_ip_);
- if (!socket) {
- RTC_LOG(LS_ERROR) << "Couldn't find a free port!";
- return;
- }
+ if (!socket) {
+ RTC_LOG(LS_ERROR) << "Couldn't find a free port!";
+ return;
+ }
- TransEntry* entry = new TransEntry(route, socket, nat_);
- (*int_map_)[route] = entry;
- (*ext_map_)[socket->GetLocalAddress()] = entry;
- socket->SignalReadPacket.connect(this, &NATServer::OnExternalUDPPacket);
+ TransEntry* entry = new TransEntry(route, socket, nat_);
+ (*int_map_)[route] = entry;
+ (*ext_map_)[socket->GetLocalAddress()] = entry;
+ socket->RegisterReceivedPacketCallback(
+ [&](rtc::AsyncPacketSocket* socket, const rtc::ReceivedPacket& packet) {
+ OnExternalUDPPacket(socket, packet);
+ });
+ });
}
bool NATServer::ShouldFilterOut(TransEntry* entry,
diff --git a/rtc_base/nat_server.h b/rtc_base/nat_server.h
index acbd62a..d179efa 100644
--- a/rtc_base/nat_server.h
+++ b/rtc_base/nat_server.h
@@ -58,15 +58,17 @@
const int NAT_SERVER_UDP_PORT = 4237;
const int NAT_SERVER_TCP_PORT = 4238;
-class NATServer : public sigslot::has_slots<> {
+class NATServer {
public:
NATServer(NATType type,
+ rtc::Thread& internal_socket_thread,
SocketFactory* internal,
const SocketAddress& internal_udp_addr,
const SocketAddress& internal_tcp_addr,
+ rtc::Thread& external_socket_thread,
SocketFactory* external,
const SocketAddress& external_ip);
- ~NATServer() override;
+ ~NATServer();
NATServer(const NATServer&) = delete;
NATServer& operator=(const NATServer&) = delete;
@@ -81,15 +83,9 @@
// Packets received on one of the networks.
void OnInternalUDPPacket(AsyncPacketSocket* socket,
- const char* buf,
- size_t size,
- const SocketAddress& addr,
- const int64_t& packet_time_us);
+ const rtc::ReceivedPacket& packet);
void OnExternalUDPPacket(AsyncPacketSocket* socket,
- const char* buf,
- size_t size,
- const SocketAddress& remote_addr,
- const int64_t& packet_time_us);
+ const rtc::ReceivedPacket& packet);
private:
typedef std::set<SocketAddress, AddrCmp> AddressSet;
@@ -118,6 +114,8 @@
bool ShouldFilterOut(TransEntry* entry, const SocketAddress& ext_addr);
NAT* nat_;
+ rtc::Thread& internal_socket_thread_;
+ rtc::Thread& external_socket_thread_;
SocketFactory* external_;
SocketAddress external_ip_;
AsyncUDPSocket* udp_server_socket_;
diff --git a/rtc_base/nat_socket_factory.cc b/rtc_base/nat_socket_factory.cc
index fe021b9..83ec2bc 100644
--- a/rtc_base/nat_socket_factory.cc
+++ b/rtc_base/nat_socket_factory.cc
@@ -368,7 +368,8 @@
if (nats_.Get(ext_ip))
return nullptr;
- return nats_.Add(ext_ip, new Translator(this, type, int_ip, server_, ext_ip));
+ return nats_.Add(
+ ext_ip, new Translator(this, type, int_ip, *msg_queue_, server_, ext_ip));
}
void NATSocketServer::RemoveTranslator(const SocketAddress& ext_ip) {
@@ -413,6 +414,7 @@
NATSocketServer::Translator::Translator(NATSocketServer* server,
NATType type,
const SocketAddress& int_ip,
+ Thread& external_socket_thread,
SocketFactory* ext_factory,
const SocketAddress& ext_ip)
: server_(server) {
@@ -422,7 +424,8 @@
internal_server_ = std::make_unique<VirtualSocketServer>();
internal_server_->SetMessageQueue(server_->queue());
nat_server_ = std::make_unique<NATServer>(
- type, internal_server_.get(), int_ip, int_ip, ext_factory, ext_ip);
+ type, *server->queue(), internal_server_.get(), int_ip, int_ip,
+ external_socket_thread, ext_factory, ext_ip);
}
NATSocketServer::Translator::~Translator() {
@@ -443,8 +446,8 @@
return nullptr;
AddClient(ext_ip);
- return nats_.Add(ext_ip,
- new Translator(server_, type, int_ip, server_, ext_ip));
+ return nats_.Add(ext_ip, new Translator(server_, type, int_ip,
+ *server_->queue(), server_, ext_ip));
}
void NATSocketServer::Translator::RemoveTranslator(
const SocketAddress& ext_ip) {
diff --git a/rtc_base/nat_socket_factory.h b/rtc_base/nat_socket_factory.h
index 0b301b5..f803496 100644
--- a/rtc_base/nat_socket_factory.h
+++ b/rtc_base/nat_socket_factory.h
@@ -102,6 +102,7 @@
Translator(NATSocketServer* server,
NATType type,
const SocketAddress& int_addr,
+ Thread& external_socket_thread,
SocketFactory* ext_factory,
const SocketAddress& ext_addr);
~Translator();
diff --git a/rtc_base/nat_unittest.cc b/rtc_base/nat_unittest.cc
index 432985d..742e0d6 100644
--- a/rtc_base/nat_unittest.cc
+++ b/rtc_base/nat_unittest.cc
@@ -76,16 +76,17 @@
Thread th_int(internal);
Thread th_ext(external);
- SocketAddress server_addr = internal_addr;
- server_addr.SetPort(0); // Auto-select a port
- NATServer* nat = new NATServer(nat_type, internal, server_addr, server_addr,
- external, external_addrs[0]);
- NATSocketFactory* natsf = new NATSocketFactory(
- internal, nat->internal_udp_address(), nat->internal_tcp_address());
-
th_int.Start();
th_ext.Start();
+ SocketAddress server_addr = internal_addr;
+ server_addr.SetPort(0); // Auto-select a port
+ NATServer* nat =
+ new NATServer(nat_type, th_int, internal, server_addr, server_addr,
+ th_ext, external, external_addrs[0]);
+ NATSocketFactory* natsf = new NATSocketFactory(
+ internal, nat->internal_udp_address(), nat->internal_tcp_address());
+
TestClient* in;
th_int.BlockingCall([&] { in = CreateTestClient(natsf, internal_addr); });
@@ -139,13 +140,13 @@
SocketAddress server_addr = internal_addr;
server_addr.SetPort(0); // Auto-select a port
- NATServer* nat = new NATServer(nat_type, internal, server_addr, server_addr,
- external, external_addrs[0]);
- NATSocketFactory* natsf = new NATSocketFactory(
- internal, nat->internal_udp_address(), nat->internal_tcp_address());
-
th_int.Start();
th_ext.Start();
+ NATServer* nat =
+ new NATServer(nat_type, th_int, internal, server_addr, server_addr,
+ th_ext, external, external_addrs[0]);
+ NATSocketFactory* natsf = new NATSocketFactory(
+ internal, nat->internal_udp_address(), nat->internal_tcp_address());
TestClient* in = nullptr;
th_int.BlockingCall([&] { in = CreateTestClient(natsf, internal_addr); });
@@ -355,9 +356,11 @@
int_thread_(new Thread(int_vss_.get())),
ext_thread_(new Thread(ext_vss_.get())),
nat_(new NATServer(NAT_OPEN_CONE,
+ *int_thread_,
int_vss_.get(),
int_addr_,
int_addr_,
+ *ext_thread_,
ext_vss_.get(),
ext_addr_)),
natsf_(new NATSocketFactory(int_vss_.get(),