dtls-in-stun: Prepare for multi packet handshakes
1) Change format of ACK attribute to be hash of packet (crc32).
this is good since we anyway can't retransmit parts of DTLS packet
and it also works for the encrypted handshake packages.
This change is "somewhat upgrade" compatible,
version prior to this does not actually check/use the content of
the ack attribute, only the presence.
This version will automatically clear pending packets
when a new flight begins => It works, but retransmit of 1
packet in a multi placket flight does not unless both
ens has this version....But multi packet handshake flight
are not yet "in play"
2) Keep track of individually acked packets.
Clear array of pending when getting an ACK. This is "already today"
and improvment of existing code.
3) Limit size of ACK to 4 packets,
this limit *should* never be reached...but there is a test for it :)
4) Readd restrictions removed in https://webrtc-review.googlesource.com/c/src/+/381102. It only worked "by accident",
and the use case is of zero value. Even hard to explain!
Maybe I'll take a stab after PQC is done.
---
With this, PQC handshake _almost_ work with dtls-in-stun.
BUG=webrtc:367395350
Change-Id: I582397babb099896ba56ef26e174f47f7e98d1d9
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/385401
Commit-Queue: Jonas Oreland <jonaso@webrtc.org>
Reviewed-by: Harald Alvestrand <hta@webrtc.org>
Cr-Commit-Position: refs/heads/main@{#44340}
diff --git a/p2p/BUILD.gn b/p2p/BUILD.gn
index cf0d539..63331cf 100644
--- a/p2p/BUILD.gn
+++ b/p2p/BUILD.gn
@@ -707,6 +707,7 @@
"../api:array_view",
"../rtc_base:byte_buffer",
"../rtc_base:checks",
+ "../rtc_base:crc32",
]
}
@@ -730,7 +731,10 @@
"../rtc_base:macromagic",
"../rtc_base:stringutils",
"../rtc_base/system:no_unique_address",
+ "//third_party/abseil-cpp/absl/container:flat_hash_set",
"//third_party/abseil-cpp/absl/functional:any_invocable",
+ "//third_party/abseil-cpp/absl/strings",
+ "//third_party/abseil-cpp/absl/strings:str_format",
"//third_party/abseil-cpp/absl/strings:string_view",
]
}
diff --git a/p2p/dtls/dtls_stun_piggyback_controller.cc b/p2p/dtls/dtls_stun_piggyback_controller.cc
index 7ca8094..138a12e 100644
--- a/p2p/dtls/dtls_stun_piggyback_controller.cc
+++ b/p2p/dtls/dtls_stun_piggyback_controller.cc
@@ -10,20 +10,25 @@
#include "p2p/dtls/dtls_stun_piggyback_controller.h"
+#include <algorithm>
#include <cstdint>
+#include <memory>
#include <optional>
#include <utility>
#include <vector>
+#include "absl/container/flat_hash_set.h"
#include "absl/functional/any_invocable.h"
#include "absl/strings/string_view.h"
#include "api/array_view.h"
#include "api/sequence_checker.h"
#include "api/transport/stun.h"
#include "p2p/dtls/dtls_utils.h"
+#include "rtc_base/buffer.h"
+#include "rtc_base/byte_buffer.h"
#include "rtc_base/checks.h"
#include "rtc_base/logging.h"
-#include "rtc_base/string_encode.h"
+#include "rtc_base/strings/str_join.h"
namespace webrtc {
@@ -44,7 +49,7 @@
// the last flight from the server.
// For DTLS 1.3 this is reversed since the handshake has one round trip less.
if ((is_dtls_client && !is_dtls13) || (!is_dtls_client && is_dtls13)) {
- pending_packet_.Clear();
+ pending_packets_.clear();
}
// Peer does not support this so fallback to a normal DTLS handshake
@@ -66,18 +71,26 @@
// is made for 1-packet at a time. Use the writing_packets_ variable to keep
// track of a full batch. The writing_packets_ is reset in Flush.
if (!writing_packets_) {
- pending_packet_.Clear();
+ pending_packets_.clear();
writing_packets_ = true;
}
- // Note: this overwrites the existing packets which is an issue
- // if this gets called with fragmented DTLS flights.
- pending_packet_.SetData(data);
+ // BoringSSL writes burst of packets...but the interface
+ // is made for 1-packet at a time. Use the writing_packets_ variable to keep
+ // track of a full batch. The writing_packets_ is reset in
+ // GetDataToPiggyback().
+ if (!writing_packets_) {
+ pending_packets_.clear();
+ writing_packets_ = true;
+ }
+ pending_packets_.push_back(std::make_pair(
+ ComputeDtlsPacketHash(data),
+ std::make_unique<webrtc::Buffer>(data.data(), data.size())));
}
void DtlsStunPiggybackController::ClearCachedPacketForTesting() {
RTC_DCHECK_RUN_ON(&sequence_checker_);
- pending_packet_.Clear();
+ pending_packets_.clear();
}
void DtlsStunPiggybackController::Flush() {
@@ -108,15 +121,17 @@
return std::nullopt;
}
- if (pending_packet_.size() == 0) {
+ if (pending_packets_.size() == 0) {
return std::nullopt;
}
- return absl::string_view(pending_packet_);
+
+ return absl::string_view(*pending_packets_.back().second.get());
}
std::optional<absl::string_view> DtlsStunPiggybackController::GetAckToPiggyback(
StunMessageType stun_message_type) {
RTC_DCHECK_RUN_ON(&sequence_checker_);
+
if (state_ == State::OFF || state_ == State::COMPLETE) {
return std::nullopt;
}
@@ -150,7 +165,7 @@
if (state_ == State::PENDING && data == nullptr && ack == nullptr) {
RTC_LOG(LS_INFO) << "DTLS-STUN piggybacking complete.";
state_ = State::COMPLETE;
- pending_packet_.Clear();
+ pending_packets_.clear();
handshake_ack_writer_.Clear();
handshake_messages_received_.clear();
return;
@@ -161,10 +176,32 @@
state_ = State::CONFIRMED;
}
- if (ack != nullptr && !ack->string_view().empty()) {
- RTC_LOG(LS_VERBOSE) << "DTLS-STUN piggybacking ACK: "
- << webrtc::hex_encode(ack->string_view());
+ if (ack != nullptr) {
+ if (!pending_packets_.empty()) {
+ // Unpack the ACK attribute (a list of uint32_t)
+ absl::flat_hash_set<uint32_t> acked_packets;
+ {
+ webrtc::ByteBufferReader ack_reader(ack->array_view());
+ uint32_t packet_hash;
+ while (ack_reader.ReadUInt32(&packet_hash)) {
+ acked_packets.insert(packet_hash);
+ }
+ }
+ RTC_LOG(LS_VERBOSE) << "DTLS-STUN piggybacking ACK: "
+ << webrtc::StrJoin(acked_packets, ",");
+
+ // Remove all acked packets from pending_packets_.
+ if (!acked_packets.empty()) {
+ pending_packets_.erase(
+ std::remove_if(pending_packets_.begin(), pending_packets_.end(),
+ [&](const auto& val) {
+ return acked_packets.contains(val.first);
+ }),
+ pending_packets_.end());
+ }
+ }
}
+
// The response to the final flight of the handshake will not contain
// the DTLS data but will contain an ack.
// Must not happen on the initial server to client packet which
@@ -172,31 +209,37 @@
if (data == nullptr && ack != nullptr && state_ == State::PENDING) {
RTC_LOG(LS_INFO) << "DTLS-STUN piggybacking complete.";
state_ = State::COMPLETE;
- pending_packet_.Clear();
+ pending_packets_.clear();
handshake_ack_writer_.Clear();
handshake_messages_received_.clear();
return;
}
+
if (!data || data->length() == 0) {
return;
}
- // Extract the received message sequence numbers of the handshake
+ // Extract the received message id of the handshake
// from the packet and prepare the ack to be sent.
- std::optional<std::vector<uint16_t>> new_message_sequences =
- webrtc::GetDtlsHandshakeAcks(data->array_view());
- if (!new_message_sequences) {
- RTC_LOG(LS_ERROR) << "DTLS-STUN piggybacking failed to parse DTLS packet.";
- return;
- }
- if (!new_message_sequences->empty()) {
- for (const auto& message_seq : *new_message_sequences) {
- handshake_messages_received_.insert(message_seq);
+ uint32_t hash = ComputeDtlsPacketHash(data->array_view());
+
+ // Check if we already received this packet.
+ if (std::find(handshake_messages_received_.begin(),
+ handshake_messages_received_.end(),
+ hash) == handshake_messages_received_.end()) {
+ handshake_messages_received_.push_back(hash);
+ handshake_ack_writer_.WriteUInt32(hash);
+
+ if (handshake_ack_writer_.Length() > kMaxAckSize) {
+ // If needed, limit size of ack attribute...by removing oldest ack.
+ handshake_messages_received_.erase(handshake_messages_received_.begin());
+ handshake_ack_writer_.Clear();
+ for (const auto& val : handshake_messages_received_) {
+ handshake_ack_writer_.WriteUInt32(val);
+ }
}
- handshake_ack_writer_.Clear();
- for (const auto& message_seq : handshake_messages_received_) {
- handshake_ack_writer_.WriteUInt16(message_seq);
- }
+
+ RTC_DCHECK(handshake_ack_writer_.Length() <= kMaxAckSize);
}
dtls_data_callback_(data->array_view());
diff --git a/p2p/dtls/dtls_stun_piggyback_controller.h b/p2p/dtls/dtls_stun_piggyback_controller.h
index fd500e4..b44a964 100644
--- a/p2p/dtls/dtls_stun_piggyback_controller.h
+++ b/p2p/dtls/dtls_stun_piggyback_controller.h
@@ -12,8 +12,10 @@
#define P2P_DTLS_DTLS_STUN_PIGGYBACK_CONTROLLER_H_
#include <cstdint>
+#include <memory>
#include <optional>
-#include <set>
+#include <utility>
+#include <vector>
#include "absl/functional/any_invocable.h"
#include "absl/strings/string_view.h"
@@ -31,6 +33,9 @@
// as the constructor.
class DtlsStunPiggybackController {
public:
+ // Never ack more than 4 packets.
+ static constexpr unsigned kMaxAckSize = 16;
+
// dtls_data_callback will be called with any DTLS packets received
// piggybacked.
DtlsStunPiggybackController(
@@ -86,11 +91,12 @@
private:
State state_ RTC_GUARDED_BY(sequence_checker_) = State::TENTATIVE;
bool writing_packets_ RTC_GUARDED_BY(sequence_checker_) = false;
- Buffer pending_packet_ RTC_GUARDED_BY(sequence_checker_);
+ std::vector<std::pair<uint32_t, std::unique_ptr<rtc::Buffer>>>
+ pending_packets_ RTC_GUARDED_BY(sequence_checker_);
absl::AnyInvocable<void(rtc::ArrayView<const uint8_t>)> dtls_data_callback_;
absl::AnyInvocable<void()> disable_piggybacking_callback_;
- std::set<uint16_t> handshake_messages_received_
+ std::vector<uint32_t> handshake_messages_received_
RTC_GUARDED_BY(sequence_checker_);
ByteBufferWriter handshake_ack_writer_ RTC_GUARDED_BY(sequence_checker_);
diff --git a/p2p/dtls/dtls_stun_piggyback_controller_unittest.cc b/p2p/dtls/dtls_stun_piggyback_controller_unittest.cc
index dda5c81..efbe980 100644
--- a/p2p/dtls/dtls_stun_piggyback_controller_unittest.cc
+++ b/p2p/dtls/dtls_stun_piggyback_controller_unittest.cc
@@ -16,8 +16,11 @@
#include <string>
#include <vector>
+#include "absl/strings/string_view.h"
#include "api/array_view.h"
#include "api/transport/stun.h"
+#include "p2p/dtls/dtls_utils.h"
+#include "rtc_base/byte_buffer.h"
#include "test/gmock.h"
#include "test/gtest.h"
@@ -52,6 +55,22 @@
0x00, 0x00, 0x00, 0x00, 0x00};
const std::vector<uint8_t> empty = {};
+
+std::string AsAckAttribute(const std::vector<uint32_t>& list) {
+ webrtc::ByteBufferWriter writer;
+ for (const auto& val : list) {
+ writer.WriteUInt32(val);
+ }
+ return std::string(writer.DataAsStringView());
+}
+
+std::vector<uint8_t> FakeDtlsPacket(uint16_t packet_number) {
+ auto packet = dtls_flight1;
+ packet[17] = static_cast<uint8_t>(packet_number >> 8);
+ packet[18] = static_cast<uint8_t>(packet_number & 255);
+ return packet;
+}
+
} // namespace
namespace webrtc {
@@ -74,14 +93,12 @@
client_.ClearCachedPacketForTesting();
}
std::unique_ptr<StunByteStringAttribute> attr_data;
- if (client_.GetDataToPiggyback(type)) {
- attr_data = std::make_unique<StunByteStringAttribute>(
- STUN_ATTR_META_DTLS_IN_STUN, *client_.GetDataToPiggyback(type));
+ if (auto data = client_.GetDataToPiggyback(type)) {
+ attr_data = WrapInStun(STUN_ATTR_META_DTLS_IN_STUN, *data);
}
std::unique_ptr<StunByteStringAttribute> attr_ack;
- if (client_.GetAckToPiggyback(type)) {
- attr_ack = std::make_unique<StunByteStringAttribute>(
- STUN_ATTR_META_DTLS_IN_STUN_ACK, *client_.GetAckToPiggyback(type));
+ if (auto data = client_.GetAckToPiggyback(type)) {
+ attr_ack = WrapInStun(STUN_ATTR_META_DTLS_IN_STUN_ACK, *data);
}
server_.ReportDataPiggybacked(attr_data.get(), attr_ack.get());
}
@@ -94,14 +111,12 @@
server_.ClearCachedPacketForTesting();
}
std::unique_ptr<StunByteStringAttribute> attr_data;
- if (server_.GetDataToPiggyback(type)) {
- attr_data = std::make_unique<StunByteStringAttribute>(
- STUN_ATTR_META_DTLS_IN_STUN, *server_.GetDataToPiggyback(type));
+ if (auto data = server_.GetDataToPiggyback(type)) {
+ attr_data = WrapInStun(STUN_ATTR_META_DTLS_IN_STUN, *data);
}
std::unique_ptr<StunByteStringAttribute> attr_ack;
- if (server_.GetAckToPiggyback(type)) {
- attr_ack = std::make_unique<StunByteStringAttribute>(
- STUN_ATTR_META_DTLS_IN_STUN_ACK, *server_.GetAckToPiggyback(type));
+ if (auto data = server_.GetAckToPiggyback(type)) {
+ attr_ack = WrapInStun(STUN_ATTR_META_DTLS_IN_STUN_ACK, *data);
}
client_.ReportDataPiggybacked(attr_data.get(), attr_ack.get());
if (data == dtls_flight4) {
@@ -113,6 +128,19 @@
}
}
+ std::unique_ptr<StunByteStringAttribute> WrapInStun(
+ cricket::IceAttributeType type,
+ absl::string_view data) {
+ return std::make_unique<StunByteStringAttribute>(type, data);
+ }
+
+ std::unique_ptr<StunByteStringAttribute> WrapInStun(
+ cricket::IceAttributeType type,
+ const std::vector<uint8_t>& data) {
+ return std::make_unique<StunByteStringAttribute>(type, data.data(),
+ data.size());
+ }
+
void DisableSupport(DtlsStunPiggybackController& client_or_server) {
ASSERT_EQ(client_or_server.state(), State::TENTATIVE);
client_or_server.ReportDataPiggybacked(nullptr, nullptr);
@@ -260,17 +288,23 @@
SendClientToServer(dtls_flight1, STUN_BINDING_REQUEST);
SendServerToClient(dtls_flight2, STUN_BINDING_RESPONSE);
EXPECT_EQ(server_.GetAckToPiggyback(STUN_BINDING_REQUEST),
- std::string("\x12\x34", 2));
+ AsAckAttribute({ComputeDtlsPacketHash(dtls_flight1)}));
EXPECT_EQ(client_.GetAckToPiggyback(STUN_BINDING_RESPONSE),
- std::string("\x43\x21", 2));
+ AsAckAttribute({ComputeDtlsPacketHash(dtls_flight2)}));
// Flight 3+4
SendClientToServer(dtls_flight3, STUN_BINDING_REQUEST);
SendServerToClient(dtls_flight4, STUN_BINDING_RESPONSE);
EXPECT_EQ(server_.GetAckToPiggyback(STUN_BINDING_RESPONSE),
- std::string("\x12\x34\x44\x44", 4));
+ AsAckAttribute({
+ ComputeDtlsPacketHash(dtls_flight1),
+ ComputeDtlsPacketHash(dtls_flight3),
+ }));
EXPECT_EQ(client_.GetAckToPiggyback(STUN_BINDING_REQUEST),
- std::string("\x43\x21\x54\x86", 4));
+ AsAckAttribute({
+ ComputeDtlsPacketHash(dtls_flight2),
+ ComputeDtlsPacketHash(dtls_flight4),
+ }));
// Post-handshake ACK
SendServerToClient(empty, STUN_BINDING_REQUEST);
@@ -285,15 +319,63 @@
// Flight 1+2
SendClientToServer(dtls_flight1, STUN_BINDING_REQUEST);
EXPECT_EQ(server_.GetAckToPiggyback(STUN_BINDING_REQUEST),
- std::string("\x12\x34", 2));
+ AsAckAttribute({ComputeDtlsPacketHash(dtls_flight1)}));
SendClientToServer(dtls_flight3, STUN_BINDING_REQUEST);
EXPECT_EQ(server_.GetAckToPiggyback(STUN_BINDING_REQUEST),
- std::string("\x12\x34\x44\x44", 4));
+ AsAckAttribute({
+ ComputeDtlsPacketHash(dtls_flight1),
+ ComputeDtlsPacketHash(dtls_flight3),
+ }));
// Receive Flight 1 again, no change expected.
SendClientToServer(dtls_flight1, STUN_BINDING_REQUEST);
EXPECT_EQ(server_.GetAckToPiggyback(STUN_BINDING_REQUEST),
- std::string("\x12\x34\x44\x44", 4));
+ AsAckAttribute({
+ ComputeDtlsPacketHash(dtls_flight1),
+ ComputeDtlsPacketHash(dtls_flight3),
+ }));
+}
+
+TEST_F(DtlsStunPiggybackControllerTest, DontSendAckedPackets) {
+ server_.CapturePacket(dtls_flight1);
+ server_.Flush();
+ EXPECT_TRUE(server_.GetDataToPiggyback(STUN_BINDING_REQUEST).has_value());
+ server_.ReportDataPiggybacked(
+ nullptr, WrapInStun(STUN_ATTR_META_DTLS_IN_STUN_ACK,
+ AsAckAttribute({ComputeDtlsPacketHash(dtls_flight1)}))
+ .get());
+ // No unacked packet exists.
+ EXPECT_FALSE(server_.GetDataToPiggyback(STUN_BINDING_REQUEST).has_value());
+}
+
+TEST_F(DtlsStunPiggybackControllerTest, LimitAckSize) {
+ std::vector<uint8_t> dtls_flight5 = FakeDtlsPacket(0x5487);
+
+ server_.ReportDataPiggybacked(
+ WrapInStun(STUN_ATTR_META_DTLS_IN_STUN, dtls_flight1).get(), nullptr);
+ EXPECT_EQ(server_.GetAckToPiggyback(STUN_BINDING_REQUEST)->size(), 4u);
+ server_.ReportDataPiggybacked(
+ WrapInStun(STUN_ATTR_META_DTLS_IN_STUN, dtls_flight2).get(), nullptr);
+ EXPECT_EQ(server_.GetAckToPiggyback(STUN_BINDING_REQUEST)->size(), 8u);
+ server_.ReportDataPiggybacked(
+ WrapInStun(STUN_ATTR_META_DTLS_IN_STUN, dtls_flight3).get(), nullptr);
+ EXPECT_EQ(server_.GetAckToPiggyback(STUN_BINDING_REQUEST)->size(), 12u);
+ server_.ReportDataPiggybacked(
+ WrapInStun(STUN_ATTR_META_DTLS_IN_STUN, dtls_flight4).get(), nullptr);
+ EXPECT_EQ(server_.GetAckToPiggyback(STUN_BINDING_REQUEST)->size(), 16u);
+
+ // Limit size of ack so that it does not grow unbounded.
+ server_.ReportDataPiggybacked(
+ WrapInStun(STUN_ATTR_META_DTLS_IN_STUN, dtls_flight5).get(), nullptr);
+ EXPECT_EQ(server_.GetAckToPiggyback(STUN_BINDING_REQUEST)->size(),
+ DtlsStunPiggybackController::kMaxAckSize);
+ EXPECT_EQ(server_.GetAckToPiggyback(STUN_BINDING_REQUEST),
+ AsAckAttribute({
+ ComputeDtlsPacketHash(dtls_flight2),
+ ComputeDtlsPacketHash(dtls_flight3),
+ ComputeDtlsPacketHash(dtls_flight4),
+ ComputeDtlsPacketHash(dtls_flight5),
+ }));
}
} // namespace webrtc
diff --git a/p2p/dtls/dtls_utils.cc b/p2p/dtls/dtls_utils.cc
index 1ab2a01..227a183 100644
--- a/p2p/dtls/dtls_utils.cc
+++ b/p2p/dtls/dtls_utils.cc
@@ -17,6 +17,7 @@
#include "api/array_view.h"
#include "rtc_base/byte_buffer.h"
#include "rtc_base/checks.h"
+#include "rtc_base/crc32.h"
namespace {
// https://datatracker.ietf.org/doc/html/rfc5246#appendix-A.1
@@ -133,6 +134,7 @@
}
RTC_DCHECK(handshake_buf.Length() == 0);
}
+
// Should have consumed everything.
if (record_buf.Length() != 0) {
return std::nullopt;
@@ -140,4 +142,8 @@
return acks;
}
+uint32_t ComputeDtlsPacketHash(rtc::ArrayView<const uint8_t> dtls_packet) {
+ return webrtc::ComputeCrc32(dtls_packet.data(), dtls_packet.size());
+}
+
} // namespace webrtc
diff --git a/p2p/dtls/dtls_utils.h b/p2p/dtls/dtls_utils.h
index d50c514..f23e578 100644
--- a/p2p/dtls/dtls_utils.h
+++ b/p2p/dtls/dtls_utils.h
@@ -30,6 +30,8 @@
std::optional<std::vector<uint16_t>> GetDtlsHandshakeAcks(
rtc::ArrayView<const uint8_t> dtls_packet);
+uint32_t ComputeDtlsPacketHash(rtc::ArrayView<const uint8_t> dtls_packet);
+
} // namespace webrtc
// Re-export symbols from the webrtc namespace for backwards compatibility.
diff --git a/pc/data_channel_integrationtest.cc b/pc/data_channel_integrationtest.cc
index c18c1a2..1668211 100644
--- a/pc/data_channel_integrationtest.cc
+++ b/pc/data_channel_integrationtest.cc
@@ -21,6 +21,7 @@
#include <vector>
#include "absl/algorithm/container.h"
+#include "absl/strings/match.h"
#include "api/data_channel_interface.h"
#include "api/dtls_transport_interface.h"
#include "api/jsep.h"
@@ -1626,6 +1627,15 @@
}
const char* CheckSupported() {
+ const bool callee_active = std::get<0>(GetParam());
+ const bool callee_has_dtls_in_stun = absl::StrContains(
+ std::get<2>(GetParam()), "WebRTC-IceHandshakeDtls/Enabled/");
+ const bool callee2_has_dtls_in_stun = absl::StrContains(
+ std::get<3>(GetParam()), "WebRTC-IceHandshakeDtls/Enabled/");
+ if (callee_active &&
+ (callee_has_dtls_in_stun || callee2_has_dtls_in_stun)) {
+ return "dtls-in-stun when callee(s) are dtls clients";
+ }
return nullptr;
}
};