Implementation of SSL caching; tests in separate CL.
This CL adds the ability for a SSLAdapter to resume a previous session, saving a roundtrip and significantly reducing the # of bytes needed to bring up the new session.
To do this, the sessions need to share state. This is addressed by introducing the SSLAdapterFactory object, which can maintain a SSL_CTX and session cache for multiple sessions.
This CL does not have unit tests in order to minimize the change size (i.e., to reduce the size of the CP). CL https://chromium-review.googlesource.com/c/558612 builds on this CL and adds tests, but makes some nontrivial changes to SSLStreamAdapter in order to get the test server to share a SSL_CTX across sessions.
Bug: 7936
Change-Id: I677b73453d981d5b3a2e66ea9a5be722acd59475
Reviewed-on: https://chromium-review.googlesource.com/575910
Commit-Queue: Justin Uberti <juberti@webrtc.org>
Reviewed-by: Emad Omara <emadomara@webrtc.org>
Reviewed-by: Taylor Brandstetter <deadbeef@webrtc.org>
Reviewed-by: Peter Thatcher <pthatcher@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#19342}
diff --git a/webrtc/rtc_base/openssladapter.cc b/webrtc/rtc_base/openssladapter.cc
index eec8021..11473ac 100644
--- a/webrtc/rtc_base/openssladapter.cc
+++ b/webrtc/rtc_base/openssladapter.cc
@@ -274,8 +274,10 @@
return true;
}
-OpenSSLAdapter::OpenSSLAdapter(AsyncSocket* socket)
+OpenSSLAdapter::OpenSSLAdapter(AsyncSocket* socket,
+ OpenSSLAdapterFactory* factory)
: SSLAdapter(socket),
+ factory_(factory),
state_(SSL_NONE),
ssl_read_needs_write_(false),
ssl_write_needs_read_(false),
@@ -283,20 +285,29 @@
ssl_(nullptr),
ssl_ctx_(nullptr),
ssl_mode_(SSL_MODE_TLS),
- custom_verification_succeeded_(false) {}
+ custom_verification_succeeded_(false) {
+ // If a factory is used, take a reference on the factory's SSL_CTX.
+ // Otherwise, we'll create our own later.
+ // Either way, we'll release our reference via SSL_CTX_free() in Cleanup().
+ if (factory_) {
+ ssl_ctx_ = factory_->ssl_ctx();
+ RTC_DCHECK(ssl_ctx_);
+ // Note: if using OpenSSL, requires version 1.1.0 or later.
+ SSL_CTX_up_ref(ssl_ctx_);
+ }
+}
OpenSSLAdapter::~OpenSSLAdapter() {
Cleanup();
}
-void
-OpenSSLAdapter::SetMode(SSLMode mode) {
+void OpenSSLAdapter::SetMode(SSLMode mode) {
+ RTC_DCHECK(!ssl_ctx_);
RTC_DCHECK(state_ == SSL_NONE);
ssl_mode_ = mode;
}
-int
-OpenSSLAdapter::StartSSL(const char* hostname, bool restartable) {
+int OpenSSLAdapter::StartSSL(const char* hostname, bool restartable) {
if (state_ != SSL_NONE)
return -1;
@@ -317,18 +328,20 @@
return 0;
}
-int
-OpenSSLAdapter::BeginSSL() {
- LOG(LS_INFO) << "BeginSSL: " << ssl_host_name_;
+int OpenSSLAdapter::BeginSSL() {
+ LOG(LS_INFO) << "OpenSSLAdapter::BeginSSL: " << ssl_host_name_;
RTC_DCHECK(state_ == SSL_CONNECTING);
int err = 0;
BIO* bio = nullptr;
- // First set up the context
- if (!ssl_ctx_)
- ssl_ctx_ = SetupSSLContext();
-
+ // First set up the context. We should either have a factory, with its own
+ // pre-existing context, or be running standalone, in which case we will
+ // need to create one, and specify |false| to disable session caching.
+ if (!factory_) {
+ RTC_DCHECK(!ssl_ctx_);
+ ssl_ctx_ = CreateContext(ssl_mode_, false);
+ }
if (!ssl_ctx_) {
err = -1;
goto ssl_error;
@@ -348,7 +361,6 @@
SSL_set_app_data(ssl_, this);
- SSL_set_bio(ssl_, bio, bio);
// SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER allows different buffers to be passed
// into SSL_write when a record could only be partially transmitted (and thus
// requires another call to SSL_write to finish transmission). This allows us
@@ -360,9 +372,24 @@
SSL_set_mode(ssl_, SSL_MODE_ENABLE_PARTIAL_WRITE |
SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER);
- // Enable SNI.
+ // Enable SNI, if a hostname is supplied.
if (!ssl_host_name_.empty()) {
SSL_set_tlsext_host_name(ssl_, ssl_host_name_.c_str());
+
+ // Enable session caching, if configured and a hostname is supplied.
+ if (factory_) {
+ SSL_SESSION* cached = factory_->LookupSession(ssl_host_name_);
+ if (cached) {
+ if (SSL_set_session(ssl_, cached) == 0) {
+ LOG(LS_WARNING) << "Failed to apply SSL session from cache";
+ err = -1;
+ goto ssl_error;
+ }
+
+ LOG(LS_INFO) << "Attempting to resume SSL session to "
+ << ssl_host_name_;
+ }
+ }
}
// Set a couple common TLS extensions; even though we don't use them yet.
@@ -370,10 +397,12 @@
SSL_enable_ocsp_stapling(ssl_);
SSL_enable_signed_cert_timestamps(ssl_);
- // the SSL object owns the bio now
+ // Now that the initial config is done, transfer ownership of |bio| to the
+ // SSL object. If ContinueSSL() fails, the bio will be freed in Cleanup().
+ SSL_set_bio(ssl_, bio, bio);
bio = nullptr;
- // Do the connect
+ // Do the connect.
err = ContinueSSL();
if (err != 0)
goto ssl_error;
@@ -388,8 +417,7 @@
return err;
}
-int
-OpenSSLAdapter::ContinueSSL() {
+int OpenSSLAdapter::ContinueSSL() {
RTC_DCHECK(state_ == SSL_CONNECTING);
// Clear the DTLS timer
@@ -441,8 +469,7 @@
return 0;
}
-void
-OpenSSLAdapter::Error(const char* context, int err, bool signal) {
+void OpenSSLAdapter::Error(const char* context, int err, bool signal) {
LOG(LS_WARNING) << "OpenSSLAdapter::Error("
<< context << ", " << err << ")";
state_ = SSL_ERROR;
@@ -451,9 +478,8 @@
AsyncSocketAdapter::OnCloseEvent(this, err);
}
-void
-OpenSSLAdapter::Cleanup() {
- LOG(LS_INFO) << "Cleanup";
+void OpenSSLAdapter::Cleanup() {
+ LOG(LS_INFO) << "OpenSSLAdapter::Cleanup";
state_ = SSL_NONE;
ssl_read_needs_write_ = false;
@@ -519,8 +545,7 @@
// AsyncSocket Implementation
//
-int
-OpenSSLAdapter::Send(const void* pv, size_t cb) {
+int OpenSSLAdapter::Send(const void* pv, size_t cb) {
//LOG(LS_INFO) << "OpenSSLAdapter::Send(" << cb << ")";
switch (state_) {
@@ -589,8 +614,9 @@
return ret;
}
-int
-OpenSSLAdapter::SendTo(const void* pv, size_t cb, const SocketAddress& addr) {
+int OpenSSLAdapter::SendTo(const void* pv,
+ size_t cb,
+ const SocketAddress& addr) {
if (socket_->GetState() == Socket::CS_CONNECTED &&
addr == socket_->GetRemoteAddress()) {
return Send(pv, cb);
@@ -677,15 +703,13 @@
return SOCKET_ERROR;
}
-int
-OpenSSLAdapter::Close() {
+int OpenSSLAdapter::Close() {
Cleanup();
state_ = restartable_ ? SSL_WAIT : SSL_NONE;
return AsyncSocketAdapter::Close();
}
-Socket::ConnState
-OpenSSLAdapter::GetState() const {
+Socket::ConnState OpenSSLAdapter::GetState() const {
//if (signal_close_)
// return CS_CONNECTED;
ConnState state = socket_->GetState();
@@ -695,8 +719,11 @@
return state;
}
-void
-OpenSSLAdapter::OnMessage(Message* msg) {
+bool OpenSSLAdapter::IsResumedSession() {
+ return (ssl_ && SSL_session_reused(ssl_) == 1);
+}
+
+void OpenSSLAdapter::OnMessage(Message* msg) {
if (MSG_TIMEOUT == msg->message_id) {
LOG(LS_INFO) << "DTLS timeout expired";
DTLSv1_handle_timeout(ssl_);
@@ -704,8 +731,7 @@
}
}
-void
-OpenSSLAdapter::OnConnectEvent(AsyncSocket* socket) {
+void OpenSSLAdapter::OnConnectEvent(AsyncSocket* socket) {
LOG(LS_INFO) << "OpenSSLAdapter::OnConnectEvent";
if (state_ != SSL_WAIT) {
RTC_DCHECK(state_ == SSL_NONE);
@@ -719,8 +745,7 @@
}
}
-void
-OpenSSLAdapter::OnReadEvent(AsyncSocket* socket) {
+void OpenSSLAdapter::OnReadEvent(AsyncSocket* socket) {
//LOG(LS_INFO) << "OpenSSLAdapter::OnReadEvent";
if (state_ == SSL_NONE) {
@@ -749,8 +774,7 @@
AsyncSocketAdapter::OnReadEvent(socket);
}
-void
-OpenSSLAdapter::OnWriteEvent(AsyncSocket* socket) {
+void OpenSSLAdapter::OnWriteEvent(AsyncSocket* socket) {
//LOG(LS_INFO) << "OpenSSLAdapter::OnWriteEvent";
if (state_ == SSL_NONE) {
@@ -790,8 +814,7 @@
AsyncSocketAdapter::OnWriteEvent(socket);
}
-void
-OpenSSLAdapter::OnCloseEvent(AsyncSocket* socket, int err) {
+void OpenSSLAdapter::OnCloseEvent(AsyncSocket* socket, int err) {
LOG(LS_INFO) << "OpenSSLAdapter::OnCloseEvent(" << err << ")";
AsyncSocketAdapter::OnCloseEvent(socket, err);
}
@@ -891,8 +914,7 @@
// We only use this for tracing and so it is only needed in debug mode
-void
-OpenSSLAdapter::SSLInfoCallback(const SSL* s, int where, int ret) {
+void OpenSSLAdapter::SSLInfoCallback(const SSL* s, int where, int ret) {
const char* str = "undefined";
int w = where & ~SSL_ST_MASK;
if (w & SSL_ST_CONNECT) {
@@ -918,8 +940,7 @@
#endif
-int
-OpenSSLAdapter::SSLVerifyCallback(int ok, X509_STORE_CTX* store) {
+int OpenSSLAdapter::SSLVerifyCallback(int ok, X509_STORE_CTX* store) {
#if !defined(NDEBUG)
if (!ok) {
char data[256];
@@ -964,6 +985,15 @@
return ok;
}
+int OpenSSLAdapter::NewSSLSessionCallback(SSL* ssl, SSL_SESSION* session) {
+ OpenSSLAdapter* stream =
+ reinterpret_cast<OpenSSLAdapter*>(SSL_get_app_data(ssl));
+ RTC_DCHECK(stream->factory_);
+ LOG(LS_INFO) << "Caching SSL session for " << stream->ssl_host_name_;
+ stream->factory_->AddSession(stream->ssl_host_name_, session);
+ return 1; // We've taken ownership of the session; OpenSSL shouldn't free it.
+}
+
bool OpenSSLAdapter::ConfigureTrustedRootCertificates(SSL_CTX* ctx) {
// Add the root cert that we care about to the SSL context
int count_of_added_certs = 0;
@@ -985,18 +1015,17 @@
return count_of_added_certs > 0;
}
-SSL_CTX*
-OpenSSLAdapter::SetupSSLContext() {
+SSL_CTX* OpenSSLAdapter::CreateContext(SSLMode mode, bool enable_cache) {
// Use (D)TLS 1.2.
// Note: BoringSSL supports a range of versions by setting max/min version
// (Default V1.0 to V1.2). However (D)TLSv1_2_client_method functions used
// below in OpenSSL only support V1.2.
SSL_CTX* ctx = nullptr;
#ifdef OPENSSL_IS_BORINGSSL
- ctx = SSL_CTX_new(ssl_mode_ == SSL_MODE_DTLS ? DTLS_method() : TLS_method());
+ ctx = SSL_CTX_new(mode == SSL_MODE_DTLS ? DTLS_method() : TLS_method());
#else
- ctx = SSL_CTX_new(ssl_mode_ == SSL_MODE_DTLS ? DTLSv1_2_client_method()
- : TLSv1_2_client_method());
+ ctx = SSL_CTX_new(mode == SSL_MODE_DTLS ? DTLSv1_2_client_method()
+ : TLSv1_2_client_method());
#endif // OPENSSL_IS_BORINGSSL
if (ctx == nullptr) {
unsigned long error = ERR_get_error(); // NOLINT: type used by OpenSSL.
@@ -1023,11 +1052,59 @@
SSL_CTX_set_cipher_list(
ctx, "ALL:!SHA256:!SHA384:!aPSK:!ECDSA+SHA1:!ADH:!LOW:!EXP:!MD5");
- if (ssl_mode_ == SSL_MODE_DTLS) {
+ if (mode == SSL_MODE_DTLS) {
SSL_CTX_set_read_ahead(ctx, 1);
}
+ if (enable_cache) {
+ SSL_CTX_set_session_cache_mode(ctx, SSL_SESS_CACHE_CLIENT);
+ SSL_CTX_sess_set_new_cb(ctx, &OpenSSLAdapter::NewSSLSessionCallback);
+ }
+
return ctx;
}
+//////////////////////////////////////////////////////////////////////
+// OpenSSLAdapterFactory
+//////////////////////////////////////////////////////////////////////
+
+OpenSSLAdapterFactory::OpenSSLAdapterFactory()
+ : ssl_mode_(SSL_MODE_TLS), ssl_ctx_(nullptr) {}
+
+OpenSSLAdapterFactory::~OpenSSLAdapterFactory() {
+ for (auto it : sessions_) {
+ SSL_SESSION_free(it.second);
+ }
+ SSL_CTX_free(ssl_ctx_);
+}
+
+void OpenSSLAdapterFactory::SetMode(SSLMode mode) {
+ RTC_DCHECK(!ssl_ctx_);
+ ssl_mode_ = mode;
+}
+
+OpenSSLAdapter* OpenSSLAdapterFactory::CreateAdapter(AsyncSocket* socket) {
+ if (!ssl_ctx_) {
+ bool enable_cache = true;
+ ssl_ctx_ = OpenSSLAdapter::CreateContext(ssl_mode_, enable_cache);
+ if (!ssl_ctx_) {
+ return nullptr;
+ }
+ }
+
+ return new OpenSSLAdapter(socket, this);
+}
+
+SSL_SESSION* OpenSSLAdapterFactory::LookupSession(const std::string& hostname) {
+ auto it = sessions_.find(hostname);
+ return (it != sessions_.end()) ? it->second : nullptr;
+}
+
+void OpenSSLAdapterFactory::AddSession(const std::string& hostname,
+ SSL_SESSION* new_session) {
+ SSL_SESSION* old_session = LookupSession(hostname);
+ SSL_SESSION_free(old_session);
+ sessions_[hostname] = new_session;
+}
+
} // namespace rtc