Add support for caching more than 1 client hello packets.
This is "needed" for PQC (i.e. to avoid timeout/retransmit).
- move (inlined) code for PacketStash (from StunPiggyBackController) into class that is unit tested separately and also used for the cached
client hellos.
- Extend TestEventOrdering to cover dtls1.3 and PQC.
BUG=webrtc:404763475
Change-Id: I3a05f7685578d3e1de5bdd5e8992a0f60182b263
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/386901
Commit-Queue: Jonas Oreland <jonaso@webrtc.org>
Reviewed-by: Sameer Vijaykar <samvi@google.com>
Cr-Commit-Position: refs/heads/main@{#44411}
diff --git a/p2p/BUILD.gn b/p2p/BUILD.gn
index 63331cf..e03f295 100644
--- a/p2p/BUILD.gn
+++ b/p2p/BUILD.gn
@@ -287,6 +287,7 @@
"../api/crypto:options",
"../api/rtc_event_log",
"../api/task_queue:pending_task_safety_flag",
+ "../api/transport:ecn_marking",
"../api/transport:stun_types",
"../api/units:time_delta",
"../api/units:timestamp",
@@ -705,9 +706,11 @@
]
deps = [
"../api:array_view",
+ "../rtc_base:buffer",
"../rtc_base:byte_buffer",
"../rtc_base:checks",
"../rtc_base:crc32",
+ "//third_party/abseil-cpp/absl/container:flat_hash_set",
]
}
@@ -1284,6 +1287,7 @@
"../test:test_support",
"../test:wait_until",
"//third_party/abseil-cpp/absl/algorithm:container",
+ "//third_party/abseil-cpp/absl/container:flat_hash_set",
"//third_party/abseil-cpp/absl/functional:any_invocable",
"//third_party/abseil-cpp/absl/memory",
"//third_party/abseil-cpp/absl/strings",
diff --git a/p2p/dtls/dtls_stun_piggyback_controller.cc b/p2p/dtls/dtls_stun_piggyback_controller.cc
index 92247c5..db7f484 100644
--- a/p2p/dtls/dtls_stun_piggyback_controller.cc
+++ b/p2p/dtls/dtls_stun_piggyback_controller.cc
@@ -12,7 +12,6 @@
#include <algorithm>
#include <cstdint>
-#include <memory>
#include <optional>
#include <utility>
#include <vector>
@@ -24,7 +23,6 @@
#include "api/sequence_checker.h"
#include "api/transport/stun.h"
#include "p2p/dtls/dtls_utils.h"
-#include "rtc_base/buffer.h"
#include "rtc_base/byte_buffer.h"
#include "rtc_base/checks.h"
#include "rtc_base/logging.h"
@@ -50,7 +48,6 @@
// the last flight from the server.
// For DTLS 1.3 this is reversed since the handshake has one round trip less.
if ((is_dtls_client && !is_dtls13) || (!is_dtls_client && is_dtls13)) {
- pending_packet_pos_ = 0;
pending_packets_.clear();
}
@@ -72,19 +69,15 @@
// is made for 1-packet at a time. Use the writing_packets_ variable to keep
// track of a full batch. The writing_packets_ is reset in Flush.
if (!writing_packets_) {
- pending_packet_pos_ = 0;
pending_packets_.clear();
writing_packets_ = true;
}
- pending_packets_.push_back(std::make_pair(
- ComputeDtlsPacketHash(data),
- std::make_unique<webrtc::Buffer>(data.data(), data.size())));
+ pending_packets_.Add(data);
}
void DtlsStunPiggybackController::ClearCachedPacketForTesting() {
RTC_DCHECK_RUN_ON(&sequence_checker_);
- pending_packet_pos_ = 0;
pending_packets_.clear();
}
@@ -116,13 +109,13 @@
return std::nullopt;
}
- if (pending_packets_.size() == 0) {
+ if (pending_packets_.empty()) {
return std::nullopt;
}
- auto pos = pending_packet_pos_;
- pending_packet_pos_ = (pos + 1) % pending_packets_.size();
- return absl::string_view(*pending_packets_[pos].second.get());
+ const auto packet = pending_packets_.GetNext();
+ return absl::string_view(reinterpret_cast<const char*>(packet.data()),
+ packet.size());
}
std::optional<absl::string_view> DtlsStunPiggybackController::GetAckToPiggyback(
@@ -162,7 +155,6 @@
if (state_ == State::PENDING && data == nullptr && ack == nullptr) {
RTC_LOG(LS_INFO) << "DTLS-STUN piggybacking complete.";
state_ = State::COMPLETE;
- pending_packet_pos_ = 0;
pending_packets_.clear();
handshake_ack_writer_.Clear();
handshake_messages_received_.clear();
@@ -189,20 +181,7 @@
<< webrtc::StrJoin(acked_packets, ",");
// Remove all acked packets from pending_packets_.
- if (!acked_packets.empty()) {
- uint32_t before = pending_packets_.size();
- pending_packets_.erase(
- std::remove_if(pending_packets_.begin(), pending_packets_.end(),
- [&](const auto& val) {
- return acked_packets.contains(val.first);
- }),
- pending_packets_.end());
- uint32_t after = pending_packets_.size();
- uint32_t removed = before - after;
- if (pending_packet_pos_ >= removed) {
- pending_packet_pos_ -= removed;
- }
- }
+ pending_packets_.Prune(acked_packets);
}
}
@@ -213,7 +192,6 @@
if (data == nullptr && ack != nullptr && state_ == State::PENDING) {
RTC_LOG(LS_INFO) << "DTLS-STUN piggybacking complete.";
state_ = State::COMPLETE;
- pending_packet_pos_ = 0;
pending_packets_.clear();
handshake_ack_writer_.Clear();
handshake_messages_received_.clear();
diff --git a/p2p/dtls/dtls_stun_piggyback_controller.h b/p2p/dtls/dtls_stun_piggyback_controller.h
index 71379d2..02aa71c 100644
--- a/p2p/dtls/dtls_stun_piggyback_controller.h
+++ b/p2p/dtls/dtls_stun_piggyback_controller.h
@@ -12,9 +12,7 @@
#define P2P_DTLS_DTLS_STUN_PIGGYBACK_CONTROLLER_H_
#include <cstdint>
-#include <memory>
#include <optional>
-#include <utility>
#include <vector>
#include "absl/functional/any_invocable.h"
@@ -22,7 +20,7 @@
#include "api/array_view.h"
#include "api/sequence_checker.h"
#include "api/transport/stun.h"
-#include "rtc_base/buffer.h"
+#include "p2p/dtls/dtls_utils.h"
#include "rtc_base/byte_buffer.h"
#include "rtc_base/system/no_unique_address.h"
#include "rtc_base/thread_annotations.h"
@@ -91,11 +89,8 @@
private:
State state_ RTC_GUARDED_BY(sequence_checker_) = State::TENTATIVE;
bool writing_packets_ RTC_GUARDED_BY(sequence_checker_) = false;
- uint32_t pending_packet_pos_ RTC_GUARDED_BY(sequence_checker_) = 0;
- std::vector<std::pair<uint32_t, std::unique_ptr<Buffer>>> pending_packets_
- RTC_GUARDED_BY(sequence_checker_);
- absl::AnyInvocable<void(webrtc::ArrayView<const uint8_t>)>
- dtls_data_callback_;
+ PacketStash pending_packets_ RTC_GUARDED_BY(sequence_checker_);
+ absl::AnyInvocable<void(ArrayView<const uint8_t>)> dtls_data_callback_;
absl::AnyInvocable<void()> disable_piggybacking_callback_;
std::vector<uint32_t> handshake_messages_received_
diff --git a/p2p/dtls/dtls_transport.cc b/p2p/dtls/dtls_transport.cc
index 38588df..f8d118a 100644
--- a/p2p/dtls/dtls_transport.cc
+++ b/p2p/dtls/dtls_transport.cc
@@ -28,6 +28,7 @@
#include "api/scoped_refptr.h"
#include "api/sequence_checker.h"
#include "api/task_queue/pending_task_safety_flag.h"
+#include "api/transport/ecn_marking.h"
#include "api/transport/stun.h"
#include "api/units/time_delta.h"
#include "api/units/timestamp.h"
@@ -39,11 +40,12 @@
#include "p2p/dtls/dtls_stun_piggyback_controller.h"
#include "p2p/dtls/dtls_transport_internal.h"
#include "p2p/dtls/dtls_utils.h"
+#include "rtc_base/async_packet_socket.h"
#include "rtc_base/buffer.h"
#include "rtc_base/checks.h"
#include "rtc_base/logging.h"
-#include "rtc_base/network/ecn_marking.h"
#include "rtc_base/network/received_packet.h"
+#include "rtc_base/network/sent_packet.h"
#include "rtc_base/network_route.h"
#include "rtc_base/rtc_certificate.h"
#include "rtc_base/socket.h"
@@ -97,6 +99,8 @@
// This effectively disables the handshake timeout.
constexpr int kDisabledHandshakeTimeoutMs = 3600 * 1000 * 24;
+constexpr uint32_t kMaxCachedClientHello = 4;
+
static bool IsRtpPacket(ArrayView<const uint8_t> payload) {
const uint8_t* u = payload.data();
return (payload.size() >= kMinRtpPacketLen && (u[0] & 0xC0) == 0x80);
@@ -773,7 +777,8 @@
RTC_LOG(LS_INFO) << ToString()
<< ": Caching DTLS ClientHello packet until DTLS is "
"started.";
- cached_client_hello_.SetData(packet.payload());
+ cached_client_hello_.AddIfUnique(packet.payload());
+ cached_client_hello_.Prune(kMaxCachedClientHello);
// If we haven't started setting up DTLS yet (because we don't have a
// remote fingerprint/role), we can use the client hello as a clue that
// the peer has chosen the client role, and proceed with the handshake.
@@ -946,19 +951,24 @@
set_dtls_state(webrtc::DtlsTransportState::kConnecting);
// Now that the handshake has started, we can process a cached ClientHello
// (if one exists).
- if (cached_client_hello_.size()) {
+ if (!cached_client_hello_.empty()) {
if (*dtls_role_ == webrtc::SSL_SERVER) {
- RTC_LOG(LS_INFO) << ToString()
- << ": Handling cached DTLS ClientHello packet.";
- if (!HandleDtlsPacket(cached_client_hello_)) {
- RTC_LOG(LS_ERROR) << ToString() << ": Failed to handle DTLS packet.";
+ int size = cached_client_hello_.size();
+ RTC_LOG(LS_INFO) << ToString() << ": Handling #" << size
+ << " cached DTLS ClientHello packet(s).";
+ for (int i = 0; i < size; i++) {
+ if (!HandleDtlsPacket(cached_client_hello_.GetNext())) {
+ RTC_LOG(LS_ERROR)
+ << ToString() << ": Failed to handle DTLS packet.";
+ break;
+ }
}
} else {
RTC_LOG(LS_WARNING) << ToString()
<< ": Discarding cached DTLS ClientHello packet "
"because we don't have the server role.";
}
- cached_client_hello_.Clear();
+ cached_client_hello_.clear();
}
}
}
diff --git a/p2p/dtls/dtls_transport.h b/p2p/dtls/dtls_transport.h
index 70a4a7f..e39508a 100644
--- a/p2p/dtls/dtls_transport.h
+++ b/p2p/dtls/dtls_transport.h
@@ -32,6 +32,7 @@
#include "p2p/base/packet_transport_internal.h"
#include "p2p/dtls/dtls_stun_piggyback_controller.h"
#include "p2p/dtls/dtls_transport_internal.h"
+#include "p2p/dtls/dtls_utils.h"
#include "rtc_base/async_packet_socket.h"
#include "rtc_base/buffer.h"
#include "rtc_base/buffer_queue.h"
@@ -289,7 +290,7 @@
// Cached DTLS ClientHello packet that was received before we started the
// DTLS handshake. This could happen if the hello was received before the
// ice transport became writable, or before a remote fingerprint was received.
- Buffer cached_client_hello_;
+ PacketStash cached_client_hello_;
bool receiving_ = false;
bool writable_ = false;
diff --git a/p2p/dtls/dtls_transport_unittest.cc b/p2p/dtls/dtls_transport_unittest.cc
index 0ef049c..a32c477 100644
--- a/p2p/dtls/dtls_transport_unittest.cc
+++ b/p2p/dtls/dtls_transport_unittest.cc
@@ -108,6 +108,7 @@
void SetupMaxProtocolVersion(SSLProtocolVersion version) {
ssl_max_version_ = version;
}
+ void SetPqc(bool value) { pqc_ = value; }
void set_async_delay(int async_delay_ms) { async_delay_ms_ = async_delay_ms; }
// Set up fake ICE transport and real DTLS transport under test.
@@ -117,6 +118,10 @@
dtls_transport_ = nullptr;
fake_ice_transport_ = nullptr;
+ if (field_trials_string.empty() && pqc_) {
+ field_trials_string = "WebRTC-EnableDtlsPqc/Enabled/";
+ }
+
fake_ice_transport_.reset(new FakeIceTransport(
absl::StrCat("fake-", name_), 0,
/* network_thread= */ nullptr, field_trials_string));
@@ -390,6 +395,7 @@
SentPacketInfo sent_packet_;
absl::AnyInvocable<void()> writable_func_;
int async_delay_ms_ = 100;
+ bool pqc_ = false;
};
// Base class for DtlsTransportInternalImplTest and DtlsEventOrderingTest, which
@@ -404,10 +410,16 @@
start_time_ns_ = fake_clock_.TimeNanos();
}
+ void SetPqc(bool value) {
+ client1_.SetPqc(value);
+ client2_.SetPqc(value);
+ }
+
void SetMaxProtocolVersions(SSLProtocolVersion c1, SSLProtocolVersion c2) {
client1_.SetupMaxProtocolVersion(c1);
client2_.SetupMaxProtocolVersion(c2);
}
+
// If not called, DtlsTransportInternalImpl will be used in SRTP bypass mode.
void PrepareDtls(KeyType key_type) {
client1_.CreateCertificate(key_type);
@@ -555,6 +567,7 @@
DtlsTestClient client1_;
DtlsTestClient client2_;
bool use_dtls_;
+ bool pqc_ = false;
uint64_t start_time_ns_;
SSLProtocolVersion ssl_expected_version_;
};
@@ -1362,13 +1375,25 @@
class DtlsEventOrderingTest
: public DtlsTransportInternalImplTestBase,
public ::testing::TestWithParam<
- ::testing::tuple<std::vector<DtlsTransportInternalImplEvent>, bool>> {
+ ::testing::tuple<std::vector<DtlsTransportInternalImplEvent>,
+ bool /* valid_fingerprint */,
+ SSLProtocolVersion,
+ bool /* pqc */>> {
protected:
// If `valid_fingerprint` is false, the caller will receive a fingerprint
// that doesn't match the callee's certificate, so the handshake should fail.
void TestEventOrdering(
const std::vector<DtlsTransportInternalImplEvent>& events,
bool valid_fingerprint) {
+ bool pqc = ::testing::get<3>(GetParam());
+ if (pqc && ::testing::get<2>(GetParam()) != SSL_PROTOCOL_DTLS_13) {
+ GTEST_SKIP() << "PQC requires DTLS1.3";
+ }
+
+ SetPqc(::testing::get<3>(GetParam()));
+ SetMaxProtocolVersions(::testing::get<2>(GetParam()),
+ ::testing::get<2>(GetParam()));
+
// Pre-setup: Set local certificate on both caller and callee, and
// remote fingerprint on callee, but neither is writable and the caller
// doesn't have the callee's fingerprint.
@@ -1406,7 +1431,7 @@
EXPECT_TRUE(WaitUntil(
[&] { return client2_.fake_ice_transport()->writable(); }));
EXPECT_TRUE(WaitUntil(
- [&] { return client1_.received_dtls_client_hellos() == 1; }));
+ [&] { return client1_.received_dtls_client_hellos() >= 1; }));
break;
case HANDSHAKE_FINISHES:
// Sanity check that the handshake hasn't already finished.
@@ -1442,8 +1467,9 @@
EXPECT_EQ(valid_fingerprint, client1_.dtls_transport()->writable());
EXPECT_EQ(valid_fingerprint, client2_.dtls_transport()->writable());
+ int count = pqc ? 2 : 1;
// Check that no hello needed to be retransmitted.
- EXPECT_EQ(1, client1_.received_dtls_client_hellos());
+ EXPECT_EQ(count, client1_.received_dtls_client_hellos());
EXPECT_EQ(1, client2_.received_dtls_server_hellos());
if (valid_fingerprint) {
@@ -1486,6 +1512,8 @@
std::vector<DtlsTransportInternalImplEvent>{
CALLER_RECEIVES_CLIENTHELLO, CALLER_WRITABLE,
HANDSHAKE_FINISHES, CALLER_RECEIVES_FINGERPRINT}),
+ ::testing::Bool(),
+ ::testing::Values(SSL_PROTOCOL_DTLS_12, SSL_PROTOCOL_DTLS_13),
::testing::Bool()));
class DtlsTransportInternalImplDtlsInStunTest
diff --git a/p2p/dtls/dtls_utils.cc b/p2p/dtls/dtls_utils.cc
index 69353c2..6a32d4c 100644
--- a/p2p/dtls/dtls_utils.cc
+++ b/p2p/dtls/dtls_utils.cc
@@ -10,11 +10,15 @@
#include "p2p/dtls/dtls_utils.h"
+#include <algorithm>
#include <cstdint>
+#include <memory>
#include <optional>
#include <vector>
+#include "absl/container/flat_hash_set.h"
#include "api/array_view.h"
+#include "rtc_base/buffer.h"
#include "rtc_base/byte_buffer.h"
#include "rtc_base/checks.h"
#include "rtc_base/crc32.h"
@@ -146,4 +150,62 @@
return webrtc::ComputeCrc32(dtls_packet.data(), dtls_packet.size());
}
+bool PacketStash::AddIfUnique(rtc::ArrayView<const uint8_t> packet) {
+ uint32_t h = ComputeDtlsPacketHash(packet);
+ for (const auto& [hash, p] : packets_) {
+ if (h == hash) {
+ return false;
+ }
+ }
+ packets_.push_back({.hash = h,
+ .buffer = std::make_unique<webrtc::Buffer>(
+ packet.data(), packet.size())});
+ return true;
+}
+
+void PacketStash::Add(rtc::ArrayView<const uint8_t> packet) {
+ packets_.push_back({.hash = ComputeDtlsPacketHash(packet),
+ .buffer = std::make_unique<webrtc::Buffer>(
+ packet.data(), packet.size())});
+}
+
+void PacketStash::Prune(const absl::flat_hash_set<uint32_t>& hashes) {
+ if (hashes.empty()) {
+ return;
+ }
+ uint32_t before = packets_.size();
+ packets_.erase(std::remove_if(packets_.begin(), packets_.end(),
+ [&](const auto& val) {
+ return hashes.contains(val.hash);
+ }),
+ packets_.end());
+ uint32_t after = packets_.size();
+ uint32_t removed = before - after;
+ if (pos_ >= removed) {
+ pos_ -= removed;
+ }
+}
+
+void PacketStash::Prune(uint32_t max_size) {
+ auto size = packets_.size();
+ if (size <= max_size) {
+ return;
+ }
+ auto removed = size - max_size;
+ packets_.erase(packets_.begin(), packets_.begin() + removed);
+ if (pos_ <= removed) {
+ pos_ = 0;
+ } else {
+ pos_ -= removed;
+ }
+}
+
+rtc::ArrayView<const uint8_t> PacketStash::GetNext() {
+ RTC_DCHECK(!packets_.empty());
+ auto pos = pos_;
+ pos_ = (pos + 1) % packets_.size();
+ const auto& buffer = packets_[pos].buffer;
+ return rtc::ArrayView<const uint8_t>(buffer->data(), buffer->size());
+}
+
} // namespace webrtc
diff --git a/p2p/dtls/dtls_utils.h b/p2p/dtls/dtls_utils.h
index e9af4b9..773df9b 100644
--- a/p2p/dtls/dtls_utils.h
+++ b/p2p/dtls/dtls_utils.h
@@ -13,10 +13,13 @@
#include <cstddef>
#include <cstdint>
+#include <memory>
#include <optional>
#include <vector>
+#include "absl/container/flat_hash_set.h"
#include "api/array_view.h"
+#include "rtc_base/buffer.h"
namespace webrtc {
@@ -32,6 +35,40 @@
uint32_t ComputeDtlsPacketHash(ArrayView<const uint8_t> dtls_packet);
+class PacketStash {
+ public:
+ PacketStash() {}
+
+ void Add(rtc::ArrayView<const uint8_t> packet);
+ bool AddIfUnique(rtc::ArrayView<const uint8_t> packet);
+ void Prune(const absl::flat_hash_set<uint32_t>& packet_hashes);
+ void Prune(uint32_t max_size);
+ rtc::ArrayView<const uint8_t> GetNext();
+
+ void clear() {
+ packets_.clear();
+ pos_ = 0;
+ }
+ bool empty() const { return packets_.empty(); }
+ int size() const { return packets_.size(); }
+
+ static uint32_t Hash(rtc::ArrayView<const uint8_t> packet) {
+ return ComputeDtlsPacketHash(packet);
+ }
+
+ private:
+ struct StashedPacket {
+ uint32_t hash;
+ std::unique_ptr<rtc::Buffer> buffer;
+ };
+
+ // This vector will only contain very few items,
+ // so it is appropriate to use a vector rather than
+ // e.g. a hash map.
+ uint32_t pos_ = 0;
+ std::vector<StashedPacket> packets_;
+};
+
} // namespace webrtc
// Re-export symbols from the webrtc namespace for backwards compatibility.
diff --git a/p2p/dtls/dtls_utils_unittest.cc b/p2p/dtls/dtls_utils_unittest.cc
index 5e7265f..2991f0a 100644
--- a/p2p/dtls/dtls_utils_unittest.cc
+++ b/p2p/dtls/dtls_utils_unittest.cc
@@ -14,6 +14,8 @@
#include <optional>
#include <vector>
+#include "absl/container/flat_hash_set.h"
+#include "api/array_view.h"
#include "test/gmock.h"
#include "test/gtest.h"
@@ -198,4 +200,141 @@
EXPECT_EQ(acks->size(), 0u);
}
+std::vector<uint8_t> ToVector(rtc::ArrayView<const uint8_t> array) {
+ return std::vector<uint8_t>(array.begin(), array.end());
+}
+
+TEST(PacketStash, Add) {
+ PacketStash stash;
+ std::vector<uint8_t> packet = {
+ 0x2f, 0x5b, 0x4c, 0x00, 0x23, 0x47, 0xab, 0xe7, 0x90, 0x96,
+ 0xc0, 0xac, 0x2f, 0x25, 0x40, 0x35, 0x35, 0xa3, 0x81, 0x50,
+ 0x0c, 0x38, 0x0a, 0xf6, 0xd4, 0xd5, 0x7d, 0xbe, 0x9a, 0xa3,
+ 0xcb, 0xcb, 0x67, 0xb0, 0x77, 0x79, 0x8b, 0x48, 0x60, 0xf8,
+ };
+
+ stash.Add(packet);
+ EXPECT_EQ(stash.size(), 1);
+ EXPECT_EQ(ToVector(stash.GetNext()), packet);
+
+ stash.Add(packet);
+ EXPECT_EQ(stash.size(), 2);
+ EXPECT_EQ(ToVector(stash.GetNext()), packet);
+ EXPECT_EQ(ToVector(stash.GetNext()), packet);
+}
+
+TEST(PacketStash, AddIfUnique) {
+ PacketStash stash;
+ std::vector<uint8_t> packet1 = {
+ 0x2f, 0x5b, 0x4c, 0x00, 0x23, 0x47, 0xab, 0xe7, 0x90, 0x96,
+ 0xc0, 0xac, 0x2f, 0x25, 0x40, 0x35, 0x35, 0xa3, 0x81, 0x50,
+ 0x0c, 0x38, 0x0a, 0xf6, 0xd4, 0xd5, 0x7d, 0xbe, 0x9a, 0xa3,
+ 0xcb, 0xcb, 0x67, 0xb0, 0x77, 0x79, 0x8b, 0x48, 0x60, 0xf8,
+ };
+
+ std::vector<uint8_t> packet2 = {
+ 0x16, 0xfe, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x0c, 0x0e, 0x00, 0x00, 0x00, 0x00,
+ 0xac, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ };
+
+ stash.AddIfUnique(packet1);
+ EXPECT_EQ(stash.size(), 1);
+ EXPECT_EQ(ToVector(stash.GetNext()), packet1);
+
+ stash.AddIfUnique(packet1);
+ EXPECT_EQ(stash.size(), 1);
+ EXPECT_EQ(ToVector(stash.GetNext()), packet1);
+
+ stash.AddIfUnique(packet2);
+ EXPECT_EQ(stash.size(), 2);
+ EXPECT_EQ(ToVector(stash.GetNext()), packet1);
+ EXPECT_EQ(ToVector(stash.GetNext()), packet2);
+
+ stash.AddIfUnique(packet2);
+ EXPECT_EQ(stash.size(), 2);
+}
+
+TEST(PacketStash, Prune) {
+ PacketStash stash;
+ std::vector<uint8_t> packet1 = {
+ 0x2f, 0x5b, 0x4c, 0x00, 0x23, 0x47, 0xab, 0xe7, 0x90, 0x96,
+ 0xc0, 0xac, 0x2f, 0x25, 0x40, 0x35, 0x35, 0xa3, 0x81, 0x50,
+ 0x0c, 0x38, 0x0a, 0xf6, 0xd4, 0xd5, 0x7d, 0xbe, 0x9a, 0xa3,
+ 0xcb, 0xcb, 0x67, 0xb0, 0x77, 0x79, 0x8b, 0x48, 0x60, 0xf8,
+ };
+
+ std::vector<uint8_t> packet2 = {
+ 0x16, 0xfe, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x0c, 0x0e, 0x00, 0x00, 0x00, 0x00,
+ 0xac, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ };
+
+ stash.AddIfUnique(packet1);
+ stash.AddIfUnique(packet2);
+ EXPECT_EQ(stash.size(), 2);
+ EXPECT_EQ(ToVector(stash.GetNext()), packet1);
+ EXPECT_EQ(ToVector(stash.GetNext()), packet2);
+
+ absl::flat_hash_set<uint32_t> remove;
+ remove.insert(PacketStash::Hash(packet1));
+ stash.Prune(remove);
+
+ EXPECT_EQ(stash.size(), 1);
+ EXPECT_EQ(ToVector(stash.GetNext()), packet2);
+}
+
+TEST(PacketStash, PruneSize) {
+ PacketStash stash;
+ std::vector<uint8_t> packet1 = {
+ 0x2f, 0x5b, 0x4c, 0x00, 0x23, 0x47, 0xab, 0xe7, 0x90, 0x96,
+ 0xc0, 0xac, 0x2f, 0x25, 0x40, 0x35, 0x35, 0xa3, 0x81, 0x50,
+ 0x0c, 0x38, 0x0a, 0xf6, 0xd4, 0xd5, 0x7d, 0xbe, 0x9a, 0xa3,
+ 0xcb, 0xcb, 0x67, 0xb0, 0x77, 0x79, 0x8b, 0x48, 0x60, 0xf8,
+ };
+
+ std::vector<uint8_t> packet2 = {
+ 0x16, 0xfe, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x0c, 0x0e, 0x00, 0x00, 0x00, 0x00,
+ 0xac, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ };
+
+ std::vector<uint8_t> packet3 = {0x3};
+ std::vector<uint8_t> packet4 = {0x4};
+ std::vector<uint8_t> packet5 = {0x5};
+ std::vector<uint8_t> packet6 = {0x6};
+
+ stash.AddIfUnique(packet1);
+ stash.AddIfUnique(packet2);
+ stash.AddIfUnique(packet3);
+ stash.AddIfUnique(packet4);
+ stash.AddIfUnique(packet5);
+ stash.AddIfUnique(packet6);
+ EXPECT_EQ(stash.size(), 6);
+ EXPECT_EQ(ToVector(stash.GetNext()), packet1);
+ EXPECT_EQ(ToVector(stash.GetNext()), packet2);
+ EXPECT_EQ(ToVector(stash.GetNext()), packet3);
+ EXPECT_EQ(ToVector(stash.GetNext()), packet4);
+ EXPECT_EQ(ToVector(stash.GetNext()), packet5);
+ EXPECT_EQ(ToVector(stash.GetNext()), packet6);
+
+ // Should be NOP.
+ stash.Prune(/* max_size= */ 6);
+ EXPECT_EQ(ToVector(stash.GetNext()), packet1);
+ EXPECT_EQ(ToVector(stash.GetNext()), packet2);
+ EXPECT_EQ(ToVector(stash.GetNext()), packet3);
+ EXPECT_EQ(ToVector(stash.GetNext()), packet4);
+ EXPECT_EQ(ToVector(stash.GetNext()), packet5);
+ EXPECT_EQ(ToVector(stash.GetNext()), packet6);
+
+ // Move "cursor" forward.
+ EXPECT_EQ(ToVector(stash.GetNext()), packet1);
+ stash.Prune(/* max_size= */ 4);
+ EXPECT_EQ(stash.size(), 4);
+ EXPECT_EQ(ToVector(stash.GetNext()), packet3);
+ EXPECT_EQ(ToVector(stash.GetNext()), packet4);
+ EXPECT_EQ(ToVector(stash.GetNext()), packet5);
+ EXPECT_EQ(ToVector(stash.GetNext()), packet6);
+}
+
} // namespace webrtc