dtls-in-stun: Only read IceConfig.dtls_handshake_in_stun in 1 place.

This patch fixes 2 problems, that both stem from the same root,
the the code had scattered checking IceConfig.dtls_handshake_in_stun
but that value was not updated when we discovered that remote peer
does not support piggy backing, "restart".

the patch modifies the code so that the IceConfig.dtls_handshake_in_stun is only checked during DtlsTransport::SetupDtls.

The problems fixed are:
1) P2PTransportChannel correctly (set/does not set) Connection::RegisterDtlsPiggyback based on the existing
dtls_stun_piggyback_callbacks_.empty() (that is reset when we detect
that peer does not support piggybacking) rather than on config which
is unchanged.

2) The timeout was not set properly during "restart",
properly == the value based upon ice rtt, but was still using the
"infinitely high" value for piggybacking.
This is tested with the DtlsTransportVersionTest which now runs
(optionally) with dtls-in-stun piggybacking.

BUG=webrtc:367395350

Change-Id: Ib511bcd1d3371a2132cefe26a3c49372208735ac
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/379760
Commit-Queue: Jonas Oreland <jonaso@webrtc.org>
Auto-Submit: Jonas Oreland <jonaso@webrtc.org>
Reviewed-by: Harald Alvestrand <hta@webrtc.org>
Cr-Commit-Position: refs/heads/main@{#44044}
diff --git a/p2p/BUILD.gn b/p2p/BUILD.gn
index 1277c71..368f9ba 100644
--- a/p2p/BUILD.gn
+++ b/p2p/BUILD.gn
@@ -911,6 +911,7 @@
       ":candidate_pair_interface",
       ":connection",
       ":connection_info",
+      ":dtls_stun_piggyback_controller",
       ":ice_transport_internal",
       ":port",
       ":transport_description",
diff --git a/p2p/base/p2p_transport_channel.cc b/p2p/base/p2p_transport_channel.cc
index cdb4359..9ff2512 100644
--- a/p2p/base/p2p_transport_channel.cc
+++ b/p2p/base/p2p_transport_channel.cc
@@ -313,7 +313,7 @@
       [this](webrtc::RTCErrorOr<const StunUInt64Attribute*> delta_ack) {
         GoogDeltaAckReceived(std::move(delta_ack));
       });
-  if (config_.dtls_handshake_in_stun) {
+  if (!dtls_stun_piggyback_callbacks_.empty()) {
     connection->RegisterDtlsPiggyback(DtlsStunPiggybackCallbacks(
         [&](auto request) {
           return dtls_stun_piggyback_callbacks_.send_data(request);
@@ -2246,8 +2246,9 @@
   }
   SignalWritableState(this);
 
-  if (config_.dtls_handshake_in_stun &&
+  if (writable_ && selected_connection_ &&
       !dtls_stun_piggyback_callbacks_.empty()) {
+    // TODO(webrtc:367395350): Move this into DtlsTransport somehow.
     // Need to STUN ping here to get the last bit of the DTLS handshake across
     // as quickly as possible. Only done when DTLS-in-STUN is configured
     // and the data callback has not been reset due to lack of support.
diff --git a/p2p/dtls/dtls_ice_integrationtest.cc b/p2p/dtls/dtls_ice_integrationtest.cc
index 15aff16..f9b6666 100644
--- a/p2p/dtls/dtls_ice_integrationtest.cc
+++ b/p2p/dtls/dtls_ice_integrationtest.cc
@@ -108,6 +108,18 @@
                                false),
         client_dtls_stun_piggyback_(std::get<0>(GetParam())),
         server_dtls_stun_piggyback_(std::get<1>(GetParam())) {
+    // Enable(or disable) the dtls_in_stun parameter before
+    // DTLS is negotiated.
+    cricket::IceConfig client_config;
+    client_config.continual_gathering_policy = GATHER_CONTINUALLY;
+    client_config.dtls_handshake_in_stun = client_dtls_stun_piggyback_;
+    client_ice_->SetIceConfig(client_config);
+
+    cricket::IceConfig server_config;
+    server_config.dtls_handshake_in_stun = server_dtls_stun_piggyback_;
+    server_config.continual_gathering_policy = GATHER_CONTINUALLY;
+    server_ice_->SetIceConfig(server_config);
+
     // Setup ICE.
     client_ice_->SetIceParameters(client_ice_parameters_);
     client_ice_->SetRemoteIceParameters(server_ice_parameters_);
@@ -141,6 +153,18 @@
 
   ~DtlsIceIntegrationTest() = default;
 
+  static int CountWritableConnections(IceTransportInternal* ice) {
+    IceTransportStats stats;
+    ice->GetStats(&stats);
+    int count = 0;
+    for (const auto& con : stats.connection_infos) {
+      if (con.writable) {
+        count++;
+      }
+    }
+    return count;
+  }
+
   rtc::FakeNetworkManager network_manager_;
   std::unique_ptr<rtc::VirtualSocketServer> ss_;
   std::unique_ptr<rtc::BasicPacketSocketFactory> socket_factory_;
@@ -165,14 +189,6 @@
 };
 
 TEST_P(DtlsIceIntegrationTest, SmokeTest) {
-  cricket::IceConfig client_config;
-  client_config.dtls_handshake_in_stun = client_dtls_stun_piggyback_;
-  client_ice_->SetIceConfig(client_config);
-
-  cricket::IceConfig server_config;
-  server_config.dtls_handshake_in_stun = server_dtls_stun_piggyback_;
-  server_ice_->SetIceConfig(server_config);
-
   client_ice_->MaybeStartGathering();
   server_ice_->MaybeStartGathering();
 
@@ -188,6 +204,18 @@
             client_dtls_stun_piggyback_ && server_dtls_stun_piggyback_);
   EXPECT_EQ(server_dtls_.IsDtlsPiggybackSupportedByPeer(),
             client_dtls_stun_piggyback_ && server_dtls_stun_piggyback_);
+
+  // Validate that we can add new Connections (that become writable).
+  network_manager_.AddInterface(rtc::SocketAddress("192.168.2.1", 0));
+  EXPECT_THAT(webrtc::WaitUntil(
+                  [&] {
+                    return CountWritableConnections(client_ice_.get()) > 1 &&
+                           CountWritableConnections(server_ice_.get()) > 1;
+                  },
+                  IsTrue(),
+                  {.timeout = webrtc::TimeDelta::Millis(kDefaultTimeout),
+                   .clock = &fake_clock_}),
+              webrtc::IsRtcOk());
 }
 
 // Test cases are parametrized by
@@ -197,12 +225,14 @@
 INSTANTIATE_TEST_SUITE_P(
     DtlsStunPiggybackingIntegrationTest,
     DtlsIceIntegrationTest,
-    ::testing::Values(
-        std::make_tuple(false, false, rtc::SSL_PROTOCOL_DTLS_12),
-        std::make_tuple(true, false, rtc::SSL_PROTOCOL_DTLS_12),
-        std::make_tuple(false, true, rtc::SSL_PROTOCOL_DTLS_12),
-        std::make_tuple(true, true, rtc::SSL_PROTOCOL_DTLS_12),
-        // Skip negative cases that are behaving similar for DTLS 1.3
-        std::make_tuple(true, true, rtc::SSL_PROTOCOL_DTLS_13)));
+    ::testing::Values(std::make_tuple(false, false, rtc::SSL_PROTOCOL_DTLS_12),
+                      std::make_tuple(true, false, rtc::SSL_PROTOCOL_DTLS_12),
+                      std::make_tuple(false, true, rtc::SSL_PROTOCOL_DTLS_12),
+                      std::make_tuple(true, true, rtc::SSL_PROTOCOL_DTLS_12),
+
+                      std::make_tuple(false, false, rtc::SSL_PROTOCOL_DTLS_13),
+                      std::make_tuple(true, false, rtc::SSL_PROTOCOL_DTLS_13),
+                      std::make_tuple(false, true, rtc::SSL_PROTOCOL_DTLS_13),
+                      std::make_tuple(true, true, rtc::SSL_PROTOCOL_DTLS_13)));
 
 }  // namespace cricket
diff --git a/p2p/dtls/dtls_stun_piggyback_controller.cc b/p2p/dtls/dtls_stun_piggyback_controller.cc
index 602f623..9fedd89 100644
--- a/p2p/dtls/dtls_stun_piggyback_controller.cc
+++ b/p2p/dtls/dtls_stun_piggyback_controller.cc
@@ -175,4 +175,17 @@
   dtls_data_callback_(data->array_view());
 }
 
+void DtlsStunPiggybackController::SetEnabled(bool enabled) {
+  RTC_DCHECK_RUN_ON(&sequence_checker_);
+  enabled_ = enabled;
+  if (!enabled) {
+    state_ = State::OFF;
+  }
+}
+
+bool DtlsStunPiggybackController::enabled() const {
+  RTC_DCHECK_RUN_ON(&sequence_checker_);
+  return enabled_;
+}
+
 }  // namespace cricket
diff --git a/p2p/dtls/dtls_stun_piggyback_controller.h b/p2p/dtls/dtls_stun_piggyback_controller.h
index c25d93f..ebf23c6d 100644
--- a/p2p/dtls/dtls_stun_piggyback_controller.h
+++ b/p2p/dtls/dtls_stun_piggyback_controller.h
@@ -38,6 +38,11 @@
           dtls_data_callback);
   ~DtlsStunPiggybackController();
 
+  // Initially set from IceConfig.dtls_handshake_in_stun
+  // but is also set to FALSE before restarting handshake.
+  void SetEnabled(bool enabled);
+  bool enabled() const;
+
   enum class State {
     // We don't know if peer support DTLS piggybacked in STUN.
     // We will piggyback DTLS until we get a piggybacked response
@@ -79,6 +84,7 @@
                              const StunByteStringAttribute* ack);
 
  private:
+  bool enabled_ RTC_GUARDED_BY(sequence_checker_) = false;
   State state_ RTC_GUARDED_BY(sequence_checker_) = State::TENTATIVE;
   rtc::Buffer pending_packet_ RTC_GUARDED_BY(sequence_checker_);
   absl::AnyInvocable<void(rtc::ArrayView<const uint8_t>)> dtls_data_callback_;
diff --git a/p2p/dtls/dtls_transport.cc b/p2p/dtls/dtls_transport.cc
index 219c91f..002aa83 100644
--- a/p2p/dtls/dtls_transport.cc
+++ b/p2p/dtls/dtls_transport.cc
@@ -109,8 +109,7 @@
 
   // If we try to use DTLS-in-STUN, DTLS packets will be sent as part of STUN
   // packets and are consumed here.
-  if (ice_transport_->config().dtls_handshake_in_stun &&
-      dtls_stun_piggyback_controller_ &&
+  if (dtls_stun_piggyback_controller_ &&
       dtls_stun_piggyback_controller_->MaybeConsumePacket(data)) {
     written = data.size();
     return rtc::SR_SUCCESS;
@@ -385,14 +384,20 @@
   return dtls_ ? dtls_->ExportSrtpKeyingMaterial(keying_material) : false;
 }
 
-bool DtlsTransport::SetupDtls() {
+bool DtlsTransport::SetupDtls(bool disable_piggybacking) {
   RTC_DCHECK(dtls_role_);
+  // Look at both config...and argument (used on restart).
+  const bool enable_piggybacking =
+      ice_transport_->config().dtls_handshake_in_stun && !disable_piggybacking;
+
   {
     auto downward = std::make_unique<StreamInterfaceChannel>(ice_transport_);
     StreamInterfaceChannel* downward_ptr = downward.get();
 
-    downward_ptr->SetDtlsStunPiggybackController(
-        &dtls_stun_piggyback_controller_);
+    if (enable_piggybacking) {
+      downward_ptr->SetDtlsStunPiggybackController(
+          &dtls_stun_piggyback_controller_);
+    }
     dtls_ = rtc::SSLStreamAdapter::Create(
         std::move(downward),
         [this](rtc::SSLHandshakeError error) { OnDtlsHandshakeError(error); },
@@ -404,6 +409,11 @@
     downward_ = downward_ptr;
   }
 
+  if (!enable_piggybacking) {
+    ice_transport_->ResetDtlsStunPiggybackCallbacks();
+  }
+  dtls_stun_piggyback_controller_.SetEnabled(enable_piggybacking);
+
   dtls_->SetIdentity(local_certificate_->identity()->Clone());
   dtls_->SetMaxProtocolVersion(ssl_max_version_);
   dtls_->SetServerRole(*dtls_role_);
@@ -617,18 +627,16 @@
   // Recreate the DTLS session. Note: this assumes we can consider
   // the previous DTLS session state beyond repair and no packet
   // reached the peer.
-  if (ice_transport_->config().dtls_handshake_in_stun && dtls_ &&
+  if (dtls_stun_piggyback_controller_.enabled() && dtls_ &&
       !was_ever_connected_ && !IsDtlsPiggybackSupportedByPeer() &&
       (dtls_state() == webrtc::DtlsTransportState::kConnecting ||
        dtls_state() == webrtc::DtlsTransportState::kNew)) {
     RTC_LOG(LS_INFO) << "DTLS piggybacking not supported, restarting...";
-    ice_transport_->ResetDtlsStunPiggybackCallbacks();
-    downward_->SetDtlsStunPiggybackController(nullptr);
     dtls_.reset(nullptr);
     set_dtls_state(webrtc::DtlsTransportState::kNew);
     set_writable(false);
 
-    if (!SetupDtls()) {
+    if (!SetupDtls(/* disable_piggybacking= */ true)) {
       RTC_LOG(LS_ERROR)
           << "Failed to setup DTLS again after attempted piggybacking.";
       set_dtls_state(webrtc::DtlsTransportState::kFailed);
@@ -856,10 +864,9 @@
   RTC_DCHECK(ice_transport_);
   //  When adding the DTLS handshake in STUN we want to call StartSSL even
   //  before the ICE transport is ready.
-  bool start_early_for_dtls_in_stun =
-      ice_transport_->config().dtls_handshake_in_stun;
-  if (dtls_ && (ice_transport_->writable() || start_early_for_dtls_in_stun)) {
-    ConfigureHandshakeTimeout(start_early_for_dtls_in_stun);
+  if (dtls_ && (ice_transport_->writable() ||
+                dtls_stun_piggyback_controller_.enabled())) {
+    ConfigureHandshakeTimeout();
 
     if (dtls_->StartSSL()) {
       // This should never fail:
@@ -947,8 +954,9 @@
   SendDtlsHandshakeError(error);
 }
 
-void DtlsTransport::ConfigureHandshakeTimeout(bool uses_dtls_in_stun) {
+void DtlsTransport::ConfigureHandshakeTimeout() {
   RTC_DCHECK(dtls_);
+  bool uses_dtls_in_stun = dtls_stun_piggyback_controller_.enabled();
   std::optional<int> rtt_ms = ice_transport_->GetRttEstimate();
   if (uses_dtls_in_stun) {
     // Configure a very high timeout to effectively disable the DTLS timeout
@@ -984,19 +992,17 @@
 bool DtlsTransport::IsDtlsPiggybackSupportedByPeer() {
   RTC_DCHECK_RUN_ON(&thread_checker_);
   RTC_DCHECK(ice_transport_);
-  return ice_transport_->config().dtls_handshake_in_stun &&
-         dtls_stun_piggyback_controller_.state() !=
-             DtlsStunPiggybackController::State::OFF;
+  return dtls_stun_piggyback_controller_.state() !=
+         DtlsStunPiggybackController::State::OFF;
 }
 
 bool DtlsTransport::IsDtlsPiggybackHandshaking() {
   RTC_DCHECK_RUN_ON(&thread_checker_);
   RTC_DCHECK(ice_transport_);
-  return ice_transport_->config().dtls_handshake_in_stun &&
-         (dtls_stun_piggyback_controller_.state() ==
-              DtlsStunPiggybackController::State::TENTATIVE ||
-          dtls_stun_piggyback_controller_.state() ==
-              DtlsStunPiggybackController::State::CONFIRMED);
+  return dtls_stun_piggyback_controller_.state() ==
+             DtlsStunPiggybackController::State::TENTATIVE ||
+         dtls_stun_piggyback_controller_.state() ==
+             DtlsStunPiggybackController::State::CONFIRMED;
 }
 
 }  // namespace cricket
diff --git a/p2p/dtls/dtls_transport.h b/p2p/dtls/dtls_transport.h
index b0953b5..9c7b0f9 100644
--- a/p2p/dtls/dtls_transport.h
+++ b/p2p/dtls/dtls_transport.h
@@ -78,8 +78,8 @@
 
  private:
   IceTransportInternal* const ice_transport_;  // owned by DtlsTransport
-  DtlsStunPiggybackController*
-      dtls_stun_piggyback_controller_;  // owned by DtlsTransport
+  DtlsStunPiggybackController* dtls_stun_piggyback_controller_ =
+      nullptr;  // owned by DtlsTransport
   rtc::StreamState state_ RTC_GUARDED_BY(callback_sequence_);
   rtc::BufferQueue packets_ RTC_GUARDED_BY(callback_sequence_);
 };
@@ -228,9 +228,6 @@
     return sb.Release();
   }
 
-  void SetPiggybackDtlsDataCallback(
-      absl::AnyInvocable<void(rtc::PacketTransportInternal* transport,
-                              const rtc::ReceivedPacket& packet)> callback);
   bool IsDtlsPiggybackSupportedByPeer();
 
  private:
@@ -245,16 +242,19 @@
   void OnReceivingState(rtc::PacketTransportInternal* transport);
   void OnDtlsEvent(int sig, int err);
   void OnNetworkRouteChanged(std::optional<rtc::NetworkRoute> network_route);
-  bool SetupDtls();
+  bool SetupDtls(bool disable_piggybacking = false);
   void MaybeStartDtls();
   bool HandleDtlsPacket(rtc::ArrayView<const uint8_t> payload);
   void OnDtlsHandshakeError(rtc::SSLHandshakeError error);
-  void ConfigureHandshakeTimeout(bool uses_dtls_in_stun);
+  void ConfigureHandshakeTimeout();
 
   void set_receiving(bool receiving);
   void set_writable(bool writable);
   // Sets the DTLS state, signaling if necessary.
   void set_dtls_state(webrtc::DtlsTransportState state);
+  void SetPiggybackDtlsDataCallback(
+      absl::AnyInvocable<void(rtc::PacketTransportInternal* transport,
+                              const rtc::ReceivedPacket& packet)> callback);
 
   RTC_NO_UNIQUE_ADDRESS webrtc::SequenceChecker thread_checker_;
 
diff --git a/p2p/dtls/dtls_transport_unittest.cc b/p2p/dtls/dtls_transport_unittest.cc
index 6740648..2485d5d 100644
--- a/p2p/dtls/dtls_transport_unittest.cc
+++ b/p2p/dtls/dtls_transport_unittest.cc
@@ -18,6 +18,7 @@
 #include <optional>
 #include <set>
 #include <string>
+#include <tuple>
 #include <utility>
 #include <vector>
 
@@ -52,10 +53,9 @@
 #include "test/gtest.h"
 #include "test/wait_until.h"
 
-#define MAYBE_SKIP_TEST(feature)                                  \
-  if (!(rtc::SSLStreamAdapter::feature())) {                      \
-    RTC_LOG(LS_INFO) << #feature " feature disabled... skipping"; \
-    return;                                                       \
+#define MAYBE_SKIP_TEST(feature)                              \
+  if (!(rtc::SSLStreamAdapter::feature())) {                  \
+    GTEST_SKIP() << #feature " feature disabled... skipping"; \
   }
 
 namespace cricket {
@@ -155,6 +155,16 @@
     return true;
   }
 
+  // Connect the fake ICE transports so that packets flows from one to other.
+  bool ConnectIceTransport(DtlsTestClient* peer) {
+    fake_ice_transport()->SetDestinationNotWritable(peer->fake_ice_transport());
+    return true;
+  }
+
+  bool SendIcePing() { return fake_ice_transport_->SendIcePing(); }
+
+  bool SendIcePingConf() { return fake_ice_transport_->SendIcePingConf(); }
+
   int received_dtls_client_hellos() const {
     return received_dtls_client_hellos_;
   }
@@ -576,15 +586,41 @@
     {rtc::kDtls13VersionBytes, dtls_13_handshake_events},
 };
 
+struct EndpointConfig {
+  rtc::SSLProtocolVersion max_protocol_version;
+  bool dtls_in_stun = false;
+};
+
 class DtlsTransportVersionTest
     : public DtlsTransportTestBase,
       public ::testing::TestWithParam<
-          ::testing::tuple<rtc::SSLProtocolVersion, rtc::SSLProtocolVersion>> {
+          std::tuple<EndpointConfig, EndpointConfig>> {
  public:
   void Prepare() {
     PrepareDtls(rtc::KT_DEFAULT);
-    SetMaxProtocolVersions(::testing::get<0>(GetParam()),
-                           ::testing::get<1>(GetParam()));
+    SetMaxProtocolVersions(std::get<0>(GetParam()).max_protocol_version,
+                           std::get<1>(GetParam()).max_protocol_version);
+
+    client1_.SetupTransports(ICEROLE_CONTROLLING);
+    client2_.SetupTransports(ICEROLE_CONTROLLED);
+    client1_.dtls_transport()->SetDtlsRole(rtc::SSL_CLIENT);
+    client2_.dtls_transport()->SetDtlsRole(rtc::SSL_SERVER);
+
+    if (std::get<0>(GetParam()).dtls_in_stun) {
+      auto config = client1_.fake_ice_transport()->config();
+      config.dtls_handshake_in_stun = true;
+      client1_.fake_ice_transport()->SetIceConfig(config);
+    }
+    if (std::get<1>(GetParam()).dtls_in_stun) {
+      auto config = client2_.fake_ice_transport()->config();
+      config.dtls_handshake_in_stun = true;
+      client2_.fake_ice_transport()->SetIceConfig(config);
+    }
+
+    SetRemoteFingerprintFromCert(client1_.dtls_transport(),
+                                 client2_.certificate());
+    SetRemoteFingerprintFromCert(client2_.dtls_transport(),
+                                 client1_.certificate());
   }
 
   // Run DTLS handshake.
@@ -592,8 +628,6 @@
   // - drop packets as specified in `packets_to_drop`
   std::pair</* dtls_version_bytes*/ int, std::vector<HandshakeTestEvent>>
   RunHandshake(std::set<unsigned> packets_to_drop) {
-    Negotiate(/* client1_server= */ false);
-
     std::vector<HandshakeTestEvent> events;
     auto start_time_ns = fake_clock_.TimeNanos();
     client1_.fake_ice_transport()->set_rtt_estimate(50, true);
@@ -642,7 +676,11 @@
           return LogSend("server", diff_ms, drop, data, len);
         });
 
-    EXPECT_TRUE(client1_.Connect(&client2_, false));
+    EXPECT_TRUE(client1_.ConnectIceTransport(&client2_));
+    client1_.SendIcePing();
+    client2_.SendIcePingConf();
+    client2_.SendIcePing();
+    client1_.SendIcePingConf();
 
     EXPECT_THAT(webrtc::WaitUntil(
                     [&] {
@@ -661,12 +699,13 @@
 
     auto dtls_version_bytes = client1_.GetVersionBytes();
     EXPECT_EQ(dtls_version_bytes, client2_.GetVersionBytes());
-    return std::make_pair(*dtls_version_bytes, std::move(events));
+    return std::make_pair(dtls_version_bytes.value_or(0), std::move(events));
   }
 
   int GetExpectedDtlsVersionBytes() {
-    int version = std::min(static_cast<int>(::testing::get<0>(GetParam())),
-                           static_cast<int>(::testing::get<1>(GetParam())));
+    int version = std::min(
+        static_cast<int>(std::get<0>(GetParam()).max_protocol_version),
+        static_cast<int>(std::get<1>(GetParam()).max_protocol_version));
     if (version == rtc::SSL_PROTOCOL_DTLS_13) {
       return rtc::kDtls13VersionBytes;
     } else {
@@ -711,17 +750,40 @@
   }
 };
 
+static const EndpointConfig kEndpointVariants[] = {
+    {
+        .max_protocol_version = rtc::SSL_PROTOCOL_DTLS_10,
+        .dtls_in_stun = false,
+    },
+    {
+        .max_protocol_version = rtc::SSL_PROTOCOL_DTLS_12,
+        .dtls_in_stun = false,
+    },
+    {
+        .max_protocol_version = rtc::SSL_PROTOCOL_DTLS_13,
+        .dtls_in_stun = false,
+    },
+    {
+        .max_protocol_version = rtc::SSL_PROTOCOL_DTLS_10,
+        .dtls_in_stun = true,
+    },
+    {
+        .max_protocol_version = rtc::SSL_PROTOCOL_DTLS_12,
+        .dtls_in_stun = true,
+    },
+    {
+        .max_protocol_version = rtc::SSL_PROTOCOL_DTLS_13,
+        .dtls_in_stun = true,
+    },
+};
+
 // Will test every combination of 1.0/1.2/1.3 on the client and server.
 // DTLS will negotiate an effective version (the min of client & sewrver).
 INSTANTIATE_TEST_SUITE_P(
     DtlsTransportVersionTest,
     DtlsTransportVersionTest,
-    ::testing::Combine(::testing::Values(rtc::SSL_PROTOCOL_DTLS_10,
-                                         rtc::SSL_PROTOCOL_DTLS_12,
-                                         rtc::SSL_PROTOCOL_DTLS_13),
-                       ::testing::Values(rtc::SSL_PROTOCOL_DTLS_10,
-                                         rtc::SSL_PROTOCOL_DTLS_12,
-                                         rtc::SSL_PROTOCOL_DTLS_13)));
+    ::testing::Combine(testing::ValuesIn(kEndpointVariants),
+                       testing::ValuesIn(kEndpointVariants)));
 
 // Test that an acceptable cipher suite is negotiated when different versions
 // of DTLS are supported. Note that it's IsAcceptableCipher that does the actual
@@ -732,6 +794,10 @@
 }
 
 TEST_P(DtlsTransportVersionTest, HandshakeFlights) {
+  if (std::get<0>(GetParam()).dtls_in_stun &&
+      std::get<1>(GetParam()).dtls_in_stun) {
+    GTEST_SKIP() << "This test does not support dtls in stun";
+  }
   Prepare();
   auto [dtls_version_bytes, events] = RunHandshake({});
 
@@ -743,6 +809,10 @@
 
 TEST_P(DtlsTransportVersionTest, HandshakeLoseFirstClientPacket) {
   MAYBE_SKIP_TEST(IsBoringSsl);
+  if (std::get<0>(GetParam()).dtls_in_stun &&
+      std::get<1>(GetParam()).dtls_in_stun) {
+    GTEST_SKIP() << "This test does not support dtls in stun";
+  }
 
   Prepare();
   auto [dtls_version_bytes, events] = RunHandshake({/* packet_num= */ 0});
@@ -758,6 +828,10 @@
 
 TEST_P(DtlsTransportVersionTest, HandshakeLoseSecondClientPacket) {
   MAYBE_SKIP_TEST(IsBoringSsl);
+  if (std::get<0>(GetParam()).dtls_in_stun &&
+      std::get<1>(GetParam()).dtls_in_stun) {
+    GTEST_SKIP() << "This test does not support dtls in stun";
+  }
 
   Prepare();
   auto [dtls_version_bytes, events] = RunHandshake({/* packet_num= */ 2});
@@ -817,7 +891,7 @@
       };
       break;
     default:
-      RTC_CHECK(false) << "Unknown dtls version bytes: " << dtls_version_bytes;
+      FAIL() << "Unknown dtls version bytes: " << dtls_version_bytes;
   }
   EXPECT_EQ(events, expect);
 }
diff --git a/p2p/test/fake_ice_transport.h b/p2p/test/fake_ice_transport.h
index 9b9df19..d365ddc 100644
--- a/p2p/test/fake_ice_transport.h
+++ b/p2p/test/fake_ice_transport.h
@@ -32,6 +32,7 @@
 #include "p2p/base/ice_transport_internal.h"
 #include "p2p/base/port.h"
 #include "p2p/base/transport_description.h"
+#include "p2p/dtls/dtls_stun_piggyback_callbacks.h"
 #include "rtc_base/async_packet_socket.h"
 #include "rtc_base/checks.h"
 #include "rtc_base/copy_on_write_buffer.h"
@@ -118,6 +119,24 @@
     }
   }
 
+  void SetDestinationNotWritable(FakeIceTransport* dest) {
+    RTC_DCHECK_RUN_ON(network_thread_);
+    if (dest == dest_) {
+      return;
+    }
+    RTC_DCHECK(!dest || !dest_)
+        << "Changing fake destination from one to another is not supported.";
+
+    if (dest) {
+      RTC_DCHECK_RUN_ON(dest->network_thread_);
+      dest->dest_ = this;
+    } else if (dest_) {
+      RTC_DCHECK_RUN_ON(dest_->network_thread_);
+      dest_->dest_ = nullptr;
+    }
+    dest_ = dest;
+  }
+
   void SetTransportState(webrtc::IceTransportState state,
                          IceTransportState legacy_state) {
     RTC_DCHECK_RUN_ON(network_thread_);
@@ -415,6 +434,61 @@
     }
   }
 
+  void ResetDtlsStunPiggybackCallbacks() override {
+    dtls_stun_piggyback_callbacks_.reset();
+  }
+  void SetDtlsStunPiggybackCallbacks(
+      DtlsStunPiggybackCallbacks&& callbacks) override {
+    RTC_LOG(LS_INFO) << name_ << ": SetDtlsStunPiggybackCallbacks";
+    dtls_stun_piggyback_callbacks_ = std::move(callbacks);
+  }
+
+  bool SendIcePing() {
+    RTC_DCHECK_RUN_ON(network_thread_);
+    RTC_DLOG(LS_INFO) << name_ << ": SendIcePing()";
+    auto msg = std::make_unique<IceMessage>(STUN_BINDING_REQUEST);
+    MaybeAddDtlsPiggybackingAttributes(msg.get());
+    msg->AddFingerprint();
+    rtc::ByteBufferWriter buf;
+    msg->Write(&buf);
+    SendPacketInternal(rtc::CopyOnWriteBuffer(buf.DataView()));
+    return true;
+  }
+
+  void MaybeAddDtlsPiggybackingAttributes(StunMessage* msg) {
+    if (dtls_stun_piggyback_callbacks_.empty()) {
+      return;
+    }
+
+    const auto& [attr, ack] = dtls_stun_piggyback_callbacks_.send_data(
+        static_cast<StunMessageType>(msg->type()));
+
+    RTC_DLOG(LS_INFO) << name_ << ": Adding attr: " << attr.has_value()
+                      << " ack: " << ack.has_value() << " to stun message: "
+                      << StunMethodToString(msg->type());
+
+    if (attr) {
+      msg->AddAttribute(std::make_unique<StunByteStringAttribute>(
+          STUN_ATTR_META_DTLS_IN_STUN, *attr));
+    }
+    if (ack) {
+      msg->AddAttribute(std::make_unique<StunByteStringAttribute>(
+          STUN_ATTR_META_DTLS_IN_STUN_ACK, *ack));
+    }
+  }
+
+  bool SendIcePingConf() {
+    RTC_DCHECK_RUN_ON(network_thread_);
+    RTC_DLOG(LS_INFO) << name_ << ": SendIcePingConf()";
+    auto msg = std::make_unique<IceMessage>(STUN_BINDING_RESPONSE);
+    MaybeAddDtlsPiggybackingAttributes(msg.get());
+    msg->AddFingerprint();
+    rtc::ByteBufferWriter buf;
+    msg->Write(&buf);
+    SendPacketInternal(rtc::CopyOnWriteBuffer(buf.DataView()));
+    return true;
+  }
+
  private:
   void set_writable(bool writable)
       RTC_EXCLUSIVE_LOCKS_REQUIRED(network_thread_) {
@@ -449,6 +523,25 @@
   void ReceivePacketInternal(const rtc::CopyOnWriteBuffer& packet) {
     RTC_DCHECK_RUN_ON(network_thread_);
     auto now = rtc::TimeMicros();
+    if (auto msg = GetStunMessage(packet)) {
+      const auto* dtls_piggyback_attr =
+          msg->GetByteString(STUN_ATTR_META_DTLS_IN_STUN);
+      const auto* dtls_piggyback_ack =
+          msg->GetByteString(STUN_ATTR_META_DTLS_IN_STUN_ACK);
+      RTC_DLOG(LS_INFO) << name_ << ": Got STUN message: "
+                        << StunMethodToString(msg->type())
+                        << " attr: " << (dtls_piggyback_attr != nullptr)
+                        << " ack: " << (dtls_piggyback_ack != nullptr);
+      if (!dtls_stun_piggyback_callbacks_.empty()) {
+        dtls_stun_piggyback_callbacks_.recv_data(dtls_piggyback_attr,
+                                                 dtls_piggyback_ack);
+      }
+
+      if (msg->type() == STUN_BINDING_RESPONSE) {
+        set_writable(true);
+      }
+      return;
+    }
     if (packet_recv_filter_func_ && packet_recv_filter_func_(packet, now)) {
       RTC_DLOG(LS_INFO) << name_
                         << ": dropping packet at receiver len=" << packet.size()
@@ -460,6 +553,18 @@
     }
   }
 
+  std::unique_ptr<IceMessage> GetStunMessage(
+      const rtc::CopyOnWriteBuffer& packet) {
+    if (!StunMessage::ValidateFingerprint(packet.data<char>(), packet.size())) {
+      return nullptr;
+    }
+
+    std::unique_ptr<IceMessage> stun_msg(new IceMessage());
+    rtc::ByteBufferReader buf(rtc::MakeArrayView(packet.data(), packet.size()));
+    RTC_CHECK(stun_msg->Read(&buf));
+    return stun_msg;
+  }
+
   const std::string name_;
   const int component_;
   FakeIceTransport* dest_ RTC_GUARDED_BY(network_thread_) = nullptr;
@@ -497,6 +602,7 @@
       packet_send_filter_func_ RTC_GUARDED_BY(network_thread_) = nullptr;
   absl::AnyInvocable<bool(const rtc::CopyOnWriteBuffer&, uint64_t)>
       packet_recv_filter_func_ RTC_GUARDED_BY(network_thread_) = nullptr;
+  DtlsStunPiggybackCallbacks dtls_stun_piggyback_callbacks_;
 };
 
 class FakeIceTransportWrapper : public webrtc::IceTransportInterface {