Support DNS resolution matching a specified IP family.

The input SocketAddress for STUN host lookup is constructed with just
the hostname, so the family is AF_UNSPEC. So added an overload with a
target family to distinguish this from the family of the input addr.

Bug: webrtc:14319, webrtc:14131
Change-Id: Ia5ac5aa2e894e0c4dfb4417e3e8a76a6cec3ea71
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/270624
Reviewed-by: Tomas Gunnarsson <tommi@webrtc.org>
Commit-Queue: Sameer Vijaykar <samvi@google.com>
Reviewed-by: Jonas Oreland <jonaso@webrtc.org>
Reviewed-by: Jonas Oreland <jonaso@google.com>
Cr-Commit-Position: refs/heads/main@{#37750}
diff --git a/api/BUILD.gn b/api/BUILD.gn
index 5300a6a..eca66e5 100644
--- a/api/BUILD.gn
+++ b/api/BUILD.gn
@@ -345,6 +345,7 @@
   visibility = [ "*" ]
   sources = [ "async_dns_resolver.h" ]
   deps = [
+    "../rtc_base:checks",
     "../rtc_base:socket_address",
     "../rtc_base/system:rtc_export",
   ]
diff --git a/api/async_dns_resolver.h b/api/async_dns_resolver.h
index 138503b..82d80de 100644
--- a/api/async_dns_resolver.h
+++ b/api/async_dns_resolver.h
@@ -14,6 +14,7 @@
 #include <functional>
 #include <memory>
 
+#include "rtc_base/checks.h"
 #include "rtc_base/socket_address.h"
 #include "rtc_base/system/rtc_export.h"
 
@@ -63,6 +64,10 @@
   // Start address resolution of the hostname in `addr`.
   virtual void Start(const rtc::SocketAddress& addr,
                      std::function<void()> callback) = 0;
+  // Start address resolution of the hostname in `addr` matching `family`.
+  virtual void Start(const rtc::SocketAddress& addr,
+                     int family,
+                     std::function<void()> callback) = 0;
   virtual const AsyncDnsResolverResult& result() const = 0;
 };
 
@@ -79,6 +84,14 @@
   virtual std::unique_ptr<webrtc::AsyncDnsResolverInterface> CreateAndResolve(
       const rtc::SocketAddress& addr,
       std::function<void()> callback) = 0;
+  // Creates an AsyncDnsResolver and starts resolving the name to an address
+  // matching the specified family. The callback will be called when resolution
+  // is finished. The callback will be called on the sequence that the caller
+  // runs on.
+  virtual std::unique_ptr<webrtc::AsyncDnsResolverInterface> CreateAndResolve(
+      const rtc::SocketAddress& addr,
+      int family,
+      std::function<void()> callback) = 0;
   // Creates an AsyncDnsResolver and does not start it.
   // For backwards compatibility, will be deprecated and removed.
   // One has to do a separate Start() call on the
diff --git a/api/test/mock_async_dns_resolver.h b/api/test/mock_async_dns_resolver.h
index 7cc17a8..81132c9 100644
--- a/api/test/mock_async_dns_resolver.h
+++ b/api/test/mock_async_dns_resolver.h
@@ -34,6 +34,10 @@
               Start,
               (const rtc::SocketAddress&, std::function<void()>),
               (override));
+  MOCK_METHOD(void,
+              Start,
+              (const rtc::SocketAddress&, int family, std::function<void()>),
+              (override));
   MOCK_METHOD(AsyncDnsResolverResult&, result, (), (const, override));
 };
 
@@ -44,6 +48,10 @@
               (const rtc::SocketAddress&, std::function<void()>),
               (override));
   MOCK_METHOD(std::unique_ptr<webrtc::AsyncDnsResolverInterface>,
+              CreateAndResolve,
+              (const rtc::SocketAddress&, int, std::function<void()>),
+              (override));
+  MOCK_METHOD(std::unique_ptr<webrtc::AsyncDnsResolverInterface>,
               Create,
               (),
               (override));
diff --git a/api/wrapping_async_dns_resolver.h b/api/wrapping_async_dns_resolver.h
index 80da206..5155b0f 100644
--- a/api/wrapping_async_dns_resolver.h
+++ b/api/wrapping_async_dns_resolver.h
@@ -13,6 +13,7 @@
 
 #include <functional>
 #include <memory>
+#include <utility>
 
 #include "absl/memory/memory.h"
 #include "api/async_dns_resolver.h"
@@ -68,14 +69,18 @@
   void Start(const rtc::SocketAddress& addr,
              std::function<void()> callback) override {
     RTC_DCHECK_RUN_ON(&sequence_checker_);
-    RTC_DCHECK_EQ(State::kNotStarted, state_);
-    state_ = State::kStarted;
-    callback_ = callback;
-    wrapped_->SignalDone.connect(this,
-                                 &WrappingAsyncDnsResolver::OnResolveResult);
+    PrepareToResolve(std::move(callback));
     wrapped_->Start(addr);
   }
 
+  void Start(const rtc::SocketAddress& addr,
+             int family,
+             std::function<void()> callback) override {
+    RTC_DCHECK_RUN_ON(&sequence_checker_);
+    PrepareToResolve(std::move(callback));
+    wrapped_->Start(addr, family);
+  }
+
   const AsyncDnsResolverResult& result() const override {
     RTC_DCHECK_RUN_ON(&sequence_checker_);
     RTC_DCHECK_EQ(State::kResolved, state_);
@@ -92,6 +97,15 @@
     return wrapped_.get();
   }
 
+  void PrepareToResolve(std::function<void()> callback) {
+    RTC_DCHECK_RUN_ON(&sequence_checker_);
+    RTC_DCHECK_EQ(State::kNotStarted, state_);
+    state_ = State::kStarted;
+    callback_ = std::move(callback);
+    wrapped_->SignalDone.connect(this,
+                                 &WrappingAsyncDnsResolver::OnResolveResult);
+  }
+
   void OnResolveResult(rtc::AsyncResolverInterface* ref) {
     RTC_DCHECK_RUN_ON(&sequence_checker_);
     RTC_DCHECK(state_ == State::kStarted);
diff --git a/p2p/base/basic_async_resolver_factory.cc b/p2p/base/basic_async_resolver_factory.cc
index 6824357..3fdf75b 100644
--- a/p2p/base/basic_async_resolver_factory.cc
+++ b/p2p/base/basic_async_resolver_factory.cc
@@ -36,7 +36,17 @@
     const rtc::SocketAddress& addr,
     std::function<void()> callback) {
   std::unique_ptr<webrtc::AsyncDnsResolverInterface> resolver = Create();
-  resolver->Start(addr, callback);
+  resolver->Start(addr, std::move(callback));
+  return resolver;
+}
+
+std::unique_ptr<webrtc::AsyncDnsResolverInterface>
+WrappingAsyncDnsResolverFactory::CreateAndResolve(
+    const rtc::SocketAddress& addr,
+    int family,
+    std::function<void()> callback) {
+  std::unique_ptr<webrtc::AsyncDnsResolverInterface> resolver = Create();
+  resolver->Start(addr, family, std::move(callback));
   return resolver;
 }
 
diff --git a/p2p/base/basic_async_resolver_factory.h b/p2p/base/basic_async_resolver_factory.h
index c988913..9a0ba1a 100644
--- a/p2p/base/basic_async_resolver_factory.h
+++ b/p2p/base/basic_async_resolver_factory.h
@@ -45,6 +45,11 @@
       const rtc::SocketAddress& addr,
       std::function<void()> callback) override;
 
+  std::unique_ptr<webrtc::AsyncDnsResolverInterface> CreateAndResolve(
+      const rtc::SocketAddress& addr,
+      int family,
+      std::function<void()> callback) override;
+
   std::unique_ptr<webrtc::AsyncDnsResolverInterface> Create() override;
 
  private:
diff --git a/p2p/base/mock_async_resolver.h b/p2p/base/mock_async_resolver.h
index 8bc0eb9..4416471 100644
--- a/p2p/base/mock_async_resolver.h
+++ b/p2p/base/mock_async_resolver.h
@@ -30,6 +30,7 @@
   ~MockAsyncResolver() = default;
 
   MOCK_METHOD(void, Start, (const rtc::SocketAddress&), (override));
+  MOCK_METHOD(void, Start, (const rtc::SocketAddress&, int family), (override));
   MOCK_METHOD(bool,
               GetResolvedAddress,
               (int family, SocketAddress* addr),
diff --git a/rtc_base/BUILD.gn b/rtc_base/BUILD.gn
index ee8d6cd..2476bca 100644
--- a/rtc_base/BUILD.gn
+++ b/rtc_base/BUILD.gn
@@ -845,6 +845,7 @@
     "async_resolver_interface.h",
   ]
   deps = [
+    ":checks",
     ":socket_address",
     "system:rtc_export",
     "third_party/sigslot",
diff --git a/rtc_base/async_resolver.cc b/rtc_base/async_resolver.cc
index 198013c..7c1a6fe 100644
--- a/rtc_base/async_resolver.cc
+++ b/rtc_base/async_resolver.cc
@@ -145,14 +145,18 @@
 }
 
 void AsyncResolver::Start(const SocketAddress& addr) {
+  Start(addr, addr.family());
+}
+
+void AsyncResolver::Start(const SocketAddress& addr, int family) {
   RTC_DCHECK_RUN_ON(&sequence_checker_);
   RTC_DCHECK(!destroy_called_);
   addr_ = addr;
   auto thread_function =
-      [this, addr, caller_task_queue = webrtc::TaskQueueBase::Current(),
+      [this, addr, family, caller_task_queue = webrtc::TaskQueueBase::Current(),
        state = state_] {
         std::vector<IPAddress> addresses;
-        int error = ResolveHostname(addr.hostname(), addr.family(), &addresses);
+        int error = ResolveHostname(addr.hostname(), family, &addresses);
         webrtc::MutexLock lock(&state->mutex);
         if (state->status == State::Status::kLive) {
           caller_task_queue->PostTask(
diff --git a/rtc_base/async_resolver.h b/rtc_base/async_resolver.h
index b7125ba..46be438 100644
--- a/rtc_base/async_resolver.h
+++ b/rtc_base/async_resolver.h
@@ -45,6 +45,7 @@
   ~AsyncResolver() override;
 
   void Start(const SocketAddress& addr) override;
+  void Start(const SocketAddress& addr, int family) override;
   bool GetResolvedAddress(int family, SocketAddress* addr) const override;
   int GetError() const override;
   void Destroy(bool wait) override;
diff --git a/rtc_base/async_resolver_interface.h b/rtc_base/async_resolver_interface.h
index 6916ea4..998ebd8 100644
--- a/rtc_base/async_resolver_interface.h
+++ b/rtc_base/async_resolver_interface.h
@@ -11,6 +11,7 @@
 #ifndef RTC_BASE_ASYNC_RESOLVER_INTERFACE_H_
 #define RTC_BASE_ASYNC_RESOLVER_INTERFACE_H_
 
+#include "rtc_base/checks.h"
 #include "rtc_base/socket_address.h"
 #include "rtc_base/system/rtc_export.h"
 #include "rtc_base/third_party/sigslot/sigslot.h"
@@ -25,6 +26,12 @@
 
   // Start address resolution of the hostname in `addr`.
   virtual void Start(const SocketAddress& addr) = 0;
+  // Start address resolution of the hostname in `addr` matching `family`.
+  virtual void Start(const SocketAddress& addr, int family) {
+    // TODO(webrtc:14319) make pure virtual when all subclasses have been
+    // updated.
+    RTC_DCHECK_NOTREACHED();
+  }
   // Returns true iff the address from `Start` was successfully resolved.
   // If the address was successfully resolved, sets `addr` to a copy of the
   // address from `Start` with the IP address set to the top most resolved