blob: 874e9387c688a182d230f64401b6201bc3039570 [file] [log] [blame]
/*
* Copyright 2004 The WebRTC Project Authors. All rights reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "rtc_base/nat_socket_factory.h"
#include "rtc_base/arraysize.h"
#include "rtc_base/checks.h"
#include "rtc_base/logging.h"
#include "rtc_base/nat_server.h"
#include "rtc_base/virtual_socket_server.h"
namespace rtc {
// Packs the given socketaddress into the buffer in buf, in the quasi-STUN
// format that the natserver uses.
// Returns 0 if an invalid address is passed.
size_t PackAddressForNAT(char* buf,
size_t buf_size,
const SocketAddress& remote_addr) {
const IPAddress& ip = remote_addr.ipaddr();
int family = ip.family();
buf[0] = 0;
buf[1] = family;
// Writes the port.
*(reinterpret_cast<uint16_t*>(&buf[2])) = HostToNetwork16(remote_addr.port());
if (family == AF_INET) {
RTC_DCHECK(buf_size >= kNATEncodedIPv4AddressSize);
in_addr v4addr = ip.ipv4_address();
memcpy(&buf[4], &v4addr, kNATEncodedIPv4AddressSize - 4);
return kNATEncodedIPv4AddressSize;
} else if (family == AF_INET6) {
RTC_DCHECK(buf_size >= kNATEncodedIPv6AddressSize);
in6_addr v6addr = ip.ipv6_address();
memcpy(&buf[4], &v6addr, kNATEncodedIPv6AddressSize - 4);
return kNATEncodedIPv6AddressSize;
}
return 0U;
}
// Decodes the remote address from a packet that has been encoded with the nat's
// quasi-STUN format. Returns the length of the address (i.e., the offset into
// data where the original packet starts).
size_t UnpackAddressFromNAT(const char* buf,
size_t buf_size,
SocketAddress* remote_addr) {
RTC_DCHECK(buf_size >= 8);
RTC_DCHECK(buf[0] == 0);
int family = buf[1];
uint16_t port =
NetworkToHost16(*(reinterpret_cast<const uint16_t*>(&buf[2])));
if (family == AF_INET) {
const in_addr* v4addr = reinterpret_cast<const in_addr*>(&buf[4]);
*remote_addr = SocketAddress(IPAddress(*v4addr), port);
return kNATEncodedIPv4AddressSize;
} else if (family == AF_INET6) {
RTC_DCHECK(buf_size >= 20);
const in6_addr* v6addr = reinterpret_cast<const in6_addr*>(&buf[4]);
*remote_addr = SocketAddress(IPAddress(*v6addr), port);
return kNATEncodedIPv6AddressSize;
}
return 0U;
}
// NATSocket
class NATSocket : public AsyncSocket, public sigslot::has_slots<> {
public:
explicit NATSocket(NATInternalSocketFactory* sf, int family, int type)
: sf_(sf),
family_(family),
type_(type),
connected_(false),
socket_(nullptr),
buf_(nullptr),
size_(0) {}
~NATSocket() override {
delete socket_;
delete[] buf_;
}
SocketAddress GetLocalAddress() const override {
return (socket_) ? socket_->GetLocalAddress() : SocketAddress();
}
SocketAddress GetRemoteAddress() const override {
return remote_addr_; // will be NIL if not connected
}
int Bind(const SocketAddress& addr) override {
if (socket_) { // already bound, bubble up error
return -1;
}
return BindInternal(addr);
}
int Connect(const SocketAddress& addr) override {
int result = 0;
// If we're not already bound (meaning `socket_` is null), bind to ANY
// address.
if (!socket_) {
result = BindInternal(SocketAddress(GetAnyIP(family_), 0));
if (result < 0) {
return result;
}
}
if (type_ == SOCK_STREAM) {
result = socket_->Connect(server_addr_.IsNil() ? addr : server_addr_);
} else {
connected_ = true;
}
if (result >= 0) {
remote_addr_ = addr;
}
return result;
}
int Send(const void* data, size_t size) override {
RTC_DCHECK(connected_);
return SendTo(data, size, remote_addr_);
}
int SendTo(const void* data,
size_t size,
const SocketAddress& addr) override {
RTC_DCHECK(!connected_ || addr == remote_addr_);
if (server_addr_.IsNil() || type_ == SOCK_STREAM) {
return socket_->SendTo(data, size, addr);
}
// This array will be too large for IPv4 packets, but only by 12 bytes.
std::unique_ptr<char[]> buf(new char[size + kNATEncodedIPv6AddressSize]);
size_t addrlength =
PackAddressForNAT(buf.get(), size + kNATEncodedIPv6AddressSize, addr);
size_t encoded_size = size + addrlength;
memcpy(buf.get() + addrlength, data, size);
int result = socket_->SendTo(buf.get(), encoded_size, server_addr_);
if (result >= 0) {
RTC_DCHECK(result == static_cast<int>(encoded_size));
result = result - static_cast<int>(addrlength);
}
return result;
}
int Recv(void* data, size_t size, int64_t* timestamp) override {
SocketAddress addr;
return RecvFrom(data, size, &addr, timestamp);
}
int RecvFrom(void* data,
size_t size,
SocketAddress* out_addr,
int64_t* timestamp) override {
if (server_addr_.IsNil() || type_ == SOCK_STREAM) {
return socket_->RecvFrom(data, size, out_addr, timestamp);
}
// Make sure we have enough room to read the requested amount plus the
// largest possible header address.
SocketAddress remote_addr;
Grow(size + kNATEncodedIPv6AddressSize);
// Read the packet from the socket.
int result = socket_->RecvFrom(buf_, size_, &remote_addr, timestamp);
if (result >= 0) {
RTC_DCHECK(remote_addr == server_addr_);
// TODO: we need better framing so we know how many bytes we can
// return before we need to read the next address. For UDP, this will be
// fine as long as the reader always reads everything in the packet.
RTC_DCHECK((size_t)result < size_);
// Decode the wire packet into the actual results.
SocketAddress real_remote_addr;
size_t addrlength = UnpackAddressFromNAT(buf_, result, &real_remote_addr);
memcpy(data, buf_ + addrlength, result - addrlength);
// Make sure this packet should be delivered before returning it.
if (!connected_ || (real_remote_addr == remote_addr_)) {
if (out_addr)
*out_addr = real_remote_addr;
result = result - static_cast<int>(addrlength);
} else {
RTC_LOG(LS_ERROR) << "Dropping packet from unknown remote address: "
<< real_remote_addr.ToString();
result = 0; // Tell the caller we didn't read anything
}
}
return result;
}
int Close() override {
int result = 0;
if (socket_) {
result = socket_->Close();
if (result >= 0) {
connected_ = false;
remote_addr_ = SocketAddress();
delete socket_;
socket_ = nullptr;
}
}
return result;
}
int Listen(int backlog) override { return socket_->Listen(backlog); }
AsyncSocket* Accept(SocketAddress* paddr) override {
return socket_->Accept(paddr);
}
int GetError() const override {
return socket_ ? socket_->GetError() : error_;
}
void SetError(int error) override {
if (socket_) {
socket_->SetError(error);
} else {
error_ = error;
}
}
ConnState GetState() const override {
return connected_ ? CS_CONNECTED : CS_CLOSED;
}
int GetOption(Option opt, int* value) override {
return socket_ ? socket_->GetOption(opt, value) : -1;
}
int SetOption(Option opt, int value) override {
return socket_ ? socket_->SetOption(opt, value) : -1;
}
void OnConnectEvent(AsyncSocket* socket) {
// If we're NATed, we need to send a message with the real addr to use.
RTC_DCHECK(socket == socket_);
if (server_addr_.IsNil()) {
connected_ = true;
SignalConnectEvent(this);
} else {
SendConnectRequest();
}
}
void OnReadEvent(AsyncSocket* socket) {
// If we're NATed, we need to process the connect reply.
RTC_DCHECK(socket == socket_);
if (type_ == SOCK_STREAM && !server_addr_.IsNil() && !connected_) {
HandleConnectReply();
} else {
SignalReadEvent(this);
}
}
void OnWriteEvent(AsyncSocket* socket) {
RTC_DCHECK(socket == socket_);
SignalWriteEvent(this);
}
void OnCloseEvent(AsyncSocket* socket, int error) {
RTC_DCHECK(socket == socket_);
SignalCloseEvent(this, error);
}
private:
int BindInternal(const SocketAddress& addr) {
RTC_DCHECK(!socket_);
int result;
socket_ = sf_->CreateInternalSocket(family_, type_, addr, &server_addr_);
result = (socket_) ? socket_->Bind(addr) : -1;
if (result >= 0) {
socket_->SignalConnectEvent.connect(this, &NATSocket::OnConnectEvent);
socket_->SignalReadEvent.connect(this, &NATSocket::OnReadEvent);
socket_->SignalWriteEvent.connect(this, &NATSocket::OnWriteEvent);
socket_->SignalCloseEvent.connect(this, &NATSocket::OnCloseEvent);
} else {
server_addr_.Clear();
delete socket_;
socket_ = nullptr;
}
return result;
}
// Makes sure the buffer is at least the given size.
void Grow(size_t new_size) {
if (size_ < new_size) {
delete[] buf_;
size_ = new_size;
buf_ = new char[size_];
}
}
// Sends the destination address to the server to tell it to connect.
void SendConnectRequest() {
char buf[kNATEncodedIPv6AddressSize];
size_t length = PackAddressForNAT(buf, arraysize(buf), remote_addr_);
socket_->Send(buf, length);
}
// Handles the byte sent back from the server and fires the appropriate event.
void HandleConnectReply() {
char code;
socket_->Recv(&code, sizeof(code), nullptr);
if (code == 0) {
connected_ = true;
SignalConnectEvent(this);
} else {
Close();
SignalCloseEvent(this, code);
}
}
NATInternalSocketFactory* sf_;
int family_;
int type_;
bool connected_;
SocketAddress remote_addr_;
SocketAddress server_addr_; // address of the NAT server
AsyncSocket* socket_;
// Need to hold error in case it occurs before the socket is created.
int error_ = 0;
char* buf_;
size_t size_;
};
// NATSocketFactory
NATSocketFactory::NATSocketFactory(SocketFactory* factory,
const SocketAddress& nat_udp_addr,
const SocketAddress& nat_tcp_addr)
: factory_(factory),
nat_udp_addr_(nat_udp_addr),
nat_tcp_addr_(nat_tcp_addr) {}
Socket* NATSocketFactory::CreateSocket(int family, int type) {
return new NATSocket(this, family, type);
}
AsyncSocket* NATSocketFactory::CreateAsyncSocket(int family, int type) {
return new NATSocket(this, family, type);
}
AsyncSocket* NATSocketFactory::CreateInternalSocket(
int family,
int type,
const SocketAddress& local_addr,
SocketAddress* nat_addr) {
if (type == SOCK_STREAM) {
*nat_addr = nat_tcp_addr_;
} else {
*nat_addr = nat_udp_addr_;
}
return factory_->CreateAsyncSocket(family, type);
}
// NATSocketServer
NATSocketServer::NATSocketServer(SocketServer* server)
: server_(server), msg_queue_(nullptr) {}
NATSocketServer::Translator* NATSocketServer::GetTranslator(
const SocketAddress& ext_ip) {
return nats_.Get(ext_ip);
}
NATSocketServer::Translator* NATSocketServer::AddTranslator(
const SocketAddress& ext_ip,
const SocketAddress& int_ip,
NATType type) {
// Fail if a translator already exists with this extternal address.
if (nats_.Get(ext_ip))
return nullptr;
return nats_.Add(ext_ip, new Translator(this, type, int_ip, server_, ext_ip));
}
void NATSocketServer::RemoveTranslator(const SocketAddress& ext_ip) {
nats_.Remove(ext_ip);
}
Socket* NATSocketServer::CreateSocket(int family, int type) {
return new NATSocket(this, family, type);
}
AsyncSocket* NATSocketServer::CreateAsyncSocket(int family, int type) {
return new NATSocket(this, family, type);
}
void NATSocketServer::SetMessageQueue(Thread* queue) {
msg_queue_ = queue;
server_->SetMessageQueue(queue);
}
bool NATSocketServer::Wait(int cms, bool process_io) {
return server_->Wait(cms, process_io);
}
void NATSocketServer::WakeUp() {
server_->WakeUp();
}
AsyncSocket* NATSocketServer::CreateInternalSocket(
int family,
int type,
const SocketAddress& local_addr,
SocketAddress* nat_addr) {
AsyncSocket* socket = nullptr;
Translator* nat = nats_.FindClient(local_addr);
if (nat) {
socket = nat->internal_factory()->CreateAsyncSocket(family, type);
*nat_addr = (type == SOCK_STREAM) ? nat->internal_tcp_address()
: nat->internal_udp_address();
} else {
socket = server_->CreateAsyncSocket(family, type);
}
return socket;
}
// NATSocketServer::Translator
NATSocketServer::Translator::Translator(NATSocketServer* server,
NATType type,
const SocketAddress& int_ip,
SocketFactory* ext_factory,
const SocketAddress& ext_ip)
: server_(server) {
// Create a new private network, and a NATServer running on the private
// network that bridges to the external network. Also tell the private
// network to use the same message queue as us.
internal_server_ = std::make_unique<VirtualSocketServer>();
internal_server_->SetMessageQueue(server_->queue());
nat_server_ = std::make_unique<NATServer>(
type, internal_server_.get(), int_ip, int_ip, ext_factory, ext_ip);
}
NATSocketServer::Translator::~Translator() {
internal_server_->SetMessageQueue(nullptr);
}
NATSocketServer::Translator* NATSocketServer::Translator::GetTranslator(
const SocketAddress& ext_ip) {
return nats_.Get(ext_ip);
}
NATSocketServer::Translator* NATSocketServer::Translator::AddTranslator(
const SocketAddress& ext_ip,
const SocketAddress& int_ip,
NATType type) {
// Fail if a translator already exists with this extternal address.
if (nats_.Get(ext_ip))
return nullptr;
AddClient(ext_ip);
return nats_.Add(ext_ip,
new Translator(server_, type, int_ip, server_, ext_ip));
}
void NATSocketServer::Translator::RemoveTranslator(
const SocketAddress& ext_ip) {
nats_.Remove(ext_ip);
RemoveClient(ext_ip);
}
bool NATSocketServer::Translator::AddClient(const SocketAddress& int_ip) {
// Fail if a client already exists with this internal address.
if (clients_.find(int_ip) != clients_.end())
return false;
clients_.insert(int_ip);
return true;
}
void NATSocketServer::Translator::RemoveClient(const SocketAddress& int_ip) {
std::set<SocketAddress>::iterator it = clients_.find(int_ip);
if (it != clients_.end()) {
clients_.erase(it);
}
}
NATSocketServer::Translator* NATSocketServer::Translator::FindClient(
const SocketAddress& int_ip) {
// See if we have the requested IP, or any of our children do.
return (clients_.find(int_ip) != clients_.end()) ? this
: nats_.FindClient(int_ip);
}
// NATSocketServer::TranslatorMap
NATSocketServer::TranslatorMap::~TranslatorMap() {
for (TranslatorMap::iterator it = begin(); it != end(); ++it) {
delete it->second;
}
}
NATSocketServer::Translator* NATSocketServer::TranslatorMap::Get(
const SocketAddress& ext_ip) {
TranslatorMap::iterator it = find(ext_ip);
return (it != end()) ? it->second : nullptr;
}
NATSocketServer::Translator* NATSocketServer::TranslatorMap::Add(
const SocketAddress& ext_ip,
Translator* nat) {
(*this)[ext_ip] = nat;
return nat;
}
void NATSocketServer::TranslatorMap::Remove(const SocketAddress& ext_ip) {
TranslatorMap::iterator it = find(ext_ip);
if (it != end()) {
delete it->second;
erase(it);
}
}
NATSocketServer::Translator* NATSocketServer::TranslatorMap::FindClient(
const SocketAddress& int_ip) {
Translator* nat = nullptr;
for (TranslatorMap::iterator it = begin(); it != end() && !nat; ++it) {
nat = it->second->FindClient(int_ip);
}
return nat;
}
} // namespace rtc