Disable SSLAdapter methods Listen and Accept

Only affects turn server. Refactored to wrap sockets with SSLAdapter
after Accept, using the SSLAdapterFactory to hold needed configuration.

Bug: webrtc:13065
Change-Id: I5df65aad5728d8d40d95b22db6398a573ec7a36f
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/235823
Reviewed-by: Harald Alvestrand <hta@webrtc.org>
Commit-Queue: Niels Moller <nisse@webrtc.org>
Cr-Commit-Position: refs/heads/main@{#35258}
diff --git a/p2p/base/test_turn_server.h b/p2p/base/test_turn_server.h
index e1deb59..6cad135 100644
--- a/p2p/base/test_turn_server.h
+++ b/p2p/base/test_turn_server.h
@@ -11,7 +11,9 @@
 #ifndef P2P_BASE_TEST_TURN_SERVER_H_
 #define P2P_BASE_TEST_TURN_SERVER_H_
 
+#include <memory>
 #include <string>
+#include <utility>
 #include <vector>
 
 #include "api/sequence_checker.h"
@@ -104,21 +106,24 @@
       // new connections.
       rtc::Socket* socket =
           thread_->socketserver()->CreateSocket(AF_INET, SOCK_STREAM);
+      socket->Bind(int_addr);
+      socket->Listen(5);
       if (proto == cricket::PROTO_TLS) {
         // For TLS, wrap the TCP socket with an SSL adapter. The adapter must
         // be configured with a self-signed certificate for testing.
         // Additionally, the client will not present a valid certificate, so we
         // must not fail when checking the peer's identity.
-        rtc::SSLAdapter* adapter = rtc::SSLAdapter::Create(socket);
-        adapter->SetRole(rtc::SSL_SERVER);
-        adapter->SetIdentity(
+        std::unique_ptr<rtc::SSLAdapterFactory> ssl_adapter_factory =
+            rtc::SSLAdapterFactory::Create();
+        ssl_adapter_factory->SetRole(rtc::SSL_SERVER);
+        ssl_adapter_factory->SetIdentity(
             rtc::SSLIdentity::Create(common_name, rtc::KeyParams()));
-        adapter->SetIgnoreBadCert(ignore_bad_cert);
-        socket = adapter;
+        ssl_adapter_factory->SetIgnoreBadCert(ignore_bad_cert);
+        server_.AddInternalServerSocket(socket, proto,
+                                        std::move(ssl_adapter_factory));
+      } else {
+        server_.AddInternalServerSocket(socket, proto);
       }
-      socket->Bind(int_addr);
-      socket->Listen(5);
-      server_.AddInternalServerSocket(socket, proto);
     } else {
       RTC_NOTREACHED() << "Unknown protocol type: " << proto;
     }
diff --git a/p2p/base/turn_server.cc b/p2p/base/turn_server.cc
index fd9cd16..5685e20 100644
--- a/p2p/base/turn_server.cc
+++ b/p2p/base/turn_server.cc
@@ -152,12 +152,15 @@
   socket->SignalReadPacket.connect(this, &TurnServer::OnInternalPacket);
 }
 
-void TurnServer::AddInternalServerSocket(rtc::Socket* socket,
-                                         ProtocolType proto) {
+void TurnServer::AddInternalServerSocket(
+    rtc::Socket* socket,
+    ProtocolType proto,
+    std::unique_ptr<rtc::SSLAdapterFactory> ssl_adapter_factory) {
   RTC_DCHECK_RUN_ON(thread_);
+
   RTC_DCHECK(server_listen_sockets_.end() ==
              server_listen_sockets_.find(socket));
-  server_listen_sockets_[socket] = proto;
+  server_listen_sockets_[socket] = {proto, std::move(ssl_adapter_factory)};
   socket->SignalReadEvent.connect(this, &TurnServer::OnNewInternalConnection);
 }
 
@@ -181,13 +184,19 @@
   rtc::SocketAddress accept_addr;
   rtc::Socket* accepted_socket = server_socket->Accept(&accept_addr);
   if (accepted_socket != NULL) {
-    ProtocolType proto = server_listen_sockets_[server_socket];
+    const ServerSocketInfo& info = server_listen_sockets_[server_socket];
+    if (info.ssl_adapter_factory) {
+      rtc::SSLAdapter* ssl_adapter =
+          info.ssl_adapter_factory->CreateAdapter(accepted_socket);
+      ssl_adapter->StartSSL("");
+      accepted_socket = ssl_adapter;
+    }
     cricket::AsyncStunTCPSocket* tcp_socket =
         new cricket::AsyncStunTCPSocket(accepted_socket);
 
     tcp_socket->SignalClose.connect(this, &TurnServer::OnInternalSocketClose);
     // Finally add the socket so it can start communicating with the client.
-    AddInternalSocket(tcp_socket, proto);
+    AddInternalSocket(tcp_socket, info.proto);
   }
 }
 
diff --git a/p2p/base/turn_server.h b/p2p/base/turn_server.h
index 7942c09..481b081 100644
--- a/p2p/base/turn_server.h
+++ b/p2p/base/turn_server.h
@@ -23,6 +23,7 @@
 #include "p2p/base/port_interface.h"
 #include "rtc_base/async_packet_socket.h"
 #include "rtc_base/socket_address.h"
+#include "rtc_base/ssl_adapter.h"
 #include "rtc_base/third_party/sigslot/sigslot.h"
 #include "rtc_base/thread.h"
 
@@ -237,7 +238,10 @@
   // Starts listening for the connections on this socket. When someone tries
   // to connect, the connection will be accepted and a new internal socket
   // will be added.
-  void AddInternalServerSocket(rtc::Socket* socket, ProtocolType proto);
+  void AddInternalServerSocket(
+      rtc::Socket* socket,
+      ProtocolType proto,
+      std::unique_ptr<rtc::SSLAdapterFactory> ssl_adapter_factory = nullptr);
   // Specifies the factory to use for creating external sockets.
   void SetExternalSocketFactory(rtc::PacketSocketFactory* factory,
                                 const rtc::SocketAddress& address);
@@ -320,7 +324,12 @@
       RTC_RUN_ON(thread_);
 
   typedef std::map<rtc::AsyncPacketSocket*, ProtocolType> InternalSocketMap;
-  typedef std::map<rtc::Socket*, ProtocolType> ServerSocketMap;
+  struct ServerSocketInfo {
+    ProtocolType proto;
+    // If non-null, used to wrap accepted sockets.
+    std::unique_ptr<rtc::SSLAdapterFactory> ssl_adapter_factory;
+  };
+  typedef std::map<rtc::Socket*, ServerSocketInfo> ServerSocketMap;
 
   rtc::Thread* const thread_;
   const std::string nonce_key_;
diff --git a/rtc_base/openssl_adapter.cc b/rtc_base/openssl_adapter.cc
index 7489bc9..bc10e61 100644
--- a/rtc_base/openssl_adapter.cc
+++ b/rtc_base/openssl_adapter.cc
@@ -250,21 +250,6 @@
   role_ = role;
 }
 
-Socket* OpenSSLAdapter::Accept(SocketAddress* paddr) {
-  RTC_DCHECK(role_ == SSL_SERVER);
-  Socket* socket = SSLAdapter::Accept(paddr);
-  if (!socket) {
-    return nullptr;
-  }
-
-  SSLAdapter* adapter = SSLAdapter::Create(socket);
-  adapter->SetIdentity(identity_->Clone());
-  adapter->SetRole(rtc::SSL_SERVER);
-  adapter->SetIgnoreBadCert(ignore_bad_cert_);
-  adapter->StartSSL("");
-  return adapter;
-}
-
 int OpenSSLAdapter::StartSSL(const char* hostname) {
   if (state_ != SSL_NONE)
     return -1;
@@ -1038,6 +1023,21 @@
   ssl_cert_verifier_ = ssl_cert_verifier;
 }
 
+void OpenSSLAdapterFactory::SetIdentity(std::unique_ptr<SSLIdentity> identity) {
+  RTC_DCHECK(!ssl_session_cache_);
+  identity_ = std::move(identity);
+}
+
+void OpenSSLAdapterFactory::SetRole(SSLRole role) {
+  RTC_DCHECK(!ssl_session_cache_);
+  ssl_role_ = role;
+}
+
+void OpenSSLAdapterFactory::SetIgnoreBadCert(bool ignore) {
+  RTC_DCHECK(!ssl_session_cache_);
+  ignore_bad_cert_ = ignore;
+}
+
 OpenSSLAdapter* OpenSSLAdapterFactory::CreateAdapter(Socket* socket) {
   if (ssl_session_cache_ == nullptr) {
     SSL_CTX* ssl_ctx = OpenSSLAdapter::CreateContext(ssl_mode_, true);
@@ -1049,8 +1049,14 @@
         std::make_unique<OpenSSLSessionCache>(ssl_mode_, ssl_ctx);
     SSL_CTX_free(ssl_ctx);
   }
-  return new OpenSSLAdapter(socket, ssl_session_cache_.get(),
-                            ssl_cert_verifier_);
+  OpenSSLAdapter* ssl_adapter =
+      new OpenSSLAdapter(socket, ssl_session_cache_.get(), ssl_cert_verifier_);
+  ssl_adapter->SetRole(ssl_role_);
+  ssl_adapter->SetIgnoreBadCert(ignore_bad_cert_);
+  if (identity_) {
+    ssl_adapter->SetIdentity(identity_->Clone());
+  }
+  return ssl_adapter;
 }
 
 OpenSSLAdapter::EarlyExitCatcher::EarlyExitCatcher(OpenSSLAdapter& adapter_ptr)
diff --git a/rtc_base/openssl_adapter.h b/rtc_base/openssl_adapter.h
index 266ed35..7e1f87b 100644
--- a/rtc_base/openssl_adapter.h
+++ b/rtc_base/openssl_adapter.h
@@ -60,7 +60,6 @@
   void SetCertVerifier(SSLCertificateVerifier* ssl_cert_verifier) override;
   void SetIdentity(std::unique_ptr<SSLIdentity> identity) override;
   void SetRole(SSLRole role) override;
-  Socket* Accept(SocketAddress* paddr) override;
   int StartSSL(const char* hostname) override;
   int Send(const void* pv, size_t cb) override;
   int SendTo(const void* pv, size_t cb, const SocketAddress& addr) override;
@@ -191,10 +190,21 @@
   // the first adapter is created with the factory. If it is called after it
   // will DCHECK.
   void SetMode(SSLMode mode) override;
+
   // Set a custom certificate verifier to be passed down to each instance
   // created with this factory. This should only ever be set before the first
   // call to the factory and cannot be changed after the fact.
   void SetCertVerifier(SSLCertificateVerifier* ssl_cert_verifier) override;
+
+  void SetIdentity(std::unique_ptr<SSLIdentity> identity) override;
+
+  // Choose whether the socket acts as a server socket or client socket.
+  void SetRole(SSLRole role) override;
+
+  // Methods that control server certificate verification, used in unit tests.
+  // Do not call these methods in production code.
+  void SetIgnoreBadCert(bool ignore) override;
+
   // Constructs a new socket using the shared OpenSSLSessionCache. This means
   // existing SSLSessions already in the cache will be reused instead of
   // re-created for improved performance.
@@ -203,6 +213,11 @@
  private:
   // Holds the SSLMode (DTLS,TLS) that will be used to set the session cache.
   SSLMode ssl_mode_ = SSL_MODE_TLS;
+  SSLRole ssl_role_ = SSL_CLIENT;
+  bool ignore_bad_cert_ = false;
+
+  std::unique_ptr<SSLIdentity> identity_;
+
   // Holds a cache of existing SSL Sessions.
   std::unique_ptr<OpenSSLSessionCache> ssl_session_cache_;
   // Provides an optional custom callback for verifying SSL certificates, this
diff --git a/rtc_base/ssl_adapter.cc b/rtc_base/ssl_adapter.cc
index c9b54c4..ff936a7 100644
--- a/rtc_base/ssl_adapter.cc
+++ b/rtc_base/ssl_adapter.cc
@@ -16,8 +16,8 @@
 
 namespace rtc {
 
-SSLAdapterFactory* SSLAdapterFactory::Create() {
-  return new OpenSSLAdapterFactory();
+std::unique_ptr<SSLAdapterFactory> SSLAdapterFactory::Create() {
+  return std::make_unique<OpenSSLAdapterFactory>();
 }
 
 SSLAdapter* SSLAdapter::Create(Socket* socket) {
diff --git a/rtc_base/ssl_adapter.h b/rtc_base/ssl_adapter.h
index 1f0616b..8f98141 100644
--- a/rtc_base/ssl_adapter.h
+++ b/rtc_base/ssl_adapter.h
@@ -39,10 +39,21 @@
   // Specify a custom certificate verifier for SSL.
   virtual void SetCertVerifier(SSLCertificateVerifier* ssl_cert_verifier) = 0;
 
+  // Set the certificate this socket will present to incoming clients.
+  // Takes ownership of `identity`.
+  virtual void SetIdentity(std::unique_ptr<SSLIdentity> identity) = 0;
+
+  // Choose whether the socket acts as a server socket or client socket.
+  virtual void SetRole(SSLRole role) = 0;
+
+  // Methods that control server certificate verification, used in unit tests.
+  // Do not call these methods in production code.
+  virtual void SetIgnoreBadCert(bool ignore) = 0;
+
   // Creates a new SSL adapter, but from a shared context.
   virtual SSLAdapter* CreateAdapter(Socket* socket) = 0;
 
-  static SSLAdapterFactory* Create();
+  static std::unique_ptr<SSLAdapterFactory> Create();
 };
 
 // Class that abstracts a client-to-server SSL session. It can be created
@@ -91,6 +102,11 @@
   // and deletes `socket`. Otherwise, the returned SSLAdapter takes ownership
   // of `socket`.
   static SSLAdapter* Create(Socket* socket);
+
+ private:
+  // Not supported.
+  int Listen(int backlog) override { RTC_CHECK(false); }
+  Socket* Accept(SocketAddress* paddr) override { RTC_CHECK(false); }
 };
 
 ///////////////////////////////////////////////////////////////////////////////