dtls-in-stun: Only read IceConfig.dtls_handshake_in_stun in 1 place.
This patch fixes 2 problems, that both stem from the same root,
the the code had scattered checking IceConfig.dtls_handshake_in_stun
but that value was not updated when we discovered that remote peer
does not support piggy backing, "restart".
the patch modifies the code so that the IceConfig.dtls_handshake_in_stun is only checked during DtlsTransport::SetupDtls.
The problems fixed are:
1) P2PTransportChannel correctly (set/does not set) Connection::RegisterDtlsPiggyback based on the existing
dtls_stun_piggyback_callbacks_.empty() (that is reset when we detect
that peer does not support piggybacking) rather than on config which
is unchanged.
2) The timeout was not set properly during "restart",
properly == the value based upon ice rtt, but was still using the
"infinitely high" value for piggybacking.
This is tested with the DtlsTransportVersionTest which now runs
(optionally) with dtls-in-stun piggybacking.
BUG=webrtc:367395350
Change-Id: Ib511bcd1d3371a2132cefe26a3c49372208735ac
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/379760
Commit-Queue: Jonas Oreland <jonaso@webrtc.org>
Auto-Submit: Jonas Oreland <jonaso@webrtc.org>
Reviewed-by: Harald Alvestrand <hta@webrtc.org>
Cr-Commit-Position: refs/heads/main@{#44044}
diff --git a/p2p/BUILD.gn b/p2p/BUILD.gn
index 1277c71..368f9ba 100644
--- a/p2p/BUILD.gn
+++ b/p2p/BUILD.gn
@@ -911,6 +911,7 @@
":candidate_pair_interface",
":connection",
":connection_info",
+ ":dtls_stun_piggyback_controller",
":ice_transport_internal",
":port",
":transport_description",
diff --git a/p2p/base/p2p_transport_channel.cc b/p2p/base/p2p_transport_channel.cc
index cdb4359..9ff2512 100644
--- a/p2p/base/p2p_transport_channel.cc
+++ b/p2p/base/p2p_transport_channel.cc
@@ -313,7 +313,7 @@
[this](webrtc::RTCErrorOr<const StunUInt64Attribute*> delta_ack) {
GoogDeltaAckReceived(std::move(delta_ack));
});
- if (config_.dtls_handshake_in_stun) {
+ if (!dtls_stun_piggyback_callbacks_.empty()) {
connection->RegisterDtlsPiggyback(DtlsStunPiggybackCallbacks(
[&](auto request) {
return dtls_stun_piggyback_callbacks_.send_data(request);
@@ -2246,8 +2246,9 @@
}
SignalWritableState(this);
- if (config_.dtls_handshake_in_stun &&
+ if (writable_ && selected_connection_ &&
!dtls_stun_piggyback_callbacks_.empty()) {
+ // TODO(webrtc:367395350): Move this into DtlsTransport somehow.
// Need to STUN ping here to get the last bit of the DTLS handshake across
// as quickly as possible. Only done when DTLS-in-STUN is configured
// and the data callback has not been reset due to lack of support.
diff --git a/p2p/dtls/dtls_ice_integrationtest.cc b/p2p/dtls/dtls_ice_integrationtest.cc
index 15aff16..f9b6666 100644
--- a/p2p/dtls/dtls_ice_integrationtest.cc
+++ b/p2p/dtls/dtls_ice_integrationtest.cc
@@ -108,6 +108,18 @@
false),
client_dtls_stun_piggyback_(std::get<0>(GetParam())),
server_dtls_stun_piggyback_(std::get<1>(GetParam())) {
+ // Enable(or disable) the dtls_in_stun parameter before
+ // DTLS is negotiated.
+ cricket::IceConfig client_config;
+ client_config.continual_gathering_policy = GATHER_CONTINUALLY;
+ client_config.dtls_handshake_in_stun = client_dtls_stun_piggyback_;
+ client_ice_->SetIceConfig(client_config);
+
+ cricket::IceConfig server_config;
+ server_config.dtls_handshake_in_stun = server_dtls_stun_piggyback_;
+ server_config.continual_gathering_policy = GATHER_CONTINUALLY;
+ server_ice_->SetIceConfig(server_config);
+
// Setup ICE.
client_ice_->SetIceParameters(client_ice_parameters_);
client_ice_->SetRemoteIceParameters(server_ice_parameters_);
@@ -141,6 +153,18 @@
~DtlsIceIntegrationTest() = default;
+ static int CountWritableConnections(IceTransportInternal* ice) {
+ IceTransportStats stats;
+ ice->GetStats(&stats);
+ int count = 0;
+ for (const auto& con : stats.connection_infos) {
+ if (con.writable) {
+ count++;
+ }
+ }
+ return count;
+ }
+
rtc::FakeNetworkManager network_manager_;
std::unique_ptr<rtc::VirtualSocketServer> ss_;
std::unique_ptr<rtc::BasicPacketSocketFactory> socket_factory_;
@@ -165,14 +189,6 @@
};
TEST_P(DtlsIceIntegrationTest, SmokeTest) {
- cricket::IceConfig client_config;
- client_config.dtls_handshake_in_stun = client_dtls_stun_piggyback_;
- client_ice_->SetIceConfig(client_config);
-
- cricket::IceConfig server_config;
- server_config.dtls_handshake_in_stun = server_dtls_stun_piggyback_;
- server_ice_->SetIceConfig(server_config);
-
client_ice_->MaybeStartGathering();
server_ice_->MaybeStartGathering();
@@ -188,6 +204,18 @@
client_dtls_stun_piggyback_ && server_dtls_stun_piggyback_);
EXPECT_EQ(server_dtls_.IsDtlsPiggybackSupportedByPeer(),
client_dtls_stun_piggyback_ && server_dtls_stun_piggyback_);
+
+ // Validate that we can add new Connections (that become writable).
+ network_manager_.AddInterface(rtc::SocketAddress("192.168.2.1", 0));
+ EXPECT_THAT(webrtc::WaitUntil(
+ [&] {
+ return CountWritableConnections(client_ice_.get()) > 1 &&
+ CountWritableConnections(server_ice_.get()) > 1;
+ },
+ IsTrue(),
+ {.timeout = webrtc::TimeDelta::Millis(kDefaultTimeout),
+ .clock = &fake_clock_}),
+ webrtc::IsRtcOk());
}
// Test cases are parametrized by
@@ -197,12 +225,14 @@
INSTANTIATE_TEST_SUITE_P(
DtlsStunPiggybackingIntegrationTest,
DtlsIceIntegrationTest,
- ::testing::Values(
- std::make_tuple(false, false, rtc::SSL_PROTOCOL_DTLS_12),
- std::make_tuple(true, false, rtc::SSL_PROTOCOL_DTLS_12),
- std::make_tuple(false, true, rtc::SSL_PROTOCOL_DTLS_12),
- std::make_tuple(true, true, rtc::SSL_PROTOCOL_DTLS_12),
- // Skip negative cases that are behaving similar for DTLS 1.3
- std::make_tuple(true, true, rtc::SSL_PROTOCOL_DTLS_13)));
+ ::testing::Values(std::make_tuple(false, false, rtc::SSL_PROTOCOL_DTLS_12),
+ std::make_tuple(true, false, rtc::SSL_PROTOCOL_DTLS_12),
+ std::make_tuple(false, true, rtc::SSL_PROTOCOL_DTLS_12),
+ std::make_tuple(true, true, rtc::SSL_PROTOCOL_DTLS_12),
+
+ std::make_tuple(false, false, rtc::SSL_PROTOCOL_DTLS_13),
+ std::make_tuple(true, false, rtc::SSL_PROTOCOL_DTLS_13),
+ std::make_tuple(false, true, rtc::SSL_PROTOCOL_DTLS_13),
+ std::make_tuple(true, true, rtc::SSL_PROTOCOL_DTLS_13)));
} // namespace cricket
diff --git a/p2p/dtls/dtls_stun_piggyback_controller.cc b/p2p/dtls/dtls_stun_piggyback_controller.cc
index 602f623..9fedd89 100644
--- a/p2p/dtls/dtls_stun_piggyback_controller.cc
+++ b/p2p/dtls/dtls_stun_piggyback_controller.cc
@@ -175,4 +175,17 @@
dtls_data_callback_(data->array_view());
}
+void DtlsStunPiggybackController::SetEnabled(bool enabled) {
+ RTC_DCHECK_RUN_ON(&sequence_checker_);
+ enabled_ = enabled;
+ if (!enabled) {
+ state_ = State::OFF;
+ }
+}
+
+bool DtlsStunPiggybackController::enabled() const {
+ RTC_DCHECK_RUN_ON(&sequence_checker_);
+ return enabled_;
+}
+
} // namespace cricket
diff --git a/p2p/dtls/dtls_stun_piggyback_controller.h b/p2p/dtls/dtls_stun_piggyback_controller.h
index c25d93f..ebf23c6d 100644
--- a/p2p/dtls/dtls_stun_piggyback_controller.h
+++ b/p2p/dtls/dtls_stun_piggyback_controller.h
@@ -38,6 +38,11 @@
dtls_data_callback);
~DtlsStunPiggybackController();
+ // Initially set from IceConfig.dtls_handshake_in_stun
+ // but is also set to FALSE before restarting handshake.
+ void SetEnabled(bool enabled);
+ bool enabled() const;
+
enum class State {
// We don't know if peer support DTLS piggybacked in STUN.
// We will piggyback DTLS until we get a piggybacked response
@@ -79,6 +84,7 @@
const StunByteStringAttribute* ack);
private:
+ bool enabled_ RTC_GUARDED_BY(sequence_checker_) = false;
State state_ RTC_GUARDED_BY(sequence_checker_) = State::TENTATIVE;
rtc::Buffer pending_packet_ RTC_GUARDED_BY(sequence_checker_);
absl::AnyInvocable<void(rtc::ArrayView<const uint8_t>)> dtls_data_callback_;
diff --git a/p2p/dtls/dtls_transport.cc b/p2p/dtls/dtls_transport.cc
index 219c91f..002aa83 100644
--- a/p2p/dtls/dtls_transport.cc
+++ b/p2p/dtls/dtls_transport.cc
@@ -109,8 +109,7 @@
// If we try to use DTLS-in-STUN, DTLS packets will be sent as part of STUN
// packets and are consumed here.
- if (ice_transport_->config().dtls_handshake_in_stun &&
- dtls_stun_piggyback_controller_ &&
+ if (dtls_stun_piggyback_controller_ &&
dtls_stun_piggyback_controller_->MaybeConsumePacket(data)) {
written = data.size();
return rtc::SR_SUCCESS;
@@ -385,14 +384,20 @@
return dtls_ ? dtls_->ExportSrtpKeyingMaterial(keying_material) : false;
}
-bool DtlsTransport::SetupDtls() {
+bool DtlsTransport::SetupDtls(bool disable_piggybacking) {
RTC_DCHECK(dtls_role_);
+ // Look at both config...and argument (used on restart).
+ const bool enable_piggybacking =
+ ice_transport_->config().dtls_handshake_in_stun && !disable_piggybacking;
+
{
auto downward = std::make_unique<StreamInterfaceChannel>(ice_transport_);
StreamInterfaceChannel* downward_ptr = downward.get();
- downward_ptr->SetDtlsStunPiggybackController(
- &dtls_stun_piggyback_controller_);
+ if (enable_piggybacking) {
+ downward_ptr->SetDtlsStunPiggybackController(
+ &dtls_stun_piggyback_controller_);
+ }
dtls_ = rtc::SSLStreamAdapter::Create(
std::move(downward),
[this](rtc::SSLHandshakeError error) { OnDtlsHandshakeError(error); },
@@ -404,6 +409,11 @@
downward_ = downward_ptr;
}
+ if (!enable_piggybacking) {
+ ice_transport_->ResetDtlsStunPiggybackCallbacks();
+ }
+ dtls_stun_piggyback_controller_.SetEnabled(enable_piggybacking);
+
dtls_->SetIdentity(local_certificate_->identity()->Clone());
dtls_->SetMaxProtocolVersion(ssl_max_version_);
dtls_->SetServerRole(*dtls_role_);
@@ -617,18 +627,16 @@
// Recreate the DTLS session. Note: this assumes we can consider
// the previous DTLS session state beyond repair and no packet
// reached the peer.
- if (ice_transport_->config().dtls_handshake_in_stun && dtls_ &&
+ if (dtls_stun_piggyback_controller_.enabled() && dtls_ &&
!was_ever_connected_ && !IsDtlsPiggybackSupportedByPeer() &&
(dtls_state() == webrtc::DtlsTransportState::kConnecting ||
dtls_state() == webrtc::DtlsTransportState::kNew)) {
RTC_LOG(LS_INFO) << "DTLS piggybacking not supported, restarting...";
- ice_transport_->ResetDtlsStunPiggybackCallbacks();
- downward_->SetDtlsStunPiggybackController(nullptr);
dtls_.reset(nullptr);
set_dtls_state(webrtc::DtlsTransportState::kNew);
set_writable(false);
- if (!SetupDtls()) {
+ if (!SetupDtls(/* disable_piggybacking= */ true)) {
RTC_LOG(LS_ERROR)
<< "Failed to setup DTLS again after attempted piggybacking.";
set_dtls_state(webrtc::DtlsTransportState::kFailed);
@@ -856,10 +864,9 @@
RTC_DCHECK(ice_transport_);
// When adding the DTLS handshake in STUN we want to call StartSSL even
// before the ICE transport is ready.
- bool start_early_for_dtls_in_stun =
- ice_transport_->config().dtls_handshake_in_stun;
- if (dtls_ && (ice_transport_->writable() || start_early_for_dtls_in_stun)) {
- ConfigureHandshakeTimeout(start_early_for_dtls_in_stun);
+ if (dtls_ && (ice_transport_->writable() ||
+ dtls_stun_piggyback_controller_.enabled())) {
+ ConfigureHandshakeTimeout();
if (dtls_->StartSSL()) {
// This should never fail:
@@ -947,8 +954,9 @@
SendDtlsHandshakeError(error);
}
-void DtlsTransport::ConfigureHandshakeTimeout(bool uses_dtls_in_stun) {
+void DtlsTransport::ConfigureHandshakeTimeout() {
RTC_DCHECK(dtls_);
+ bool uses_dtls_in_stun = dtls_stun_piggyback_controller_.enabled();
std::optional<int> rtt_ms = ice_transport_->GetRttEstimate();
if (uses_dtls_in_stun) {
// Configure a very high timeout to effectively disable the DTLS timeout
@@ -984,19 +992,17 @@
bool DtlsTransport::IsDtlsPiggybackSupportedByPeer() {
RTC_DCHECK_RUN_ON(&thread_checker_);
RTC_DCHECK(ice_transport_);
- return ice_transport_->config().dtls_handshake_in_stun &&
- dtls_stun_piggyback_controller_.state() !=
- DtlsStunPiggybackController::State::OFF;
+ return dtls_stun_piggyback_controller_.state() !=
+ DtlsStunPiggybackController::State::OFF;
}
bool DtlsTransport::IsDtlsPiggybackHandshaking() {
RTC_DCHECK_RUN_ON(&thread_checker_);
RTC_DCHECK(ice_transport_);
- return ice_transport_->config().dtls_handshake_in_stun &&
- (dtls_stun_piggyback_controller_.state() ==
- DtlsStunPiggybackController::State::TENTATIVE ||
- dtls_stun_piggyback_controller_.state() ==
- DtlsStunPiggybackController::State::CONFIRMED);
+ return dtls_stun_piggyback_controller_.state() ==
+ DtlsStunPiggybackController::State::TENTATIVE ||
+ dtls_stun_piggyback_controller_.state() ==
+ DtlsStunPiggybackController::State::CONFIRMED;
}
} // namespace cricket
diff --git a/p2p/dtls/dtls_transport.h b/p2p/dtls/dtls_transport.h
index b0953b5..9c7b0f9 100644
--- a/p2p/dtls/dtls_transport.h
+++ b/p2p/dtls/dtls_transport.h
@@ -78,8 +78,8 @@
private:
IceTransportInternal* const ice_transport_; // owned by DtlsTransport
- DtlsStunPiggybackController*
- dtls_stun_piggyback_controller_; // owned by DtlsTransport
+ DtlsStunPiggybackController* dtls_stun_piggyback_controller_ =
+ nullptr; // owned by DtlsTransport
rtc::StreamState state_ RTC_GUARDED_BY(callback_sequence_);
rtc::BufferQueue packets_ RTC_GUARDED_BY(callback_sequence_);
};
@@ -228,9 +228,6 @@
return sb.Release();
}
- void SetPiggybackDtlsDataCallback(
- absl::AnyInvocable<void(rtc::PacketTransportInternal* transport,
- const rtc::ReceivedPacket& packet)> callback);
bool IsDtlsPiggybackSupportedByPeer();
private:
@@ -245,16 +242,19 @@
void OnReceivingState(rtc::PacketTransportInternal* transport);
void OnDtlsEvent(int sig, int err);
void OnNetworkRouteChanged(std::optional<rtc::NetworkRoute> network_route);
- bool SetupDtls();
+ bool SetupDtls(bool disable_piggybacking = false);
void MaybeStartDtls();
bool HandleDtlsPacket(rtc::ArrayView<const uint8_t> payload);
void OnDtlsHandshakeError(rtc::SSLHandshakeError error);
- void ConfigureHandshakeTimeout(bool uses_dtls_in_stun);
+ void ConfigureHandshakeTimeout();
void set_receiving(bool receiving);
void set_writable(bool writable);
// Sets the DTLS state, signaling if necessary.
void set_dtls_state(webrtc::DtlsTransportState state);
+ void SetPiggybackDtlsDataCallback(
+ absl::AnyInvocable<void(rtc::PacketTransportInternal* transport,
+ const rtc::ReceivedPacket& packet)> callback);
RTC_NO_UNIQUE_ADDRESS webrtc::SequenceChecker thread_checker_;
diff --git a/p2p/dtls/dtls_transport_unittest.cc b/p2p/dtls/dtls_transport_unittest.cc
index 6740648..2485d5d 100644
--- a/p2p/dtls/dtls_transport_unittest.cc
+++ b/p2p/dtls/dtls_transport_unittest.cc
@@ -18,6 +18,7 @@
#include <optional>
#include <set>
#include <string>
+#include <tuple>
#include <utility>
#include <vector>
@@ -52,10 +53,9 @@
#include "test/gtest.h"
#include "test/wait_until.h"
-#define MAYBE_SKIP_TEST(feature) \
- if (!(rtc::SSLStreamAdapter::feature())) { \
- RTC_LOG(LS_INFO) << #feature " feature disabled... skipping"; \
- return; \
+#define MAYBE_SKIP_TEST(feature) \
+ if (!(rtc::SSLStreamAdapter::feature())) { \
+ GTEST_SKIP() << #feature " feature disabled... skipping"; \
}
namespace cricket {
@@ -155,6 +155,16 @@
return true;
}
+ // Connect the fake ICE transports so that packets flows from one to other.
+ bool ConnectIceTransport(DtlsTestClient* peer) {
+ fake_ice_transport()->SetDestinationNotWritable(peer->fake_ice_transport());
+ return true;
+ }
+
+ bool SendIcePing() { return fake_ice_transport_->SendIcePing(); }
+
+ bool SendIcePingConf() { return fake_ice_transport_->SendIcePingConf(); }
+
int received_dtls_client_hellos() const {
return received_dtls_client_hellos_;
}
@@ -576,15 +586,41 @@
{rtc::kDtls13VersionBytes, dtls_13_handshake_events},
};
+struct EndpointConfig {
+ rtc::SSLProtocolVersion max_protocol_version;
+ bool dtls_in_stun = false;
+};
+
class DtlsTransportVersionTest
: public DtlsTransportTestBase,
public ::testing::TestWithParam<
- ::testing::tuple<rtc::SSLProtocolVersion, rtc::SSLProtocolVersion>> {
+ std::tuple<EndpointConfig, EndpointConfig>> {
public:
void Prepare() {
PrepareDtls(rtc::KT_DEFAULT);
- SetMaxProtocolVersions(::testing::get<0>(GetParam()),
- ::testing::get<1>(GetParam()));
+ SetMaxProtocolVersions(std::get<0>(GetParam()).max_protocol_version,
+ std::get<1>(GetParam()).max_protocol_version);
+
+ client1_.SetupTransports(ICEROLE_CONTROLLING);
+ client2_.SetupTransports(ICEROLE_CONTROLLED);
+ client1_.dtls_transport()->SetDtlsRole(rtc::SSL_CLIENT);
+ client2_.dtls_transport()->SetDtlsRole(rtc::SSL_SERVER);
+
+ if (std::get<0>(GetParam()).dtls_in_stun) {
+ auto config = client1_.fake_ice_transport()->config();
+ config.dtls_handshake_in_stun = true;
+ client1_.fake_ice_transport()->SetIceConfig(config);
+ }
+ if (std::get<1>(GetParam()).dtls_in_stun) {
+ auto config = client2_.fake_ice_transport()->config();
+ config.dtls_handshake_in_stun = true;
+ client2_.fake_ice_transport()->SetIceConfig(config);
+ }
+
+ SetRemoteFingerprintFromCert(client1_.dtls_transport(),
+ client2_.certificate());
+ SetRemoteFingerprintFromCert(client2_.dtls_transport(),
+ client1_.certificate());
}
// Run DTLS handshake.
@@ -592,8 +628,6 @@
// - drop packets as specified in `packets_to_drop`
std::pair</* dtls_version_bytes*/ int, std::vector<HandshakeTestEvent>>
RunHandshake(std::set<unsigned> packets_to_drop) {
- Negotiate(/* client1_server= */ false);
-
std::vector<HandshakeTestEvent> events;
auto start_time_ns = fake_clock_.TimeNanos();
client1_.fake_ice_transport()->set_rtt_estimate(50, true);
@@ -642,7 +676,11 @@
return LogSend("server", diff_ms, drop, data, len);
});
- EXPECT_TRUE(client1_.Connect(&client2_, false));
+ EXPECT_TRUE(client1_.ConnectIceTransport(&client2_));
+ client1_.SendIcePing();
+ client2_.SendIcePingConf();
+ client2_.SendIcePing();
+ client1_.SendIcePingConf();
EXPECT_THAT(webrtc::WaitUntil(
[&] {
@@ -661,12 +699,13 @@
auto dtls_version_bytes = client1_.GetVersionBytes();
EXPECT_EQ(dtls_version_bytes, client2_.GetVersionBytes());
- return std::make_pair(*dtls_version_bytes, std::move(events));
+ return std::make_pair(dtls_version_bytes.value_or(0), std::move(events));
}
int GetExpectedDtlsVersionBytes() {
- int version = std::min(static_cast<int>(::testing::get<0>(GetParam())),
- static_cast<int>(::testing::get<1>(GetParam())));
+ int version = std::min(
+ static_cast<int>(std::get<0>(GetParam()).max_protocol_version),
+ static_cast<int>(std::get<1>(GetParam()).max_protocol_version));
if (version == rtc::SSL_PROTOCOL_DTLS_13) {
return rtc::kDtls13VersionBytes;
} else {
@@ -711,17 +750,40 @@
}
};
+static const EndpointConfig kEndpointVariants[] = {
+ {
+ .max_protocol_version = rtc::SSL_PROTOCOL_DTLS_10,
+ .dtls_in_stun = false,
+ },
+ {
+ .max_protocol_version = rtc::SSL_PROTOCOL_DTLS_12,
+ .dtls_in_stun = false,
+ },
+ {
+ .max_protocol_version = rtc::SSL_PROTOCOL_DTLS_13,
+ .dtls_in_stun = false,
+ },
+ {
+ .max_protocol_version = rtc::SSL_PROTOCOL_DTLS_10,
+ .dtls_in_stun = true,
+ },
+ {
+ .max_protocol_version = rtc::SSL_PROTOCOL_DTLS_12,
+ .dtls_in_stun = true,
+ },
+ {
+ .max_protocol_version = rtc::SSL_PROTOCOL_DTLS_13,
+ .dtls_in_stun = true,
+ },
+};
+
// Will test every combination of 1.0/1.2/1.3 on the client and server.
// DTLS will negotiate an effective version (the min of client & sewrver).
INSTANTIATE_TEST_SUITE_P(
DtlsTransportVersionTest,
DtlsTransportVersionTest,
- ::testing::Combine(::testing::Values(rtc::SSL_PROTOCOL_DTLS_10,
- rtc::SSL_PROTOCOL_DTLS_12,
- rtc::SSL_PROTOCOL_DTLS_13),
- ::testing::Values(rtc::SSL_PROTOCOL_DTLS_10,
- rtc::SSL_PROTOCOL_DTLS_12,
- rtc::SSL_PROTOCOL_DTLS_13)));
+ ::testing::Combine(testing::ValuesIn(kEndpointVariants),
+ testing::ValuesIn(kEndpointVariants)));
// Test that an acceptable cipher suite is negotiated when different versions
// of DTLS are supported. Note that it's IsAcceptableCipher that does the actual
@@ -732,6 +794,10 @@
}
TEST_P(DtlsTransportVersionTest, HandshakeFlights) {
+ if (std::get<0>(GetParam()).dtls_in_stun &&
+ std::get<1>(GetParam()).dtls_in_stun) {
+ GTEST_SKIP() << "This test does not support dtls in stun";
+ }
Prepare();
auto [dtls_version_bytes, events] = RunHandshake({});
@@ -743,6 +809,10 @@
TEST_P(DtlsTransportVersionTest, HandshakeLoseFirstClientPacket) {
MAYBE_SKIP_TEST(IsBoringSsl);
+ if (std::get<0>(GetParam()).dtls_in_stun &&
+ std::get<1>(GetParam()).dtls_in_stun) {
+ GTEST_SKIP() << "This test does not support dtls in stun";
+ }
Prepare();
auto [dtls_version_bytes, events] = RunHandshake({/* packet_num= */ 0});
@@ -758,6 +828,10 @@
TEST_P(DtlsTransportVersionTest, HandshakeLoseSecondClientPacket) {
MAYBE_SKIP_TEST(IsBoringSsl);
+ if (std::get<0>(GetParam()).dtls_in_stun &&
+ std::get<1>(GetParam()).dtls_in_stun) {
+ GTEST_SKIP() << "This test does not support dtls in stun";
+ }
Prepare();
auto [dtls_version_bytes, events] = RunHandshake({/* packet_num= */ 2});
@@ -817,7 +891,7 @@
};
break;
default:
- RTC_CHECK(false) << "Unknown dtls version bytes: " << dtls_version_bytes;
+ FAIL() << "Unknown dtls version bytes: " << dtls_version_bytes;
}
EXPECT_EQ(events, expect);
}
diff --git a/p2p/test/fake_ice_transport.h b/p2p/test/fake_ice_transport.h
index 9b9df19..d365ddc 100644
--- a/p2p/test/fake_ice_transport.h
+++ b/p2p/test/fake_ice_transport.h
@@ -32,6 +32,7 @@
#include "p2p/base/ice_transport_internal.h"
#include "p2p/base/port.h"
#include "p2p/base/transport_description.h"
+#include "p2p/dtls/dtls_stun_piggyback_callbacks.h"
#include "rtc_base/async_packet_socket.h"
#include "rtc_base/checks.h"
#include "rtc_base/copy_on_write_buffer.h"
@@ -118,6 +119,24 @@
}
}
+ void SetDestinationNotWritable(FakeIceTransport* dest) {
+ RTC_DCHECK_RUN_ON(network_thread_);
+ if (dest == dest_) {
+ return;
+ }
+ RTC_DCHECK(!dest || !dest_)
+ << "Changing fake destination from one to another is not supported.";
+
+ if (dest) {
+ RTC_DCHECK_RUN_ON(dest->network_thread_);
+ dest->dest_ = this;
+ } else if (dest_) {
+ RTC_DCHECK_RUN_ON(dest_->network_thread_);
+ dest_->dest_ = nullptr;
+ }
+ dest_ = dest;
+ }
+
void SetTransportState(webrtc::IceTransportState state,
IceTransportState legacy_state) {
RTC_DCHECK_RUN_ON(network_thread_);
@@ -415,6 +434,61 @@
}
}
+ void ResetDtlsStunPiggybackCallbacks() override {
+ dtls_stun_piggyback_callbacks_.reset();
+ }
+ void SetDtlsStunPiggybackCallbacks(
+ DtlsStunPiggybackCallbacks&& callbacks) override {
+ RTC_LOG(LS_INFO) << name_ << ": SetDtlsStunPiggybackCallbacks";
+ dtls_stun_piggyback_callbacks_ = std::move(callbacks);
+ }
+
+ bool SendIcePing() {
+ RTC_DCHECK_RUN_ON(network_thread_);
+ RTC_DLOG(LS_INFO) << name_ << ": SendIcePing()";
+ auto msg = std::make_unique<IceMessage>(STUN_BINDING_REQUEST);
+ MaybeAddDtlsPiggybackingAttributes(msg.get());
+ msg->AddFingerprint();
+ rtc::ByteBufferWriter buf;
+ msg->Write(&buf);
+ SendPacketInternal(rtc::CopyOnWriteBuffer(buf.DataView()));
+ return true;
+ }
+
+ void MaybeAddDtlsPiggybackingAttributes(StunMessage* msg) {
+ if (dtls_stun_piggyback_callbacks_.empty()) {
+ return;
+ }
+
+ const auto& [attr, ack] = dtls_stun_piggyback_callbacks_.send_data(
+ static_cast<StunMessageType>(msg->type()));
+
+ RTC_DLOG(LS_INFO) << name_ << ": Adding attr: " << attr.has_value()
+ << " ack: " << ack.has_value() << " to stun message: "
+ << StunMethodToString(msg->type());
+
+ if (attr) {
+ msg->AddAttribute(std::make_unique<StunByteStringAttribute>(
+ STUN_ATTR_META_DTLS_IN_STUN, *attr));
+ }
+ if (ack) {
+ msg->AddAttribute(std::make_unique<StunByteStringAttribute>(
+ STUN_ATTR_META_DTLS_IN_STUN_ACK, *ack));
+ }
+ }
+
+ bool SendIcePingConf() {
+ RTC_DCHECK_RUN_ON(network_thread_);
+ RTC_DLOG(LS_INFO) << name_ << ": SendIcePingConf()";
+ auto msg = std::make_unique<IceMessage>(STUN_BINDING_RESPONSE);
+ MaybeAddDtlsPiggybackingAttributes(msg.get());
+ msg->AddFingerprint();
+ rtc::ByteBufferWriter buf;
+ msg->Write(&buf);
+ SendPacketInternal(rtc::CopyOnWriteBuffer(buf.DataView()));
+ return true;
+ }
+
private:
void set_writable(bool writable)
RTC_EXCLUSIVE_LOCKS_REQUIRED(network_thread_) {
@@ -449,6 +523,25 @@
void ReceivePacketInternal(const rtc::CopyOnWriteBuffer& packet) {
RTC_DCHECK_RUN_ON(network_thread_);
auto now = rtc::TimeMicros();
+ if (auto msg = GetStunMessage(packet)) {
+ const auto* dtls_piggyback_attr =
+ msg->GetByteString(STUN_ATTR_META_DTLS_IN_STUN);
+ const auto* dtls_piggyback_ack =
+ msg->GetByteString(STUN_ATTR_META_DTLS_IN_STUN_ACK);
+ RTC_DLOG(LS_INFO) << name_ << ": Got STUN message: "
+ << StunMethodToString(msg->type())
+ << " attr: " << (dtls_piggyback_attr != nullptr)
+ << " ack: " << (dtls_piggyback_ack != nullptr);
+ if (!dtls_stun_piggyback_callbacks_.empty()) {
+ dtls_stun_piggyback_callbacks_.recv_data(dtls_piggyback_attr,
+ dtls_piggyback_ack);
+ }
+
+ if (msg->type() == STUN_BINDING_RESPONSE) {
+ set_writable(true);
+ }
+ return;
+ }
if (packet_recv_filter_func_ && packet_recv_filter_func_(packet, now)) {
RTC_DLOG(LS_INFO) << name_
<< ": dropping packet at receiver len=" << packet.size()
@@ -460,6 +553,18 @@
}
}
+ std::unique_ptr<IceMessage> GetStunMessage(
+ const rtc::CopyOnWriteBuffer& packet) {
+ if (!StunMessage::ValidateFingerprint(packet.data<char>(), packet.size())) {
+ return nullptr;
+ }
+
+ std::unique_ptr<IceMessage> stun_msg(new IceMessage());
+ rtc::ByteBufferReader buf(rtc::MakeArrayView(packet.data(), packet.size()));
+ RTC_CHECK(stun_msg->Read(&buf));
+ return stun_msg;
+ }
+
const std::string name_;
const int component_;
FakeIceTransport* dest_ RTC_GUARDED_BY(network_thread_) = nullptr;
@@ -497,6 +602,7 @@
packet_send_filter_func_ RTC_GUARDED_BY(network_thread_) = nullptr;
absl::AnyInvocable<bool(const rtc::CopyOnWriteBuffer&, uint64_t)>
packet_recv_filter_func_ RTC_GUARDED_BY(network_thread_) = nullptr;
+ DtlsStunPiggybackCallbacks dtls_stun_piggyback_callbacks_;
};
class FakeIceTransportWrapper : public webrtc::IceTransportInterface {