Disable SSLAdapter methods Listen and Accept
Only affects turn server. Refactored to wrap sockets with SSLAdapter
after Accept, using the SSLAdapterFactory to hold needed configuration.
Bug: webrtc:13065
Change-Id: I5df65aad5728d8d40d95b22db6398a573ec7a36f
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/235823
Reviewed-by: Harald Alvestrand <hta@webrtc.org>
Commit-Queue: Niels Moller <nisse@webrtc.org>
Cr-Commit-Position: refs/heads/main@{#35258}
diff --git a/p2p/base/test_turn_server.h b/p2p/base/test_turn_server.h
index e1deb59..6cad135 100644
--- a/p2p/base/test_turn_server.h
+++ b/p2p/base/test_turn_server.h
@@ -11,7 +11,9 @@
#ifndef P2P_BASE_TEST_TURN_SERVER_H_
#define P2P_BASE_TEST_TURN_SERVER_H_
+#include <memory>
#include <string>
+#include <utility>
#include <vector>
#include "api/sequence_checker.h"
@@ -104,21 +106,24 @@
// new connections.
rtc::Socket* socket =
thread_->socketserver()->CreateSocket(AF_INET, SOCK_STREAM);
+ socket->Bind(int_addr);
+ socket->Listen(5);
if (proto == cricket::PROTO_TLS) {
// For TLS, wrap the TCP socket with an SSL adapter. The adapter must
// be configured with a self-signed certificate for testing.
// Additionally, the client will not present a valid certificate, so we
// must not fail when checking the peer's identity.
- rtc::SSLAdapter* adapter = rtc::SSLAdapter::Create(socket);
- adapter->SetRole(rtc::SSL_SERVER);
- adapter->SetIdentity(
+ std::unique_ptr<rtc::SSLAdapterFactory> ssl_adapter_factory =
+ rtc::SSLAdapterFactory::Create();
+ ssl_adapter_factory->SetRole(rtc::SSL_SERVER);
+ ssl_adapter_factory->SetIdentity(
rtc::SSLIdentity::Create(common_name, rtc::KeyParams()));
- adapter->SetIgnoreBadCert(ignore_bad_cert);
- socket = adapter;
+ ssl_adapter_factory->SetIgnoreBadCert(ignore_bad_cert);
+ server_.AddInternalServerSocket(socket, proto,
+ std::move(ssl_adapter_factory));
+ } else {
+ server_.AddInternalServerSocket(socket, proto);
}
- socket->Bind(int_addr);
- socket->Listen(5);
- server_.AddInternalServerSocket(socket, proto);
} else {
RTC_NOTREACHED() << "Unknown protocol type: " << proto;
}
diff --git a/p2p/base/turn_server.cc b/p2p/base/turn_server.cc
index fd9cd16..5685e20 100644
--- a/p2p/base/turn_server.cc
+++ b/p2p/base/turn_server.cc
@@ -152,12 +152,15 @@
socket->SignalReadPacket.connect(this, &TurnServer::OnInternalPacket);
}
-void TurnServer::AddInternalServerSocket(rtc::Socket* socket,
- ProtocolType proto) {
+void TurnServer::AddInternalServerSocket(
+ rtc::Socket* socket,
+ ProtocolType proto,
+ std::unique_ptr<rtc::SSLAdapterFactory> ssl_adapter_factory) {
RTC_DCHECK_RUN_ON(thread_);
+
RTC_DCHECK(server_listen_sockets_.end() ==
server_listen_sockets_.find(socket));
- server_listen_sockets_[socket] = proto;
+ server_listen_sockets_[socket] = {proto, std::move(ssl_adapter_factory)};
socket->SignalReadEvent.connect(this, &TurnServer::OnNewInternalConnection);
}
@@ -181,13 +184,19 @@
rtc::SocketAddress accept_addr;
rtc::Socket* accepted_socket = server_socket->Accept(&accept_addr);
if (accepted_socket != NULL) {
- ProtocolType proto = server_listen_sockets_[server_socket];
+ const ServerSocketInfo& info = server_listen_sockets_[server_socket];
+ if (info.ssl_adapter_factory) {
+ rtc::SSLAdapter* ssl_adapter =
+ info.ssl_adapter_factory->CreateAdapter(accepted_socket);
+ ssl_adapter->StartSSL("");
+ accepted_socket = ssl_adapter;
+ }
cricket::AsyncStunTCPSocket* tcp_socket =
new cricket::AsyncStunTCPSocket(accepted_socket);
tcp_socket->SignalClose.connect(this, &TurnServer::OnInternalSocketClose);
// Finally add the socket so it can start communicating with the client.
- AddInternalSocket(tcp_socket, proto);
+ AddInternalSocket(tcp_socket, info.proto);
}
}
diff --git a/p2p/base/turn_server.h b/p2p/base/turn_server.h
index 7942c09..481b081 100644
--- a/p2p/base/turn_server.h
+++ b/p2p/base/turn_server.h
@@ -23,6 +23,7 @@
#include "p2p/base/port_interface.h"
#include "rtc_base/async_packet_socket.h"
#include "rtc_base/socket_address.h"
+#include "rtc_base/ssl_adapter.h"
#include "rtc_base/third_party/sigslot/sigslot.h"
#include "rtc_base/thread.h"
@@ -237,7 +238,10 @@
// Starts listening for the connections on this socket. When someone tries
// to connect, the connection will be accepted and a new internal socket
// will be added.
- void AddInternalServerSocket(rtc::Socket* socket, ProtocolType proto);
+ void AddInternalServerSocket(
+ rtc::Socket* socket,
+ ProtocolType proto,
+ std::unique_ptr<rtc::SSLAdapterFactory> ssl_adapter_factory = nullptr);
// Specifies the factory to use for creating external sockets.
void SetExternalSocketFactory(rtc::PacketSocketFactory* factory,
const rtc::SocketAddress& address);
@@ -320,7 +324,12 @@
RTC_RUN_ON(thread_);
typedef std::map<rtc::AsyncPacketSocket*, ProtocolType> InternalSocketMap;
- typedef std::map<rtc::Socket*, ProtocolType> ServerSocketMap;
+ struct ServerSocketInfo {
+ ProtocolType proto;
+ // If non-null, used to wrap accepted sockets.
+ std::unique_ptr<rtc::SSLAdapterFactory> ssl_adapter_factory;
+ };
+ typedef std::map<rtc::Socket*, ServerSocketInfo> ServerSocketMap;
rtc::Thread* const thread_;
const std::string nonce_key_;
diff --git a/rtc_base/openssl_adapter.cc b/rtc_base/openssl_adapter.cc
index 7489bc9..bc10e61 100644
--- a/rtc_base/openssl_adapter.cc
+++ b/rtc_base/openssl_adapter.cc
@@ -250,21 +250,6 @@
role_ = role;
}
-Socket* OpenSSLAdapter::Accept(SocketAddress* paddr) {
- RTC_DCHECK(role_ == SSL_SERVER);
- Socket* socket = SSLAdapter::Accept(paddr);
- if (!socket) {
- return nullptr;
- }
-
- SSLAdapter* adapter = SSLAdapter::Create(socket);
- adapter->SetIdentity(identity_->Clone());
- adapter->SetRole(rtc::SSL_SERVER);
- adapter->SetIgnoreBadCert(ignore_bad_cert_);
- adapter->StartSSL("");
- return adapter;
-}
-
int OpenSSLAdapter::StartSSL(const char* hostname) {
if (state_ != SSL_NONE)
return -1;
@@ -1038,6 +1023,21 @@
ssl_cert_verifier_ = ssl_cert_verifier;
}
+void OpenSSLAdapterFactory::SetIdentity(std::unique_ptr<SSLIdentity> identity) {
+ RTC_DCHECK(!ssl_session_cache_);
+ identity_ = std::move(identity);
+}
+
+void OpenSSLAdapterFactory::SetRole(SSLRole role) {
+ RTC_DCHECK(!ssl_session_cache_);
+ ssl_role_ = role;
+}
+
+void OpenSSLAdapterFactory::SetIgnoreBadCert(bool ignore) {
+ RTC_DCHECK(!ssl_session_cache_);
+ ignore_bad_cert_ = ignore;
+}
+
OpenSSLAdapter* OpenSSLAdapterFactory::CreateAdapter(Socket* socket) {
if (ssl_session_cache_ == nullptr) {
SSL_CTX* ssl_ctx = OpenSSLAdapter::CreateContext(ssl_mode_, true);
@@ -1049,8 +1049,14 @@
std::make_unique<OpenSSLSessionCache>(ssl_mode_, ssl_ctx);
SSL_CTX_free(ssl_ctx);
}
- return new OpenSSLAdapter(socket, ssl_session_cache_.get(),
- ssl_cert_verifier_);
+ OpenSSLAdapter* ssl_adapter =
+ new OpenSSLAdapter(socket, ssl_session_cache_.get(), ssl_cert_verifier_);
+ ssl_adapter->SetRole(ssl_role_);
+ ssl_adapter->SetIgnoreBadCert(ignore_bad_cert_);
+ if (identity_) {
+ ssl_adapter->SetIdentity(identity_->Clone());
+ }
+ return ssl_adapter;
}
OpenSSLAdapter::EarlyExitCatcher::EarlyExitCatcher(OpenSSLAdapter& adapter_ptr)
diff --git a/rtc_base/openssl_adapter.h b/rtc_base/openssl_adapter.h
index 266ed35..7e1f87b 100644
--- a/rtc_base/openssl_adapter.h
+++ b/rtc_base/openssl_adapter.h
@@ -60,7 +60,6 @@
void SetCertVerifier(SSLCertificateVerifier* ssl_cert_verifier) override;
void SetIdentity(std::unique_ptr<SSLIdentity> identity) override;
void SetRole(SSLRole role) override;
- Socket* Accept(SocketAddress* paddr) override;
int StartSSL(const char* hostname) override;
int Send(const void* pv, size_t cb) override;
int SendTo(const void* pv, size_t cb, const SocketAddress& addr) override;
@@ -191,10 +190,21 @@
// the first adapter is created with the factory. If it is called after it
// will DCHECK.
void SetMode(SSLMode mode) override;
+
// Set a custom certificate verifier to be passed down to each instance
// created with this factory. This should only ever be set before the first
// call to the factory and cannot be changed after the fact.
void SetCertVerifier(SSLCertificateVerifier* ssl_cert_verifier) override;
+
+ void SetIdentity(std::unique_ptr<SSLIdentity> identity) override;
+
+ // Choose whether the socket acts as a server socket or client socket.
+ void SetRole(SSLRole role) override;
+
+ // Methods that control server certificate verification, used in unit tests.
+ // Do not call these methods in production code.
+ void SetIgnoreBadCert(bool ignore) override;
+
// Constructs a new socket using the shared OpenSSLSessionCache. This means
// existing SSLSessions already in the cache will be reused instead of
// re-created for improved performance.
@@ -203,6 +213,11 @@
private:
// Holds the SSLMode (DTLS,TLS) that will be used to set the session cache.
SSLMode ssl_mode_ = SSL_MODE_TLS;
+ SSLRole ssl_role_ = SSL_CLIENT;
+ bool ignore_bad_cert_ = false;
+
+ std::unique_ptr<SSLIdentity> identity_;
+
// Holds a cache of existing SSL Sessions.
std::unique_ptr<OpenSSLSessionCache> ssl_session_cache_;
// Provides an optional custom callback for verifying SSL certificates, this
diff --git a/rtc_base/ssl_adapter.cc b/rtc_base/ssl_adapter.cc
index c9b54c4..ff936a7 100644
--- a/rtc_base/ssl_adapter.cc
+++ b/rtc_base/ssl_adapter.cc
@@ -16,8 +16,8 @@
namespace rtc {
-SSLAdapterFactory* SSLAdapterFactory::Create() {
- return new OpenSSLAdapterFactory();
+std::unique_ptr<SSLAdapterFactory> SSLAdapterFactory::Create() {
+ return std::make_unique<OpenSSLAdapterFactory>();
}
SSLAdapter* SSLAdapter::Create(Socket* socket) {
diff --git a/rtc_base/ssl_adapter.h b/rtc_base/ssl_adapter.h
index 1f0616b..8f98141 100644
--- a/rtc_base/ssl_adapter.h
+++ b/rtc_base/ssl_adapter.h
@@ -39,10 +39,21 @@
// Specify a custom certificate verifier for SSL.
virtual void SetCertVerifier(SSLCertificateVerifier* ssl_cert_verifier) = 0;
+ // Set the certificate this socket will present to incoming clients.
+ // Takes ownership of `identity`.
+ virtual void SetIdentity(std::unique_ptr<SSLIdentity> identity) = 0;
+
+ // Choose whether the socket acts as a server socket or client socket.
+ virtual void SetRole(SSLRole role) = 0;
+
+ // Methods that control server certificate verification, used in unit tests.
+ // Do not call these methods in production code.
+ virtual void SetIgnoreBadCert(bool ignore) = 0;
+
// Creates a new SSL adapter, but from a shared context.
virtual SSLAdapter* CreateAdapter(Socket* socket) = 0;
- static SSLAdapterFactory* Create();
+ static std::unique_ptr<SSLAdapterFactory> Create();
};
// Class that abstracts a client-to-server SSL session. It can be created
@@ -91,6 +102,11 @@
// and deletes `socket`. Otherwise, the returned SSLAdapter takes ownership
// of `socket`.
static SSLAdapter* Create(Socket* socket);
+
+ private:
+ // Not supported.
+ int Listen(int backlog) override { RTC_CHECK(false); }
+ Socket* Accept(SocketAddress* paddr) override { RTC_CHECK(false); }
};
///////////////////////////////////////////////////////////////////////////////