Add support for caching more than 1 client hello packets.

This is "needed" for PQC (i.e. to avoid timeout/retransmit).

- move (inlined) code for PacketStash (from StunPiggyBackController) into class that is unit tested separately and also used for the cached
client hellos.
- Extend TestEventOrdering to cover dtls1.3 and PQC.

BUG=webrtc:404763475

Change-Id: I3a05f7685578d3e1de5bdd5e8992a0f60182b263
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/386901
Commit-Queue: Jonas Oreland <jonaso@webrtc.org>
Reviewed-by: Sameer Vijaykar <samvi@google.com>
Cr-Commit-Position: refs/heads/main@{#44411}
diff --git a/p2p/BUILD.gn b/p2p/BUILD.gn
index 63331cf..e03f295 100644
--- a/p2p/BUILD.gn
+++ b/p2p/BUILD.gn
@@ -287,6 +287,7 @@
     "../api/crypto:options",
     "../api/rtc_event_log",
     "../api/task_queue:pending_task_safety_flag",
+    "../api/transport:ecn_marking",
     "../api/transport:stun_types",
     "../api/units:time_delta",
     "../api/units:timestamp",
@@ -705,9 +706,11 @@
   ]
   deps = [
     "../api:array_view",
+    "../rtc_base:buffer",
     "../rtc_base:byte_buffer",
     "../rtc_base:checks",
     "../rtc_base:crc32",
+    "//third_party/abseil-cpp/absl/container:flat_hash_set",
   ]
 }
 
@@ -1284,6 +1287,7 @@
       "../test:test_support",
       "../test:wait_until",
       "//third_party/abseil-cpp/absl/algorithm:container",
+      "//third_party/abseil-cpp/absl/container:flat_hash_set",
       "//third_party/abseil-cpp/absl/functional:any_invocable",
       "//third_party/abseil-cpp/absl/memory",
       "//third_party/abseil-cpp/absl/strings",
diff --git a/p2p/dtls/dtls_stun_piggyback_controller.cc b/p2p/dtls/dtls_stun_piggyback_controller.cc
index 92247c5..db7f484 100644
--- a/p2p/dtls/dtls_stun_piggyback_controller.cc
+++ b/p2p/dtls/dtls_stun_piggyback_controller.cc
@@ -12,7 +12,6 @@
 
 #include <algorithm>
 #include <cstdint>
-#include <memory>
 #include <optional>
 #include <utility>
 #include <vector>
@@ -24,7 +23,6 @@
 #include "api/sequence_checker.h"
 #include "api/transport/stun.h"
 #include "p2p/dtls/dtls_utils.h"
-#include "rtc_base/buffer.h"
 #include "rtc_base/byte_buffer.h"
 #include "rtc_base/checks.h"
 #include "rtc_base/logging.h"
@@ -50,7 +48,6 @@
   // the last flight from the server.
   // For DTLS 1.3 this is reversed since the handshake has one round trip less.
   if ((is_dtls_client && !is_dtls13) || (!is_dtls_client && is_dtls13)) {
-    pending_packet_pos_ = 0;
     pending_packets_.clear();
   }
 
@@ -72,19 +69,15 @@
   // is made for 1-packet at a time. Use the writing_packets_ variable to keep
   // track of a full batch. The writing_packets_ is reset in Flush.
   if (!writing_packets_) {
-    pending_packet_pos_ = 0;
     pending_packets_.clear();
     writing_packets_ = true;
   }
 
-  pending_packets_.push_back(std::make_pair(
-      ComputeDtlsPacketHash(data),
-      std::make_unique<webrtc::Buffer>(data.data(), data.size())));
+  pending_packets_.Add(data);
 }
 
 void DtlsStunPiggybackController::ClearCachedPacketForTesting() {
   RTC_DCHECK_RUN_ON(&sequence_checker_);
-  pending_packet_pos_ = 0;
   pending_packets_.clear();
 }
 
@@ -116,13 +109,13 @@
     return std::nullopt;
   }
 
-  if (pending_packets_.size() == 0) {
+  if (pending_packets_.empty()) {
     return std::nullopt;
   }
 
-  auto pos = pending_packet_pos_;
-  pending_packet_pos_ = (pos + 1) % pending_packets_.size();
-  return absl::string_view(*pending_packets_[pos].second.get());
+  const auto packet = pending_packets_.GetNext();
+  return absl::string_view(reinterpret_cast<const char*>(packet.data()),
+                           packet.size());
 }
 
 std::optional<absl::string_view> DtlsStunPiggybackController::GetAckToPiggyback(
@@ -162,7 +155,6 @@
   if (state_ == State::PENDING && data == nullptr && ack == nullptr) {
     RTC_LOG(LS_INFO) << "DTLS-STUN piggybacking complete.";
     state_ = State::COMPLETE;
-    pending_packet_pos_ = 0;
     pending_packets_.clear();
     handshake_ack_writer_.Clear();
     handshake_messages_received_.clear();
@@ -189,20 +181,7 @@
                           << webrtc::StrJoin(acked_packets, ",");
 
       // Remove all acked packets from pending_packets_.
-      if (!acked_packets.empty()) {
-        uint32_t before = pending_packets_.size();
-        pending_packets_.erase(
-            std::remove_if(pending_packets_.begin(), pending_packets_.end(),
-                           [&](const auto& val) {
-                             return acked_packets.contains(val.first);
-                           }),
-            pending_packets_.end());
-        uint32_t after = pending_packets_.size();
-        uint32_t removed = before - after;
-        if (pending_packet_pos_ >= removed) {
-          pending_packet_pos_ -= removed;
-        }
-      }
+      pending_packets_.Prune(acked_packets);
     }
   }
 
@@ -213,7 +192,6 @@
   if (data == nullptr && ack != nullptr && state_ == State::PENDING) {
     RTC_LOG(LS_INFO) << "DTLS-STUN piggybacking complete.";
     state_ = State::COMPLETE;
-    pending_packet_pos_ = 0;
     pending_packets_.clear();
     handshake_ack_writer_.Clear();
     handshake_messages_received_.clear();
diff --git a/p2p/dtls/dtls_stun_piggyback_controller.h b/p2p/dtls/dtls_stun_piggyback_controller.h
index 71379d2..02aa71c 100644
--- a/p2p/dtls/dtls_stun_piggyback_controller.h
+++ b/p2p/dtls/dtls_stun_piggyback_controller.h
@@ -12,9 +12,7 @@
 #define P2P_DTLS_DTLS_STUN_PIGGYBACK_CONTROLLER_H_
 
 #include <cstdint>
-#include <memory>
 #include <optional>
-#include <utility>
 #include <vector>
 
 #include "absl/functional/any_invocable.h"
@@ -22,7 +20,7 @@
 #include "api/array_view.h"
 #include "api/sequence_checker.h"
 #include "api/transport/stun.h"
-#include "rtc_base/buffer.h"
+#include "p2p/dtls/dtls_utils.h"
 #include "rtc_base/byte_buffer.h"
 #include "rtc_base/system/no_unique_address.h"
 #include "rtc_base/thread_annotations.h"
@@ -91,11 +89,8 @@
  private:
   State state_ RTC_GUARDED_BY(sequence_checker_) = State::TENTATIVE;
   bool writing_packets_ RTC_GUARDED_BY(sequence_checker_) = false;
-  uint32_t pending_packet_pos_ RTC_GUARDED_BY(sequence_checker_) = 0;
-  std::vector<std::pair<uint32_t, std::unique_ptr<Buffer>>> pending_packets_
-      RTC_GUARDED_BY(sequence_checker_);
-  absl::AnyInvocable<void(webrtc::ArrayView<const uint8_t>)>
-      dtls_data_callback_;
+  PacketStash pending_packets_ RTC_GUARDED_BY(sequence_checker_);
+  absl::AnyInvocable<void(ArrayView<const uint8_t>)> dtls_data_callback_;
   absl::AnyInvocable<void()> disable_piggybacking_callback_;
 
   std::vector<uint32_t> handshake_messages_received_
diff --git a/p2p/dtls/dtls_transport.cc b/p2p/dtls/dtls_transport.cc
index 38588df..f8d118a 100644
--- a/p2p/dtls/dtls_transport.cc
+++ b/p2p/dtls/dtls_transport.cc
@@ -28,6 +28,7 @@
 #include "api/scoped_refptr.h"
 #include "api/sequence_checker.h"
 #include "api/task_queue/pending_task_safety_flag.h"
+#include "api/transport/ecn_marking.h"
 #include "api/transport/stun.h"
 #include "api/units/time_delta.h"
 #include "api/units/timestamp.h"
@@ -39,11 +40,12 @@
 #include "p2p/dtls/dtls_stun_piggyback_controller.h"
 #include "p2p/dtls/dtls_transport_internal.h"
 #include "p2p/dtls/dtls_utils.h"
+#include "rtc_base/async_packet_socket.h"
 #include "rtc_base/buffer.h"
 #include "rtc_base/checks.h"
 #include "rtc_base/logging.h"
-#include "rtc_base/network/ecn_marking.h"
 #include "rtc_base/network/received_packet.h"
+#include "rtc_base/network/sent_packet.h"
 #include "rtc_base/network_route.h"
 #include "rtc_base/rtc_certificate.h"
 #include "rtc_base/socket.h"
@@ -97,6 +99,8 @@
 // This effectively disables the handshake timeout.
 constexpr int kDisabledHandshakeTimeoutMs = 3600 * 1000 * 24;
 
+constexpr uint32_t kMaxCachedClientHello = 4;
+
 static bool IsRtpPacket(ArrayView<const uint8_t> payload) {
   const uint8_t* u = payload.data();
   return (payload.size() >= kMinRtpPacketLen && (u[0] & 0xC0) == 0x80);
@@ -773,7 +777,8 @@
         RTC_LOG(LS_INFO) << ToString()
                          << ": Caching DTLS ClientHello packet until DTLS is "
                             "started.";
-        cached_client_hello_.SetData(packet.payload());
+        cached_client_hello_.AddIfUnique(packet.payload());
+        cached_client_hello_.Prune(kMaxCachedClientHello);
         // If we haven't started setting up DTLS yet (because we don't have a
         // remote fingerprint/role), we can use the client hello as a clue that
         // the peer has chosen the client role, and proceed with the handshake.
@@ -946,19 +951,24 @@
     set_dtls_state(webrtc::DtlsTransportState::kConnecting);
     // Now that the handshake has started, we can process a cached ClientHello
     // (if one exists).
-    if (cached_client_hello_.size()) {
+    if (!cached_client_hello_.empty()) {
       if (*dtls_role_ == webrtc::SSL_SERVER) {
-        RTC_LOG(LS_INFO) << ToString()
-                         << ": Handling cached DTLS ClientHello packet.";
-        if (!HandleDtlsPacket(cached_client_hello_)) {
-          RTC_LOG(LS_ERROR) << ToString() << ": Failed to handle DTLS packet.";
+        int size = cached_client_hello_.size();
+        RTC_LOG(LS_INFO) << ToString() << ": Handling #" << size
+                         << " cached DTLS ClientHello packet(s).";
+        for (int i = 0; i < size; i++) {
+          if (!HandleDtlsPacket(cached_client_hello_.GetNext())) {
+            RTC_LOG(LS_ERROR)
+                << ToString() << ": Failed to handle DTLS packet.";
+            break;
+          }
         }
       } else {
         RTC_LOG(LS_WARNING) << ToString()
                             << ": Discarding cached DTLS ClientHello packet "
                                "because we don't have the server role.";
       }
-      cached_client_hello_.Clear();
+      cached_client_hello_.clear();
     }
   }
 }
diff --git a/p2p/dtls/dtls_transport.h b/p2p/dtls/dtls_transport.h
index 70a4a7f..e39508a 100644
--- a/p2p/dtls/dtls_transport.h
+++ b/p2p/dtls/dtls_transport.h
@@ -32,6 +32,7 @@
 #include "p2p/base/packet_transport_internal.h"
 #include "p2p/dtls/dtls_stun_piggyback_controller.h"
 #include "p2p/dtls/dtls_transport_internal.h"
+#include "p2p/dtls/dtls_utils.h"
 #include "rtc_base/async_packet_socket.h"
 #include "rtc_base/buffer.h"
 #include "rtc_base/buffer_queue.h"
@@ -289,7 +290,7 @@
   // Cached DTLS ClientHello packet that was received before we started the
   // DTLS handshake. This could happen if the hello was received before the
   // ice transport became writable, or before a remote fingerprint was received.
-  Buffer cached_client_hello_;
+  PacketStash cached_client_hello_;
 
   bool receiving_ = false;
   bool writable_ = false;
diff --git a/p2p/dtls/dtls_transport_unittest.cc b/p2p/dtls/dtls_transport_unittest.cc
index 0ef049c..a32c477 100644
--- a/p2p/dtls/dtls_transport_unittest.cc
+++ b/p2p/dtls/dtls_transport_unittest.cc
@@ -108,6 +108,7 @@
   void SetupMaxProtocolVersion(SSLProtocolVersion version) {
     ssl_max_version_ = version;
   }
+  void SetPqc(bool value) { pqc_ = value; }
   void set_async_delay(int async_delay_ms) { async_delay_ms_ = async_delay_ms; }
 
   // Set up fake ICE transport and real DTLS transport under test.
@@ -117,6 +118,10 @@
     dtls_transport_ = nullptr;
     fake_ice_transport_ = nullptr;
 
+    if (field_trials_string.empty() && pqc_) {
+      field_trials_string = "WebRTC-EnableDtlsPqc/Enabled/";
+    }
+
     fake_ice_transport_.reset(new FakeIceTransport(
         absl::StrCat("fake-", name_), 0,
         /* network_thread= */ nullptr, field_trials_string));
@@ -390,6 +395,7 @@
   SentPacketInfo sent_packet_;
   absl::AnyInvocable<void()> writable_func_;
   int async_delay_ms_ = 100;
+  bool pqc_ = false;
 };
 
 // Base class for DtlsTransportInternalImplTest and DtlsEventOrderingTest, which
@@ -404,10 +410,16 @@
     start_time_ns_ = fake_clock_.TimeNanos();
   }
 
+  void SetPqc(bool value) {
+    client1_.SetPqc(value);
+    client2_.SetPqc(value);
+  }
+
   void SetMaxProtocolVersions(SSLProtocolVersion c1, SSLProtocolVersion c2) {
     client1_.SetupMaxProtocolVersion(c1);
     client2_.SetupMaxProtocolVersion(c2);
   }
+
   // If not called, DtlsTransportInternalImpl will be used in SRTP bypass mode.
   void PrepareDtls(KeyType key_type) {
     client1_.CreateCertificate(key_type);
@@ -555,6 +567,7 @@
   DtlsTestClient client1_;
   DtlsTestClient client2_;
   bool use_dtls_;
+  bool pqc_ = false;
   uint64_t start_time_ns_;
   SSLProtocolVersion ssl_expected_version_;
 };
@@ -1362,13 +1375,25 @@
 class DtlsEventOrderingTest
     : public DtlsTransportInternalImplTestBase,
       public ::testing::TestWithParam<
-          ::testing::tuple<std::vector<DtlsTransportInternalImplEvent>, bool>> {
+          ::testing::tuple<std::vector<DtlsTransportInternalImplEvent>,
+                           bool /* valid_fingerprint */,
+                           SSLProtocolVersion,
+                           bool /* pqc */>> {
  protected:
   // If `valid_fingerprint` is false, the caller will receive a fingerprint
   // that doesn't match the callee's certificate, so the handshake should fail.
   void TestEventOrdering(
       const std::vector<DtlsTransportInternalImplEvent>& events,
       bool valid_fingerprint) {
+    bool pqc = ::testing::get<3>(GetParam());
+    if (pqc && ::testing::get<2>(GetParam()) != SSL_PROTOCOL_DTLS_13) {
+      GTEST_SKIP() << "PQC requires DTLS1.3";
+    }
+
+    SetPqc(::testing::get<3>(GetParam()));
+    SetMaxProtocolVersions(::testing::get<2>(GetParam()),
+                           ::testing::get<2>(GetParam()));
+
     // Pre-setup: Set local certificate on both caller and callee, and
     // remote fingerprint on callee, but neither is writable and the caller
     // doesn't have the callee's fingerprint.
@@ -1406,7 +1431,7 @@
           EXPECT_TRUE(WaitUntil(
               [&] { return client2_.fake_ice_transport()->writable(); }));
           EXPECT_TRUE(WaitUntil(
-              [&] { return client1_.received_dtls_client_hellos() == 1; }));
+              [&] { return client1_.received_dtls_client_hellos() >= 1; }));
           break;
         case HANDSHAKE_FINISHES:
           // Sanity check that the handshake hasn't already finished.
@@ -1442,8 +1467,9 @@
     EXPECT_EQ(valid_fingerprint, client1_.dtls_transport()->writable());
     EXPECT_EQ(valid_fingerprint, client2_.dtls_transport()->writable());
 
+    int count = pqc ? 2 : 1;
     // Check that no hello needed to be retransmitted.
-    EXPECT_EQ(1, client1_.received_dtls_client_hellos());
+    EXPECT_EQ(count, client1_.received_dtls_client_hellos());
     EXPECT_EQ(1, client2_.received_dtls_server_hellos());
 
     if (valid_fingerprint) {
@@ -1486,6 +1512,8 @@
             std::vector<DtlsTransportInternalImplEvent>{
                 CALLER_RECEIVES_CLIENTHELLO, CALLER_WRITABLE,
                 HANDSHAKE_FINISHES, CALLER_RECEIVES_FINGERPRINT}),
+        ::testing::Bool(),
+        ::testing::Values(SSL_PROTOCOL_DTLS_12, SSL_PROTOCOL_DTLS_13),
         ::testing::Bool()));
 
 class DtlsTransportInternalImplDtlsInStunTest
diff --git a/p2p/dtls/dtls_utils.cc b/p2p/dtls/dtls_utils.cc
index 69353c2..6a32d4c 100644
--- a/p2p/dtls/dtls_utils.cc
+++ b/p2p/dtls/dtls_utils.cc
@@ -10,11 +10,15 @@
 
 #include "p2p/dtls/dtls_utils.h"
 
+#include <algorithm>
 #include <cstdint>
+#include <memory>
 #include <optional>
 #include <vector>
 
+#include "absl/container/flat_hash_set.h"
 #include "api/array_view.h"
+#include "rtc_base/buffer.h"
 #include "rtc_base/byte_buffer.h"
 #include "rtc_base/checks.h"
 #include "rtc_base/crc32.h"
@@ -146,4 +150,62 @@
   return webrtc::ComputeCrc32(dtls_packet.data(), dtls_packet.size());
 }
 
+bool PacketStash::AddIfUnique(rtc::ArrayView<const uint8_t> packet) {
+  uint32_t h = ComputeDtlsPacketHash(packet);
+  for (const auto& [hash, p] : packets_) {
+    if (h == hash) {
+      return false;
+    }
+  }
+  packets_.push_back({.hash = h,
+                      .buffer = std::make_unique<webrtc::Buffer>(
+                          packet.data(), packet.size())});
+  return true;
+}
+
+void PacketStash::Add(rtc::ArrayView<const uint8_t> packet) {
+  packets_.push_back({.hash = ComputeDtlsPacketHash(packet),
+                      .buffer = std::make_unique<webrtc::Buffer>(
+                          packet.data(), packet.size())});
+}
+
+void PacketStash::Prune(const absl::flat_hash_set<uint32_t>& hashes) {
+  if (hashes.empty()) {
+    return;
+  }
+  uint32_t before = packets_.size();
+  packets_.erase(std::remove_if(packets_.begin(), packets_.end(),
+                                [&](const auto& val) {
+                                  return hashes.contains(val.hash);
+                                }),
+                 packets_.end());
+  uint32_t after = packets_.size();
+  uint32_t removed = before - after;
+  if (pos_ >= removed) {
+    pos_ -= removed;
+  }
+}
+
+void PacketStash::Prune(uint32_t max_size) {
+  auto size = packets_.size();
+  if (size <= max_size) {
+    return;
+  }
+  auto removed = size - max_size;
+  packets_.erase(packets_.begin(), packets_.begin() + removed);
+  if (pos_ <= removed) {
+    pos_ = 0;
+  } else {
+    pos_ -= removed;
+  }
+}
+
+rtc::ArrayView<const uint8_t> PacketStash::GetNext() {
+  RTC_DCHECK(!packets_.empty());
+  auto pos = pos_;
+  pos_ = (pos + 1) % packets_.size();
+  const auto& buffer = packets_[pos].buffer;
+  return rtc::ArrayView<const uint8_t>(buffer->data(), buffer->size());
+}
+
 }  // namespace webrtc
diff --git a/p2p/dtls/dtls_utils.h b/p2p/dtls/dtls_utils.h
index e9af4b9..773df9b 100644
--- a/p2p/dtls/dtls_utils.h
+++ b/p2p/dtls/dtls_utils.h
@@ -13,10 +13,13 @@
 
 #include <cstddef>
 #include <cstdint>
+#include <memory>
 #include <optional>
 #include <vector>
 
+#include "absl/container/flat_hash_set.h"
 #include "api/array_view.h"
+#include "rtc_base/buffer.h"
 
 namespace webrtc {
 
@@ -32,6 +35,40 @@
 
 uint32_t ComputeDtlsPacketHash(ArrayView<const uint8_t> dtls_packet);
 
+class PacketStash {
+ public:
+  PacketStash() {}
+
+  void Add(rtc::ArrayView<const uint8_t> packet);
+  bool AddIfUnique(rtc::ArrayView<const uint8_t> packet);
+  void Prune(const absl::flat_hash_set<uint32_t>& packet_hashes);
+  void Prune(uint32_t max_size);
+  rtc::ArrayView<const uint8_t> GetNext();
+
+  void clear() {
+    packets_.clear();
+    pos_ = 0;
+  }
+  bool empty() const { return packets_.empty(); }
+  int size() const { return packets_.size(); }
+
+  static uint32_t Hash(rtc::ArrayView<const uint8_t> packet) {
+    return ComputeDtlsPacketHash(packet);
+  }
+
+ private:
+  struct StashedPacket {
+    uint32_t hash;
+    std::unique_ptr<rtc::Buffer> buffer;
+  };
+
+  // This vector will only contain very few items,
+  // so it is appropriate to use a vector rather than
+  // e.g. a hash map.
+  uint32_t pos_ = 0;
+  std::vector<StashedPacket> packets_;
+};
+
 }  //  namespace webrtc
 
 // Re-export symbols from the webrtc namespace for backwards compatibility.
diff --git a/p2p/dtls/dtls_utils_unittest.cc b/p2p/dtls/dtls_utils_unittest.cc
index 5e7265f..2991f0a 100644
--- a/p2p/dtls/dtls_utils_unittest.cc
+++ b/p2p/dtls/dtls_utils_unittest.cc
@@ -14,6 +14,8 @@
 #include <optional>
 #include <vector>
 
+#include "absl/container/flat_hash_set.h"
+#include "api/array_view.h"
 #include "test/gmock.h"
 #include "test/gtest.h"
 
@@ -198,4 +200,141 @@
   EXPECT_EQ(acks->size(), 0u);
 }
 
+std::vector<uint8_t> ToVector(rtc::ArrayView<const uint8_t> array) {
+  return std::vector<uint8_t>(array.begin(), array.end());
+}
+
+TEST(PacketStash, Add) {
+  PacketStash stash;
+  std::vector<uint8_t> packet = {
+      0x2f, 0x5b, 0x4c, 0x00, 0x23, 0x47, 0xab, 0xe7, 0x90, 0x96,
+      0xc0, 0xac, 0x2f, 0x25, 0x40, 0x35, 0x35, 0xa3, 0x81, 0x50,
+      0x0c, 0x38, 0x0a, 0xf6, 0xd4, 0xd5, 0x7d, 0xbe, 0x9a, 0xa3,
+      0xcb, 0xcb, 0x67, 0xb0, 0x77, 0x79, 0x8b, 0x48, 0x60, 0xf8,
+  };
+
+  stash.Add(packet);
+  EXPECT_EQ(stash.size(), 1);
+  EXPECT_EQ(ToVector(stash.GetNext()), packet);
+
+  stash.Add(packet);
+  EXPECT_EQ(stash.size(), 2);
+  EXPECT_EQ(ToVector(stash.GetNext()), packet);
+  EXPECT_EQ(ToVector(stash.GetNext()), packet);
+}
+
+TEST(PacketStash, AddIfUnique) {
+  PacketStash stash;
+  std::vector<uint8_t> packet1 = {
+      0x2f, 0x5b, 0x4c, 0x00, 0x23, 0x47, 0xab, 0xe7, 0x90, 0x96,
+      0xc0, 0xac, 0x2f, 0x25, 0x40, 0x35, 0x35, 0xa3, 0x81, 0x50,
+      0x0c, 0x38, 0x0a, 0xf6, 0xd4, 0xd5, 0x7d, 0xbe, 0x9a, 0xa3,
+      0xcb, 0xcb, 0x67, 0xb0, 0x77, 0x79, 0x8b, 0x48, 0x60, 0xf8,
+  };
+
+  std::vector<uint8_t> packet2 = {
+      0x16, 0xfe, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+      0x00, 0x00, 0x00, 0x0c, 0x0e, 0x00, 0x00, 0x00, 0x00,
+      0xac, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+  };
+
+  stash.AddIfUnique(packet1);
+  EXPECT_EQ(stash.size(), 1);
+  EXPECT_EQ(ToVector(stash.GetNext()), packet1);
+
+  stash.AddIfUnique(packet1);
+  EXPECT_EQ(stash.size(), 1);
+  EXPECT_EQ(ToVector(stash.GetNext()), packet1);
+
+  stash.AddIfUnique(packet2);
+  EXPECT_EQ(stash.size(), 2);
+  EXPECT_EQ(ToVector(stash.GetNext()), packet1);
+  EXPECT_EQ(ToVector(stash.GetNext()), packet2);
+
+  stash.AddIfUnique(packet2);
+  EXPECT_EQ(stash.size(), 2);
+}
+
+TEST(PacketStash, Prune) {
+  PacketStash stash;
+  std::vector<uint8_t> packet1 = {
+      0x2f, 0x5b, 0x4c, 0x00, 0x23, 0x47, 0xab, 0xe7, 0x90, 0x96,
+      0xc0, 0xac, 0x2f, 0x25, 0x40, 0x35, 0x35, 0xa3, 0x81, 0x50,
+      0x0c, 0x38, 0x0a, 0xf6, 0xd4, 0xd5, 0x7d, 0xbe, 0x9a, 0xa3,
+      0xcb, 0xcb, 0x67, 0xb0, 0x77, 0x79, 0x8b, 0x48, 0x60, 0xf8,
+  };
+
+  std::vector<uint8_t> packet2 = {
+      0x16, 0xfe, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+      0x00, 0x00, 0x00, 0x0c, 0x0e, 0x00, 0x00, 0x00, 0x00,
+      0xac, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+  };
+
+  stash.AddIfUnique(packet1);
+  stash.AddIfUnique(packet2);
+  EXPECT_EQ(stash.size(), 2);
+  EXPECT_EQ(ToVector(stash.GetNext()), packet1);
+  EXPECT_EQ(ToVector(stash.GetNext()), packet2);
+
+  absl::flat_hash_set<uint32_t> remove;
+  remove.insert(PacketStash::Hash(packet1));
+  stash.Prune(remove);
+
+  EXPECT_EQ(stash.size(), 1);
+  EXPECT_EQ(ToVector(stash.GetNext()), packet2);
+}
+
+TEST(PacketStash, PruneSize) {
+  PacketStash stash;
+  std::vector<uint8_t> packet1 = {
+      0x2f, 0x5b, 0x4c, 0x00, 0x23, 0x47, 0xab, 0xe7, 0x90, 0x96,
+      0xc0, 0xac, 0x2f, 0x25, 0x40, 0x35, 0x35, 0xa3, 0x81, 0x50,
+      0x0c, 0x38, 0x0a, 0xf6, 0xd4, 0xd5, 0x7d, 0xbe, 0x9a, 0xa3,
+      0xcb, 0xcb, 0x67, 0xb0, 0x77, 0x79, 0x8b, 0x48, 0x60, 0xf8,
+  };
+
+  std::vector<uint8_t> packet2 = {
+      0x16, 0xfe, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+      0x00, 0x00, 0x00, 0x0c, 0x0e, 0x00, 0x00, 0x00, 0x00,
+      0xac, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+  };
+
+  std::vector<uint8_t> packet3 = {0x3};
+  std::vector<uint8_t> packet4 = {0x4};
+  std::vector<uint8_t> packet5 = {0x5};
+  std::vector<uint8_t> packet6 = {0x6};
+
+  stash.AddIfUnique(packet1);
+  stash.AddIfUnique(packet2);
+  stash.AddIfUnique(packet3);
+  stash.AddIfUnique(packet4);
+  stash.AddIfUnique(packet5);
+  stash.AddIfUnique(packet6);
+  EXPECT_EQ(stash.size(), 6);
+  EXPECT_EQ(ToVector(stash.GetNext()), packet1);
+  EXPECT_EQ(ToVector(stash.GetNext()), packet2);
+  EXPECT_EQ(ToVector(stash.GetNext()), packet3);
+  EXPECT_EQ(ToVector(stash.GetNext()), packet4);
+  EXPECT_EQ(ToVector(stash.GetNext()), packet5);
+  EXPECT_EQ(ToVector(stash.GetNext()), packet6);
+
+  // Should be NOP.
+  stash.Prune(/* max_size= */ 6);
+  EXPECT_EQ(ToVector(stash.GetNext()), packet1);
+  EXPECT_EQ(ToVector(stash.GetNext()), packet2);
+  EXPECT_EQ(ToVector(stash.GetNext()), packet3);
+  EXPECT_EQ(ToVector(stash.GetNext()), packet4);
+  EXPECT_EQ(ToVector(stash.GetNext()), packet5);
+  EXPECT_EQ(ToVector(stash.GetNext()), packet6);
+
+  // Move "cursor" forward.
+  EXPECT_EQ(ToVector(stash.GetNext()), packet1);
+  stash.Prune(/* max_size= */ 4);
+  EXPECT_EQ(stash.size(), 4);
+  EXPECT_EQ(ToVector(stash.GetNext()), packet3);
+  EXPECT_EQ(ToVector(stash.GetNext()), packet4);
+  EXPECT_EQ(ToVector(stash.GetNext()), packet5);
+  EXPECT_EQ(ToVector(stash.GetNext()), packet6);
+}
+
 }  // namespace webrtc