dtls1.3 - patch 6

Incrase test coverage of DtlsRestart by having different
dtls (1.2/1.3) settings on the resp peers.

BUG=webrtc:383141571

Change-Id: I8429f2481a4d7eee0e12c0a954879932060e4fac
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/380060
Auto-Submit: Jonas Oreland <jonaso@webrtc.org>
Reviewed-by: Henrik Boström <hbos@webrtc.org>
Commit-Queue: Jonas Oreland <jonaso@webrtc.org>
Cr-Commit-Position: refs/heads/main@{#44060}
diff --git a/pc/data_channel_integrationtest.cc b/pc/data_channel_integrationtest.cc
index cfd8463..52059be 100644
--- a/pc/data_channel_integrationtest.cc
+++ b/pc/data_channel_integrationtest.cc
@@ -59,6 +59,7 @@
 using ::testing::Eq;
 using ::testing::IsTrue;
 using ::testing::Ne;
+using ::testing::ValuesIn;
 
 // All tests in this file require SCTP support.
 #ifdef WEBRTC_HAVE_SCTP
@@ -132,6 +133,15 @@
   }
 }
 
+void MakeOfferHavePassiveDtlsRole(
+    std::unique_ptr<SessionDescriptionInterface>& desc) {
+  auto& transport_infos = desc->description()->transport_infos();
+  for (auto& transport_info : transport_infos) {
+    transport_info.description.connection_role =
+        cricket::CONNECTIONROLE_PASSIVE;
+  }
+}
+
 // This test causes a PeerConnection to enter Disconnected state, and
 // sends data on a DataChannel while disconnected.
 // The data should be surfaced when the connection reestablishes.
@@ -1476,20 +1486,40 @@
 
 class DataChannelIntegrationTestUnifiedPlanFieldTrials
     : public DataChannelIntegrationTestUnifiedPlan,
-      public ::testing::WithParamInterface<
-          std::tuple</* callee-DTLS-active=*/bool, std::string>> {
+      public ::testing::WithParamInterface<std::tuple<
+          /* callee-DTLS-active=*/bool,
+          /* caller-field-trials=*/const char*,
+          /* callee-field-trials=*/const char*,
+          /* callee2-field-trials=*/const char*>> {
  protected:
   DataChannelIntegrationTestUnifiedPlanFieldTrials() {
-    SetFieldTrials(std::get<1>(GetParam()));
+    const bool callee_active = std::get<0>(GetParam());
+    RTC_LOG(LS_INFO) << "dtls_active: " << (callee_active ? "callee" : "caller")
+                     << " field-trials: caller: " << std::get<1>(GetParam())
+                     << " callee: " << std::get<2>(GetParam())
+                     << " callee2: " << std::get<3>(GetParam());
+
+    SetFieldTrials(kCallerName, std::get<1>(GetParam()));
+    SetFieldTrials(kCalleeName, std::get<2>(GetParam()));
+    SetFieldTrials("Callee2", std::get<3>(GetParam()));
   }
 
  private:
 };
 
+// TODO(webrtc:367395350/jonaso): Add "WebRTC-IceHandshakeDtls/Enabled/"...
+// when it works for this testcase...
+static const char* kTrialsVariants[] = {
+    "",
+    "WebRTC-ForceDtls13/Enabled/",
+};
+
 INSTANTIATE_TEST_SUITE_P(DataChannelIntegrationTestUnifiedPlanFieldTrials,
                          DataChannelIntegrationTestUnifiedPlanFieldTrials,
                          Combine(testing::Bool(),
-                                 Values("", "WebRTC-ForceDtls13/Enabled/")));
+                                 ValuesIn(kTrialsVariants),
+                                 ValuesIn(kTrialsVariants),
+                                 ValuesIn(kTrialsVariants)));
 
 TEST_P(DataChannelIntegrationTestUnifiedPlanFieldTrials, DtlsRestart) {
   RTCConfiguration config;
@@ -1504,11 +1534,7 @@
                                              /*reset_encoder_factory=*/false,
                                              /*reset_decoder_factory=*/false);
 
-  if (std::get<0>(GetParam())) {
-    callee()->SetReceivedSdpMunger(MakeActiveSctpOffer);
-    callee2->SetReceivedSdpMunger(MakeActiveSctpOffer);
-  }
-
+  const bool callee_active = std::get<0>(GetParam());
   ConnectFakeSignaling();
 
   DataChannelInit dc_init;
@@ -1521,11 +1547,21 @@
   std::unique_ptr<SessionDescriptionInterface> offer;
   callee()->SetReceivedSdpMunger(
       [&](std::unique_ptr<SessionDescriptionInterface>& sdp) {
+        if (callee_active) {
+          MakeOfferHavePassiveDtlsRole(sdp);
+        } else {
+          MakeActiveSctpOffer(sdp);
+        }
         offer = sdp->Clone();
       });
   callee()->SetGeneratedSdpMunger(
-      [](std::unique_ptr<SessionDescriptionInterface>& sdp) {
+      [&](std::unique_ptr<SessionDescriptionInterface>& sdp) {
         SetSdpType(sdp, SdpType::kPrAnswer);
+        if (callee_active) {
+          MakeActiveSctpOffer(sdp);
+        } else {
+          MakeOfferHavePassiveDtlsRole(sdp);
+        }
       });
   std::unique_ptr<SessionDescriptionInterface> answer;
   caller()->SetReceivedSdpMunger(
@@ -1545,6 +1581,26 @@
                         Eq(DataChannelInterface::kOpen)),
               IsRtcOk());
 
+  ASSERT_THAT(
+      WaitUntil([&] { return caller()->pc()->peer_connection_state(); },
+                Eq(PeerConnectionInterface::PeerConnectionState::kConnected)),
+      IsRtcOk());
+  ASSERT_THAT(
+      WaitUntil([&] { return callee()->pc()->peer_connection_state(); },
+                Eq(PeerConnectionInterface::PeerConnectionState::kConnected)),
+      IsRtcOk());
+
+  if (callee_active) {
+    ASSERT_THAT(caller()->dtls_transport_role(),
+                Eq(DtlsTransportTlsRole::kServer));
+    ASSERT_THAT(callee()->dtls_transport_role(),
+                Eq(DtlsTransportTlsRole::kClient));
+  } else {
+    ASSERT_THAT(caller()->dtls_transport_role(),
+                Eq(DtlsTransportTlsRole::kClient));
+    ASSERT_THAT(callee()->dtls_transport_role(),
+                Eq(DtlsTransportTlsRole::kServer));
+  }
   callee2->set_signaling_message_receiver(caller());
 
   std::atomic<int> caller_sent_on_dc(0);
@@ -1612,6 +1668,18 @@
       WaitUntil([&] { return callee2->data_observer()->last_message(); },
                 Eq("KESO")),
       IsRtcOk());
+
+  if (callee_active) {
+    EXPECT_THAT(caller()->dtls_transport_role(),
+                Eq(DtlsTransportTlsRole::kServer));
+    EXPECT_THAT(callee2->dtls_transport_role(),
+                Eq(DtlsTransportTlsRole::kClient));
+  } else {
+    EXPECT_THAT(caller()->dtls_transport_role(),
+                Eq(DtlsTransportTlsRole::kClient));
+    EXPECT_THAT(callee2->dtls_transport_role(),
+                Eq(DtlsTransportTlsRole::kServer));
+  }
 }
 
 #endif  // WEBRTC_HAVE_SCTP
diff --git a/pc/test/integration_test_helpers.h b/pc/test/integration_test_helpers.h
index b0b600b..70b5ae2 100644
--- a/pc/test/integration_test_helpers.h
+++ b/pc/test/integration_test_helpers.h
@@ -690,7 +690,12 @@
 
   bool SetRemoteDescription(std::unique_ptr<SessionDescriptionInterface> desc) {
     auto observer = rtc::make_ref_counted<FakeSetRemoteDescriptionObserver>();
-    RTC_LOG(LS_INFO) << debug_name_ << ": SetRemoteDescription SDP:" << desc;
+    std::string sdp;
+    EXPECT_TRUE(desc->ToString(&sdp));
+    RTC_LOG(LS_INFO) << debug_name_
+                     << ": SetRemoteDescription SDP: type=" << desc->type()
+                     << " contents=\n"
+                     << sdp;
     pc()->SetRemoteDescription(std::move(desc), observer);  // desc.release());
     RemoveUnusedVideoRenderers();
     EXPECT_THAT(
@@ -746,6 +751,12 @@
     });
   }
 
+  std::optional<DtlsTransportTlsRole> dtls_transport_role() {
+    return network_thread_->BlockingCall([&] {
+      return pc()->GetSctpTransport()->dtls_transport()->Information().role();
+    });
+  }
+
   // Setting the local description and sending the SDP message over the fake
   // signaling channel are combined into the same method because the SDP
   // message needs to be sent as soon as SetLocalDescription finishes, without
@@ -758,7 +769,9 @@
     SdpType type = desc->GetType();
     std::string sdp;
     EXPECT_TRUE(desc->ToString(&sdp));
-    RTC_LOG(LS_INFO) << debug_name_ << ": local SDP contents=\n" << sdp;
+    RTC_LOG(LS_INFO) << debug_name_ << ": local SDP type=" << desc->type()
+                     << " contents=\n"
+                     << sdp;
     pc()->SetLocalDescription(observer.get(), desc.release());
     RemoveUnusedVideoRenderers();
     // As mentioned above, we need to send the message immediately after
@@ -1347,6 +1360,9 @@
 // of everything else (including "PeerConnectionFactory"s).
 class PeerConnectionIntegrationBaseTest : public ::testing::Test {
  public:
+  static constexpr char kCallerName[] = "Caller";
+  static constexpr char kCalleeName[] = "Callee";
+
   explicit PeerConnectionIntegrationBaseTest(SdpSemantics sdp_semantics)
       : sdp_semantics_(sdp_semantics),
         ss_(new rtc::VirtualSocketServer()),
@@ -1411,6 +1427,15 @@
     field_trials_ = std::string(field_trials);
   }
 
+  // Sets field trials to pass to created PeerConnectionWrapper key:ed on
+  // debug_name. Must be called before PeerConnectionWrappers are created.
+  void SetFieldTrials(absl::string_view debug_name,
+                      absl::string_view field_trials) {
+    RTC_CHECK(caller_ == nullptr);
+    RTC_CHECK(callee_ == nullptr);
+    field_trials_overrides_[std::string(debug_name)] = field_trials;
+  }
+
   // When `event_log_factory` is null, the default implementation of the event
   // log factory will be used.
   std::unique_ptr<PeerConnectionIntegrationWrapper> CreatePeerConnectionWrapper(
@@ -1434,9 +1459,14 @@
     std::unique_ptr<PeerConnectionIntegrationWrapper> client(
         new PeerConnectionIntegrationWrapper(debug_name));
 
+    std::string field_trials = field_trials_;
+    auto it = field_trials_overrides_.find(debug_name);
+    if (it != field_trials_overrides_.end()) {
+      field_trials = it->second;
+    }
     if (!client->Init(options, &modified_config, std::move(dependencies),
                       fss_.get(), network_thread_.get(), worker_thread_.get(),
-                      FieldTrials::CreateNoGlobal(field_trials_),
+                      FieldTrials::CreateNoGlobal(field_trials),
                       std::move(event_log_factory), reset_encoder_factory,
                       reset_decoder_factory, create_media_engine)) {
       return nullptr;
@@ -1473,13 +1503,13 @@
     // callee PeerConnections.
     SdpSemantics original_semantics = sdp_semantics_;
     sdp_semantics_ = caller_semantics;
-    caller_ = CreatePeerConnectionWrapper("Caller", nullptr, nullptr,
+    caller_ = CreatePeerConnectionWrapper(kCallerName, nullptr, nullptr,
                                           PeerConnectionDependencies(nullptr),
                                           nullptr,
                                           /*reset_encoder_factory=*/false,
                                           /*reset_decoder_factory=*/false);
     sdp_semantics_ = callee_semantics;
-    callee_ = CreatePeerConnectionWrapper("Callee", nullptr, nullptr,
+    callee_ = CreatePeerConnectionWrapper(kCalleeName, nullptr, nullptr,
                                           PeerConnectionDependencies(nullptr),
                                           nullptr,
                                           /*reset_encoder_factory=*/false,
@@ -1491,12 +1521,12 @@
   bool CreatePeerConnectionWrappersWithConfig(
       const PeerConnectionInterface::RTCConfiguration& caller_config,
       const PeerConnectionInterface::RTCConfiguration& callee_config) {
-    caller_ = CreatePeerConnectionWrapper("Caller", nullptr, &caller_config,
+    caller_ = CreatePeerConnectionWrapper(kCallerName, nullptr, &caller_config,
                                           PeerConnectionDependencies(nullptr),
                                           nullptr,
                                           /*reset_encoder_factory=*/false,
                                           /*reset_decoder_factory=*/false);
-    callee_ = CreatePeerConnectionWrapper("Callee", nullptr, &callee_config,
+    callee_ = CreatePeerConnectionWrapper(kCalleeName, nullptr, &callee_config,
                                           PeerConnectionDependencies(nullptr),
                                           nullptr,
                                           /*reset_encoder_factory=*/false,
@@ -1510,12 +1540,12 @@
       const PeerConnectionInterface::RTCConfiguration& callee_config,
       PeerConnectionDependencies callee_dependencies) {
     caller_ =
-        CreatePeerConnectionWrapper("Caller", nullptr, &caller_config,
+        CreatePeerConnectionWrapper(kCallerName, nullptr, &caller_config,
                                     std::move(caller_dependencies), nullptr,
                                     /*reset_encoder_factory=*/false,
                                     /*reset_decoder_factory=*/false);
     callee_ =
-        CreatePeerConnectionWrapper("Callee", nullptr, &callee_config,
+        CreatePeerConnectionWrapper(kCalleeName, nullptr, &callee_config,
                                     std::move(callee_dependencies), nullptr,
                                     /*reset_encoder_factory=*/false,
                                     /*reset_decoder_factory=*/false);
@@ -1525,12 +1555,12 @@
   bool CreatePeerConnectionWrappersWithOptions(
       const PeerConnectionFactory::Options& caller_options,
       const PeerConnectionFactory::Options& callee_options) {
-    caller_ = CreatePeerConnectionWrapper("Caller", &caller_options, nullptr,
+    caller_ = CreatePeerConnectionWrapper(kCallerName, &caller_options, nullptr,
                                           PeerConnectionDependencies(nullptr),
                                           nullptr,
                                           /*reset_encoder_factory=*/false,
                                           /*reset_decoder_factory=*/false);
-    callee_ = CreatePeerConnectionWrapper("Callee", &callee_options, nullptr,
+    callee_ = CreatePeerConnectionWrapper(kCalleeName, &callee_options, nullptr,
                                           PeerConnectionDependencies(nullptr),
                                           nullptr,
                                           /*reset_encoder_factory=*/false,
@@ -1541,10 +1571,10 @@
   bool CreatePeerConnectionWrappersWithFakeRtcEventLog() {
     PeerConnectionInterface::RTCConfiguration default_config;
     caller_ = CreatePeerConnectionWrapperWithFakeRtcEventLog(
-        "Caller", nullptr, &default_config,
+        kCallerName, nullptr, &default_config,
         PeerConnectionDependencies(nullptr));
     callee_ = CreatePeerConnectionWrapperWithFakeRtcEventLog(
-        "Callee", nullptr, &default_config,
+        kCalleeName, nullptr, &default_config,
         PeerConnectionDependencies(nullptr));
     return caller_ && callee_;
   }
@@ -1565,12 +1595,12 @@
 
   bool CreateOneDirectionalPeerConnectionWrappers(bool caller_to_callee) {
     caller_ = CreatePeerConnectionWrapper(
-        "Caller", nullptr, nullptr, PeerConnectionDependencies(nullptr),
+        kCallerName, nullptr, nullptr, PeerConnectionDependencies(nullptr),
         nullptr,
         /*reset_encoder_factory=*/!caller_to_callee,
         /*reset_decoder_factory=*/caller_to_callee);
     callee_ = CreatePeerConnectionWrapper(
-        "Callee", nullptr, nullptr, PeerConnectionDependencies(nullptr),
+        kCalleeName, nullptr, nullptr, PeerConnectionDependencies(nullptr),
         nullptr,
         /*reset_encoder_factory=*/caller_to_callee,
         /*reset_decoder_factory=*/!caller_to_callee);
@@ -1578,13 +1608,13 @@
   }
 
   bool CreatePeerConnectionWrappersWithoutMediaEngine() {
-    caller_ = CreatePeerConnectionWrapper("Caller", nullptr, nullptr,
+    caller_ = CreatePeerConnectionWrapper(kCallerName, nullptr, nullptr,
                                           PeerConnectionDependencies(nullptr),
                                           nullptr,
                                           /*reset_encoder_factory=*/false,
                                           /*reset_decoder_factory=*/false,
                                           /*create_media_engine=*/false);
-    callee_ = CreatePeerConnectionWrapper("Callee", nullptr, nullptr,
+    callee_ = CreatePeerConnectionWrapper(kCalleeName, nullptr, nullptr,
                                           PeerConnectionDependencies(nullptr),
                                           nullptr,
                                           /*reset_encoder_factory=*/false,
@@ -1886,6 +1916,7 @@
   std::unique_ptr<PeerConnectionIntegrationWrapper> caller_;
   std::unique_ptr<PeerConnectionIntegrationWrapper> callee_;
   std::string field_trials_;
+  std::map<std::string, std::string> field_trials_overrides_;
 };
 
 }  // namespace webrtc