| /* | 
 |  *  Copyright 2014 The WebRTC Project Authors. All rights reserved. | 
 |  * | 
 |  *  Use of this source code is governed by a BSD-style license | 
 |  *  that can be found in the LICENSE file in the root of the source | 
 |  *  tree. An additional intellectual property rights grant can be found | 
 |  *  in the file PATENTS.  All contributing project authors may | 
 |  *  be found in the AUTHORS file in the root of the source tree. | 
 |  */ | 
 |  | 
 | #include "rtc_base/ssl_adapter.h" | 
 |  | 
 | #include <memory> | 
 | #include <string> | 
 | #include <utility> | 
 | #include <vector> | 
 |  | 
 | #include "absl/strings/str_cat.h" | 
 | #include "absl/strings/string_view.h" | 
 | #include "api/test/rtc_error_matchers.h" | 
 | #include "api/units/time_delta.h" | 
 | #include "rtc_base/ip_address.h" | 
 | #include "rtc_base/logging.h" | 
 | #include "rtc_base/socket.h" | 
 | #include "rtc_base/socket_address.h" | 
 | #include "rtc_base/ssl_certificate.h" | 
 | #include "rtc_base/ssl_identity.h" | 
 | #include "rtc_base/ssl_stream_adapter.h" | 
 | #include "rtc_base/third_party/sigslot/sigslot.h" | 
 | #include "rtc_base/thread.h" | 
 | #include "rtc_base/virtual_socket_server.h" | 
 | #include "test/gmock.h" | 
 | #include "test/gtest.h" | 
 | #include "test/wait_until.h" | 
 |  | 
 | using ::testing::_; | 
 | using ::testing::Return; | 
 |  | 
 | static const webrtc::TimeDelta kTimeout = webrtc::TimeDelta::Millis(5000); | 
 |  | 
 | static webrtc::Socket* CreateSocket() { | 
 |   webrtc::SocketAddress address(webrtc::IPAddress(INADDR_ANY), 0); | 
 |  | 
 |   webrtc::Socket* socket = | 
 |       webrtc::Thread::Current()->socketserver()->CreateSocket(address.family(), | 
 |                                                               SOCK_STREAM); | 
 |   socket->Bind(address); | 
 |  | 
 |   return socket; | 
 | } | 
 |  | 
 | // Simple mock for the certificate verifier. | 
 | class MockCertVerifier : public webrtc::SSLCertificateVerifier { | 
 |  public: | 
 |   ~MockCertVerifier() override = default; | 
 |   MOCK_METHOD(bool, Verify, (const webrtc::SSLCertificate&), (override)); | 
 | }; | 
 |  | 
 | // TODO(benwright) - Move to using INSTANTIATE_TEST_SUITE_P instead of using | 
 | // duplicate test cases for simple parameter changes. | 
 | class SSLAdapterTestDummy : public sigslot::has_slots<> { | 
 |  public: | 
 |   explicit SSLAdapterTestDummy() : socket_(CreateSocket()) {} | 
 |   ~SSLAdapterTestDummy() override = default; | 
 |  | 
 |   void CreateSSLAdapter(webrtc::Socket* socket, webrtc::SSLRole role) { | 
 |     ssl_adapter_.reset(webrtc::SSLAdapter::Create(socket)); | 
 |  | 
 |     // Ignore any certificate errors for the purpose of testing. | 
 |     // Note: We do this only because we don't have a real certificate. | 
 |     // NEVER USE THIS IN PRODUCTION CODE! | 
 |     ssl_adapter_->SetIgnoreBadCert(true); | 
 |  | 
 |     ssl_adapter_->SignalReadEvent.connect( | 
 |         this, &SSLAdapterTestDummy::OnSSLAdapterReadEvent); | 
 |     ssl_adapter_->SignalCloseEvent.connect( | 
 |         this, &SSLAdapterTestDummy::OnSSLAdapterCloseEvent); | 
 |     ssl_adapter_->SetRole(role); | 
 |   } | 
 |  | 
 |   void SetIgnoreBadCert(bool ignore_bad_cert) { | 
 |     ssl_adapter_->SetIgnoreBadCert(ignore_bad_cert); | 
 |   } | 
 |  | 
 |   void SetCertVerifier(webrtc::SSLCertificateVerifier* ssl_cert_verifier) { | 
 |     ssl_adapter_->SetCertVerifier(ssl_cert_verifier); | 
 |   } | 
 |  | 
 |   void SetAlpnProtocols(const std::vector<std::string>& protos) { | 
 |     ssl_adapter_->SetAlpnProtocols(protos); | 
 |   } | 
 |  | 
 |   void SetEllipticCurves(const std::vector<std::string>& curves) { | 
 |     ssl_adapter_->SetEllipticCurves(curves); | 
 |   } | 
 |  | 
 |   webrtc::SocketAddress GetAddress() const { | 
 |     return ssl_adapter_->GetLocalAddress(); | 
 |   } | 
 |  | 
 |   webrtc::Socket::ConnState GetState() const { | 
 |     return ssl_adapter_->GetState(); | 
 |   } | 
 |  | 
 |   const std::string& GetReceivedData() const { return data_; } | 
 |  | 
 |   int Close() { return ssl_adapter_->Close(); } | 
 |  | 
 |   int Send(absl::string_view message) { | 
 |     RTC_LOG(LS_INFO) << "Sending '" << message << "'"; | 
 |  | 
 |     return ssl_adapter_->Send(message.data(), message.length()); | 
 |   } | 
 |  | 
 |   void OnSSLAdapterReadEvent(webrtc::Socket* socket) { | 
 |     char buffer[4096] = ""; | 
 |  | 
 |     // Read data received from the server and store it in our internal buffer. | 
 |     int read = socket->Recv(buffer, sizeof(buffer) - 1, nullptr); | 
 |     if (read != -1) { | 
 |       buffer[read] = '\0'; | 
 |  | 
 |       RTC_LOG(LS_INFO) << "Received '" << buffer << "'"; | 
 |  | 
 |       data_ += buffer; | 
 |     } | 
 |   } | 
 |  | 
 |   void OnSSLAdapterCloseEvent(webrtc::Socket* socket, int error) { | 
 |     // OpenSSLAdapter signals handshake failure with a close event, but without | 
 |     // closing the socket! Let's close the socket here. This way GetState() can | 
 |     // return CS_CLOSED after failure. | 
 |     if (socket->GetState() != webrtc::Socket::CS_CLOSED) { | 
 |       socket->Close(); | 
 |     } | 
 |   } | 
 |  | 
 |  protected: | 
 |   std::unique_ptr<webrtc::SSLAdapter> ssl_adapter_; | 
 |   std::unique_ptr<webrtc::Socket> socket_; | 
 |  | 
 |  private: | 
 |   std::string data_; | 
 | }; | 
 |  | 
 | class SSLAdapterTestDummyClient : public SSLAdapterTestDummy { | 
 |  public: | 
 |   explicit SSLAdapterTestDummyClient() : SSLAdapterTestDummy() { | 
 |     CreateSSLAdapter(socket_.release(), webrtc::SSL_CLIENT); | 
 |   } | 
 |  | 
 |   int Connect(absl::string_view hostname, | 
 |               const webrtc::SocketAddress& address) { | 
 |     RTC_LOG(LS_INFO) << "Initiating connection with " << address.ToString(); | 
 |     int rv = ssl_adapter_->Connect(address); | 
 |  | 
 |     if (rv == 0) { | 
 |       RTC_LOG(LS_INFO) << "Starting TLS handshake with " << hostname; | 
 |  | 
 |       if (ssl_adapter_->StartSSL(hostname) != 0) { | 
 |         return -1; | 
 |       } | 
 |     } | 
 |  | 
 |     return rv; | 
 |   } | 
 | }; | 
 |  | 
 | class SSLAdapterTestDummyServer : public SSLAdapterTestDummy { | 
 |  public: | 
 |   explicit SSLAdapterTestDummyServer(const webrtc::KeyParams& key_params) | 
 |       : SSLAdapterTestDummy(), | 
 |         ssl_identity_(webrtc::SSLIdentity::Create(GetHostname(), key_params)) { | 
 |     socket_->Listen(1); | 
 |     socket_->SignalReadEvent.connect(this, | 
 |                                      &SSLAdapterTestDummyServer::OnReadEvent); | 
 |  | 
 |     RTC_LOG(LS_INFO) << "TCP server listening on " | 
 |                      << socket_->GetLocalAddress().ToString(); | 
 |   } | 
 |  | 
 |   webrtc::SocketAddress GetAddress() const { | 
 |     return socket_->GetLocalAddress(); | 
 |   } | 
 |  | 
 |   std::string GetHostname() const { | 
 |     // Since we don't have a real certificate anyway, the value here doesn't | 
 |     // really matter. | 
 |     return "example.com"; | 
 |   } | 
 |  | 
 |  protected: | 
 |   void OnReadEvent(webrtc::Socket* socket) { | 
 |     CreateSSLAdapter(socket_->Accept(nullptr), webrtc::SSL_SERVER); | 
 |     ssl_adapter_->SetIdentity(ssl_identity_->Clone()); | 
 |     if (ssl_adapter_->StartSSL(GetHostname()) != 0) { | 
 |       RTC_LOG(LS_ERROR) << "Starting SSL from server failed."; | 
 |     } | 
 |   } | 
 |  | 
 |  private: | 
 |   std::unique_ptr<webrtc::SSLIdentity> ssl_identity_; | 
 | }; | 
 |  | 
 | class SSLAdapterTestBase : public ::testing::Test, public sigslot::has_slots<> { | 
 |  public: | 
 |   explicit SSLAdapterTestBase(const webrtc::KeyParams& key_params) | 
 |       : vss_(new webrtc::VirtualSocketServer()), | 
 |         thread_(vss_.get()), | 
 |         server_(new SSLAdapterTestDummyServer(key_params)), | 
 |         client_(new SSLAdapterTestDummyClient()), | 
 |         handshake_wait_(webrtc::TimeDelta::Millis(kTimeout.ms())) {} | 
 |  | 
 |   void SetHandshakeWait(int wait) { | 
 |     handshake_wait_ = webrtc::TimeDelta::Millis(wait); | 
 |   } | 
 |  | 
 |   void SetIgnoreBadCert(bool ignore_bad_cert) { | 
 |     client_->SetIgnoreBadCert(ignore_bad_cert); | 
 |   } | 
 |  | 
 |   void SetCertVerifier(webrtc::SSLCertificateVerifier* ssl_cert_verifier) { | 
 |     client_->SetCertVerifier(ssl_cert_verifier); | 
 |   } | 
 |  | 
 |   void SetAlpnProtocols(const std::vector<std::string>& protos) { | 
 |     client_->SetAlpnProtocols(protos); | 
 |   } | 
 |  | 
 |   void SetEllipticCurves(const std::vector<std::string>& curves) { | 
 |     client_->SetEllipticCurves(curves); | 
 |   } | 
 |  | 
 |   void SetMockCertVerifier(bool return_value) { | 
 |     auto mock_verifier = std::make_unique<MockCertVerifier>(); | 
 |     EXPECT_CALL(*mock_verifier, Verify(_)).WillRepeatedly(Return(return_value)); | 
 |     cert_verifier_ = std::unique_ptr<webrtc::SSLCertificateVerifier>( | 
 |         std::move(mock_verifier)); | 
 |  | 
 |     SetIgnoreBadCert(false); | 
 |     SetCertVerifier(cert_verifier_.get()); | 
 |   } | 
 |  | 
 |   void TestHandshake(bool expect_success) { | 
 |     int rv; | 
 |  | 
 |     // The initial state is CS_CLOSED | 
 |     ASSERT_EQ(webrtc::Socket::CS_CLOSED, client_->GetState()); | 
 |  | 
 |     rv = client_->Connect(server_->GetHostname(), server_->GetAddress()); | 
 |     ASSERT_EQ(0, rv); | 
 |  | 
 |     // Now the state should be CS_CONNECTING | 
 |     ASSERT_EQ(webrtc::Socket::CS_CONNECTING, client_->GetState()); | 
 |  | 
 |     if (expect_success) { | 
 |       // If expecting success, the client should end up in the CS_CONNECTED | 
 |       // state after handshake. | 
 |       EXPECT_THAT(webrtc::WaitUntil([&] { return client_->GetState(); }, | 
 |                                     ::testing::Eq(webrtc::Socket::CS_CONNECTED), | 
 |                                     {.timeout = handshake_wait_}), | 
 |                   webrtc::IsRtcOk()); | 
 |  | 
 |       RTC_LOG(LS_INFO) << "TLS handshake complete."; | 
 |  | 
 |     } else { | 
 |       // On handshake failure the client should end up in the CS_CLOSED state. | 
 |       EXPECT_THAT(webrtc::WaitUntil([&] { return client_->GetState(); }, | 
 |                                     ::testing::Eq(webrtc::Socket::CS_CLOSED), | 
 |                                     {.timeout = handshake_wait_}), | 
 |                   webrtc::IsRtcOk()); | 
 |  | 
 |       RTC_LOG(LS_INFO) << "TLS handshake failed."; | 
 |     } | 
 |   } | 
 |  | 
 |   void TestTransfer(absl::string_view message) { | 
 |     int rv; | 
 |  | 
 |     rv = client_->Send(message); | 
 |     ASSERT_EQ(static_cast<int>(message.length()), rv); | 
 |  | 
 |     // The server should have received the client's message. | 
 |     EXPECT_THAT( | 
 |         webrtc::WaitUntil([&] { return server_->GetReceivedData(); }, | 
 |                           ::testing::Eq(message), {.timeout = kTimeout}), | 
 |         webrtc::IsRtcOk()); | 
 |  | 
 |     rv = server_->Send(message); | 
 |     ASSERT_EQ(static_cast<int>(message.length()), rv); | 
 |  | 
 |     // The client should have received the server's message. | 
 |     EXPECT_THAT( | 
 |         webrtc::WaitUntil([&] { return client_->GetReceivedData(); }, | 
 |                           ::testing::Eq(message), {.timeout = kTimeout}), | 
 |         webrtc::IsRtcOk()); | 
 |  | 
 |     RTC_LOG(LS_INFO) << "Transfer complete."; | 
 |   } | 
 |  | 
 |  protected: | 
 |   std::unique_ptr<webrtc::VirtualSocketServer> vss_; | 
 |   webrtc::AutoSocketServerThread thread_; | 
 |   std::unique_ptr<SSLAdapterTestDummyServer> server_; | 
 |   std::unique_ptr<SSLAdapterTestDummyClient> client_; | 
 |   std::unique_ptr<webrtc::SSLCertificateVerifier> cert_verifier_; | 
 |  | 
 |   webrtc::TimeDelta handshake_wait_; | 
 | }; | 
 |  | 
 | class SSLAdapterTestTLS_RSA : public SSLAdapterTestBase { | 
 |  public: | 
 |   SSLAdapterTestTLS_RSA() : SSLAdapterTestBase(webrtc::KeyParams::RSA()) {} | 
 | }; | 
 |  | 
 | class SSLAdapterTestTLS_ECDSA : public SSLAdapterTestBase { | 
 |  public: | 
 |   SSLAdapterTestTLS_ECDSA() : SSLAdapterTestBase(webrtc::KeyParams::ECDSA()) {} | 
 | }; | 
 |  | 
 | // Test that handshake works, using RSA | 
 | TEST_F(SSLAdapterTestTLS_RSA, TestTLSConnect) { | 
 |   TestHandshake(true); | 
 | } | 
 |  | 
 | // Test that handshake works with a custom verifier that returns true. RSA. | 
 | TEST_F(SSLAdapterTestTLS_RSA, TestTLSConnectCustomCertVerifierSucceeds) { | 
 |   SetMockCertVerifier(/*return_value=*/true); | 
 |   TestHandshake(/*expect_success=*/true); | 
 | } | 
 |  | 
 | // Test that handshake fails with a custom verifier that returns false. RSA. | 
 | TEST_F(SSLAdapterTestTLS_RSA, TestTLSConnectCustomCertVerifierFails) { | 
 |   SetMockCertVerifier(/*return_value=*/false); | 
 |   TestHandshake(/*expect_success=*/false); | 
 | } | 
 |  | 
 | // Test that handshake works, using ECDSA | 
 | TEST_F(SSLAdapterTestTLS_ECDSA, TestTLSConnect) { | 
 |   SetMockCertVerifier(/*return_value=*/true); | 
 |   TestHandshake(/*expect_success=*/true); | 
 | } | 
 |  | 
 | // Test that handshake works with a custom verifier that returns true. ECDSA. | 
 | TEST_F(SSLAdapterTestTLS_ECDSA, TestTLSConnectCustomCertVerifierSucceeds) { | 
 |   SetMockCertVerifier(/*return_value=*/true); | 
 |   TestHandshake(/*expect_success=*/true); | 
 | } | 
 |  | 
 | // Test that handshake fails with a custom verifier that returns false. ECDSA. | 
 | TEST_F(SSLAdapterTestTLS_ECDSA, TestTLSConnectCustomCertVerifierFails) { | 
 |   SetMockCertVerifier(/*return_value=*/false); | 
 |   TestHandshake(/*expect_success=*/false); | 
 | } | 
 |  | 
 | // Test transfer between client and server, using RSA | 
 | TEST_F(SSLAdapterTestTLS_RSA, TestTLSTransfer) { | 
 |   TestHandshake(true); | 
 |   TestTransfer("Hello, world!"); | 
 | } | 
 |  | 
 | // Test transfer between client and server, using RSA with custom cert verifier. | 
 | TEST_F(SSLAdapterTestTLS_RSA, TestTLSTransferCustomCertVerifier) { | 
 |   SetMockCertVerifier(/*return_value=*/true); | 
 |   TestHandshake(/*expect_success=*/true); | 
 |   TestTransfer("Hello, world!"); | 
 | } | 
 |  | 
 | TEST_F(SSLAdapterTestTLS_RSA, TestTLSTransferWithBlockedSocket) { | 
 |   TestHandshake(true); | 
 |  | 
 |   // Tell the underlying socket to simulate being blocked. | 
 |   vss_->SetSendingBlocked(true); | 
 |  | 
 |   std::string expected; | 
 |   int rv; | 
 |   // Send messages until the SSL socket adapter starts applying backpressure. | 
 |   // Note that this may not occur immediately since there may be some amount of | 
 |   // intermediate buffering (either in our code or in BoringSSL). | 
 |   for (int i = 0; i < 1024; ++i) { | 
 |     std::string message = "Hello, world: " + absl::StrCat(i); | 
 |     rv = client_->Send(message); | 
 |     if (rv != static_cast<int>(message.size())) { | 
 |       // This test assumes either the whole message or none of it is sent. | 
 |       ASSERT_EQ(-1, rv); | 
 |       break; | 
 |     } | 
 |     expected += message; | 
 |   } | 
 |   // Assert that the loop above exited due to Send returning -1. | 
 |   ASSERT_EQ(-1, rv); | 
 |  | 
 |   // Try sending another message while blocked. -1 should be returned again and | 
 |   // it shouldn't end up received by the server later. | 
 |   EXPECT_EQ(-1, client_->Send("Never sent")); | 
 |  | 
 |   // Unblock the underlying socket. All of the buffered messages should be sent | 
 |   // without any further action. | 
 |   vss_->SetSendingBlocked(false); | 
 |   EXPECT_THAT(webrtc::WaitUntil([&] { return server_->GetReceivedData(); }, | 
 |                                 ::testing::Eq(expected), {.timeout = kTimeout}), | 
 |               webrtc::IsRtcOk()); | 
 |  | 
 |   // Send another message. This previously wasn't working | 
 |   std::string final_message = "Fin."; | 
 |   expected += final_message; | 
 |   EXPECT_EQ(static_cast<int>(final_message.size()), | 
 |             client_->Send(final_message)); | 
 |   EXPECT_THAT(webrtc::WaitUntil([&] { return server_->GetReceivedData(); }, | 
 |                                 ::testing::Eq(expected), {.timeout = kTimeout}), | 
 |               webrtc::IsRtcOk()); | 
 | } | 
 |  | 
 | // Test transfer between client and server, using ECDSA | 
 | TEST_F(SSLAdapterTestTLS_ECDSA, TestTLSTransfer) { | 
 |   TestHandshake(true); | 
 |   TestTransfer("Hello, world!"); | 
 | } | 
 |  | 
 | // Test transfer between client and server, using ECDSA with custom cert | 
 | // verifier. | 
 | TEST_F(SSLAdapterTestTLS_ECDSA, TestTLSTransferCustomCertVerifier) { | 
 |   SetMockCertVerifier(/*return_value=*/true); | 
 |   TestHandshake(/*expect_success=*/true); | 
 |   TestTransfer("Hello, world!"); | 
 | } | 
 |  | 
 | // Test transfer using ALPN with protos as h2 and http/1.1 | 
 | TEST_F(SSLAdapterTestTLS_ECDSA, TestTLSALPN) { | 
 |   std::vector<std::string> alpn_protos{"h2", "http/1.1"}; | 
 |   SetAlpnProtocols(alpn_protos); | 
 |   TestHandshake(true); | 
 |   TestTransfer("Hello, world!"); | 
 | } | 
 |  | 
 | // Test transfer with TLS Elliptic curves set to "X25519:P-256:P-384:P-521" | 
 | TEST_F(SSLAdapterTestTLS_ECDSA, TestTLSEllipticCurves) { | 
 |   std::vector<std::string> elliptic_curves{"X25519", "P-256", "P-384", "P-521"}; | 
 |   SetEllipticCurves(elliptic_curves); | 
 |   TestHandshake(true); | 
 |   TestTransfer("Hello, world!"); | 
 | } |