dcsctp: implement socket handover in the DcSctpSocket class and expose the functionality in the API

Bug: webrtc:13154
Change-Id: Idf4f4028c8e65943cb6b41fab0baef1b3584205d
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/232126
Reviewed-by: Victor Boivie <boivie@webrtc.org>
Commit-Queue: Sergey Sukhanov <sergeysu@webrtc.org>
Cr-Commit-Position: refs/heads/main@{#35029}
diff --git a/net/dcsctp/public/dcsctp_handover_state.h b/net/dcsctp/public/dcsctp_handover_state.h
index 8907669..3ad4ab7 100644
--- a/net/dcsctp/public/dcsctp_handover_state.h
+++ b/net/dcsctp/public/dcsctp_handover_state.h
@@ -24,6 +24,25 @@
 // for serialization. Serialization is not provided by dcSCTP. If needed it has
 // to be implemented in the calling client.
 struct DcSctpSocketHandoverState {
+  enum class SocketState {
+    kClosed,
+    kConnected,
+  };
+  SocketState socket_state = SocketState::kClosed;
+
+  uint32_t my_verification_tag = 0;
+  uint32_t my_initial_tsn = 0;
+  uint32_t peer_verification_tag = 0;
+  uint32_t peer_initial_tsn = 0;
+  uint64_t tie_tag = 0;
+
+  struct Capabilities {
+    bool partial_reliability = false;
+    bool message_interleaving = false;
+    bool reconfig = false;
+  };
+  Capabilities capabilities;
+
   struct Transmission {
     uint32_t next_tsn = 0;
     uint32_t next_reset_req_sn = 0;
@@ -98,6 +117,7 @@
     value() |= status.value();
     return *this;
   }
+  std::string ToString() const;
 };
 
 }  // namespace dcsctp
diff --git a/net/dcsctp/public/dcsctp_socket.h b/net/dcsctp/public/dcsctp_socket.h
index 248646e..583d037 100644
--- a/net/dcsctp/public/dcsctp_socket.h
+++ b/net/dcsctp/public/dcsctp_socket.h
@@ -17,6 +17,7 @@
 #include "absl/strings/string_view.h"
 #include "absl/types/optional.h"
 #include "api/array_view.h"
+#include "net/dcsctp/public/dcsctp_handover_state.h"
 #include "net/dcsctp/public/dcsctp_message.h"
 #include "net/dcsctp/public/dcsctp_options.h"
 #include "net/dcsctp/public/packet_observer.h"
@@ -355,6 +356,14 @@
   // `DcSctpSocketCallbacks::OnConnected` will be called on success.
   virtual void Connect() = 0;
 
+  // Puts this socket to the state in which the original socket was when its
+  // `DcSctpSocketHandoverState` was captured by `GetHandoverStateAndClose`.
+  // `RestoreFromState` is allowed only on the closed socket.
+  // `DcSctpSocketCallbacks::OnConnected` will be called if a connected socket
+  // state is restored.
+  // `DcSctpSocketCallbacks::OnError` will be called on error.
+  virtual void RestoreFromState(const DcSctpSocketHandoverState& state) = 0;
+
   // Gracefully shutdowns the socket and sends all outstanding data. This is an
   // asynchronous operation and `DcSctpSocketCallbacks::OnClosed` will be called
   // on success.
@@ -417,6 +426,20 @@
 
   // Retrieves the latest metrics.
   virtual Metrics GetMetrics() const = 0;
+
+  // Returns empty bitmask if the socket is in the state in which a snapshot of
+  // the state can be made by `GetHandoverStateAndClose()`. Return value is
+  // invalidated by a call to any non-const method.
+  virtual HandoverReadinessStatus GetHandoverReadiness() const = 0;
+
+  // Collects a snapshot of the socket state that can be used to reconstruct
+  // this socket in another process. On success this socket object is closed
+  // synchronously and no callbacks will be made after the method has returned.
+  // The method fails if the socket is not in a state ready for handover.
+  // nullopt indicates the failure. `DcSctpSocketCallbacks::OnClosed` will be
+  // called on success.
+  virtual absl::optional<DcSctpSocketHandoverState>
+  GetHandoverStateAndClose() = 0;
 };
 }  // namespace dcsctp
 
diff --git a/net/dcsctp/public/mock_dcsctp_socket.h b/net/dcsctp/public/mock_dcsctp_socket.h
index b382773..eb1e8cc 100644
--- a/net/dcsctp/public/mock_dcsctp_socket.h
+++ b/net/dcsctp/public/mock_dcsctp_socket.h
@@ -26,6 +26,11 @@
 
   MOCK_METHOD(void, Connect, (), (override));
 
+  MOCK_METHOD(void,
+              RestoreFromState,
+              (const DcSctpSocketHandoverState&),
+              (override));
+
   MOCK_METHOD(void, Shutdown, (), (override));
 
   MOCK_METHOD(void, Close, (), (override));
@@ -59,6 +64,15 @@
               (override));
 
   MOCK_METHOD(Metrics, GetMetrics, (), (const, override));
+
+  MOCK_METHOD(HandoverReadinessStatus,
+              GetHandoverReadiness,
+              (),
+              (const, override));
+  MOCK_METHOD(absl::optional<DcSctpSocketHandoverState>,
+              GetHandoverStateAndClose,
+              (),
+              (override));
 };
 
 }  // namespace dcsctp
diff --git a/net/dcsctp/socket/dcsctp_socket.cc b/net/dcsctp/socket/dcsctp_socket.cc
index 5211cca..afc30f6 100644
--- a/net/dcsctp/socket/dcsctp_socket.cc
+++ b/net/dcsctp/socket/dcsctp_socket.cc
@@ -139,8 +139,57 @@
                 static_cast<uint64_t>(tie_tag_lower));
 }
 
+constexpr absl::string_view HandoverUnreadinessReasonToString(
+    HandoverUnreadinessReason reason) {
+  switch (reason) {
+    case HandoverUnreadinessReason::kWrongConnectionState:
+      return "WRONG_CONNECTION_STATE";
+    case HandoverUnreadinessReason::kSendQueueNotEmpty:
+      return "SEND_QUEUE_NOT_EMPTY";
+    case HandoverUnreadinessReason::kDataTrackerTsnBlocksPending:
+      return "DATA_TRACKER_TSN_BLOCKS_PENDING";
+    case HandoverUnreadinessReason::kReassemblyQueueDeliveredTSNsGap:
+      return "REASSEMBLY_QUEUE_DELIVERED_TSN_GAP";
+    case HandoverUnreadinessReason::kStreamResetDeferred:
+      return "STREAM_RESET_DEFERRED";
+    case HandoverUnreadinessReason::kOrderedStreamHasUnassembledChunks:
+      return "ORDERED_STREAM_HAS_UNASSEMBLED_CHUNKS";
+    case HandoverUnreadinessReason::kUnorderedStreamHasUnassembledChunks:
+      return "UNORDERED_STREAM_HAS_UNASSEMBLED_CHUNKS";
+    case HandoverUnreadinessReason::kRetransmissionQueueOutstandingData:
+      return "RETRANSMISSION_QUEUE_OUTSTANDING_DATA";
+    case HandoverUnreadinessReason::kRetransmissionQueueFastRecovery:
+      return "RETRANSMISSION_QUEUE_FAST_RECOVERY";
+    case HandoverUnreadinessReason::kRetransmissionQueueNotEmpty:
+      return "RETRANSMISSION_QUEUE_NOT_EMPTY";
+    case HandoverUnreadinessReason::kPendingStreamReset:
+      return "PENDING_STREAM_RESET";
+    case HandoverUnreadinessReason::kPendingStreamResetRequest:
+      return "PENDING_STREAM_RESET_REQUEST";
+  }
+}
 }  // namespace
 
+std::string HandoverReadinessStatus::ToString() const {
+  std::string result;
+  for (uint32_t bit = 1;
+       bit <= static_cast<uint32_t>(HandoverUnreadinessReason::kMax);
+       bit *= 2) {
+    auto flag = static_cast<HandoverUnreadinessReason>(bit);
+    if (Contains(flag)) {
+      if (!result.empty()) {
+        result.append(",");
+      }
+      absl::string_view s = HandoverUnreadinessReasonToString(flag);
+      result.append(s.data(), s.size());
+    }
+  }
+  if (result.empty()) {
+    result = "READY";
+  }
+  return result;
+}
+
 DcSctpSocket::DcSctpSocket(absl::string_view log_prefix,
                            DcSctpSocketCallbacks& callbacks,
                            std::unique_ptr<PacketObserver> packet_observer,
@@ -286,6 +335,42 @@
   callbacks_.TriggerDeferred();
 }
 
+void DcSctpSocket::RestoreFromState(const DcSctpSocketHandoverState& state) {
+  if (state_ != State::kClosed) {
+    callbacks_.OnError(ErrorKind::kUnsupportedOperation,
+                       "Only closed socket can be restored from state");
+  } else {
+    if (state.socket_state ==
+        DcSctpSocketHandoverState::SocketState::kConnected) {
+      VerificationTag my_verification_tag =
+          VerificationTag(state.my_verification_tag);
+      connect_params_.verification_tag = my_verification_tag;
+
+      Capabilities capabilities;
+      capabilities.partial_reliability = state.capabilities.partial_reliability;
+      capabilities.message_interleaving =
+          state.capabilities.message_interleaving;
+      capabilities.reconfig = state.capabilities.reconfig;
+
+      tcb_ = std::make_unique<TransmissionControlBlock>(
+          timer_manager_, log_prefix_, options_, capabilities, callbacks_,
+          send_queue_, my_verification_tag, TSN(state.my_initial_tsn),
+          VerificationTag(state.peer_verification_tag),
+          TSN(state.peer_initial_tsn), static_cast<size_t>(0),
+          TieTag(state.tie_tag), packet_sender_,
+          [this]() { return state_ == State::kEstablished; }, &state);
+      RTC_DLOG(LS_VERBOSE) << log_prefix() << "Created peer TCB from state: "
+                           << tcb_->ToString();
+
+      SetState(State::kEstablished, "restored from handover state");
+      callbacks_.OnConnected();
+    }
+  }
+
+  RTC_DCHECK(IsConsistent());
+  callbacks_.TriggerDeferred();
+}
+
 void DcSctpSocket::Shutdown() {
   if (tcb_ != nullptr) {
     // https://tools.ietf.org/html/rfc4960#section-9.2
@@ -1579,4 +1664,38 @@
   t2_shutdown_->Start();
 }
 
+HandoverReadinessStatus DcSctpSocket::GetHandoverReadiness() const {
+  HandoverReadinessStatus status;
+  if (state_ != State::kClosed && state_ != State::kEstablished) {
+    status.Add(HandoverUnreadinessReason::kWrongConnectionState);
+  }
+  if (!send_queue_.IsEmpty()) {
+    status.Add(HandoverUnreadinessReason::kSendQueueNotEmpty);
+  }
+  if (tcb_) {
+    status.Add(tcb_->GetHandoverReadiness());
+  }
+  return status;
+}
+
+absl::optional<DcSctpSocketHandoverState>
+DcSctpSocket::GetHandoverStateAndClose() {
+  if (!GetHandoverReadiness().IsReady()) {
+    return absl::nullopt;
+  }
+
+  DcSctpSocketHandoverState state;
+
+  if (state_ == State::kClosed) {
+    state.socket_state = DcSctpSocketHandoverState::SocketState::kClosed;
+  } else if (state_ == State::kEstablished) {
+    state.socket_state = DcSctpSocketHandoverState::SocketState::kConnected;
+    tcb_->AddHandoverState(state);
+    InternalClose(ErrorKind::kNoError, "handover");
+    callbacks_.TriggerDeferred();
+  }
+
+  return std::move(state);
+}
+
 }  // namespace dcsctp
diff --git a/net/dcsctp/socket/dcsctp_socket.h b/net/dcsctp/socket/dcsctp_socket.h
index 60359bd..508a8a6 100644
--- a/net/dcsctp/socket/dcsctp_socket.h
+++ b/net/dcsctp/socket/dcsctp_socket.h
@@ -85,6 +85,7 @@
   void ReceivePacket(rtc::ArrayView<const uint8_t> data) override;
   void HandleTimeout(TimeoutID timeout_id) override;
   void Connect() override;
+  void RestoreFromState(const DcSctpSocketHandoverState& state) override;
   void Shutdown() override;
   void Close() override;
   SendStatus Send(DcSctpMessage message,
@@ -98,6 +99,8 @@
   size_t buffered_amount_low_threshold(StreamID stream_id) const override;
   void SetBufferedAmountLowThreshold(StreamID stream_id, size_t bytes) override;
   Metrics GetMetrics() const override;
+  HandoverReadinessStatus GetHandoverReadiness() const override;
+  absl::optional<DcSctpSocketHandoverState> GetHandoverStateAndClose() override;
 
   // Returns this socket's verification tag, or zero if not yet connected.
   VerificationTag verification_tag() const {
diff --git a/net/dcsctp/socket/dcsctp_socket_test.cc b/net/dcsctp/socket/dcsctp_socket_test.cc
index 5f99cc9..2fadde8 100644
--- a/net/dcsctp/socket/dcsctp_socket_test.cc
+++ b/net/dcsctp/socket/dcsctp_socket_test.cc
@@ -315,6 +315,24 @@
     EXPECT_EQ(sock_z_->state(), SocketState::kConnected);
   }
 
+  void HandoverSocketZ() {
+    ASSERT_EQ(sock_z_->GetHandoverReadiness(), HandoverReadinessStatus());
+    bool is_closed = sock_z_->state() == SocketState::kClosed;
+    if (!is_closed) {
+      EXPECT_CALL(cb_z_, OnClosed).Times(1);
+    }
+    absl::optional<DcSctpSocketHandoverState> handover_state =
+        sock_z_->GetHandoverStateAndClose();
+    EXPECT_TRUE(handover_state.has_value());
+    cb_z_.Reset();
+    sock_z_ = std::make_unique<DcSctpSocket>("Z", cb_z_, GetPacketObserver("Z"),
+                                             options_);
+    if (!is_closed) {
+      EXPECT_CALL(cb_z_, OnConnected).Times(1);
+    }
+    sock_z_->RestoreFromState(*handover_state);
+  }
+
   const DcSctpOptions options_;
   testing::NiceMock<MockDcSctpSocketCallbacks> cb_a_;
   testing::NiceMock<MockDcSctpSocketCallbacks> cb_z_;
@@ -322,6 +340,52 @@
   std::unique_ptr<DcSctpSocket> sock_z_;
 };
 
+// Test parameter that controls whether to perform handovers during the test. A
+// test can have multiple points where it conditionally hands over socket Z.
+// Either socket Z will be handed over at all those points or handed over never.
+enum class HandoverMode {
+  kNoHandover,
+  kPerformHandovers,
+};
+
+class DcSctpSocketParametrizedTest
+    : public DcSctpSocketTest,
+      public ::testing::WithParamInterface<HandoverMode> {
+ protected:
+  // Trigger handover for socket Z depending on the current test param.
+  void MaybeHandoverSocketZ() {
+    if (GetParam() == HandoverMode::kPerformHandovers) {
+      HandoverSocketZ();
+    }
+  }
+  // Trigger handover for socket Z depending on the current test param.
+  // Then checks message passing to verify the handed over socket is functional.
+  void MaybeHandoverSocketZAndSendMessage() {
+    if (GetParam() == HandoverMode::kPerformHandovers) {
+      HandoverSocketZ();
+    }
+
+    ExchangeMessages(*sock_a_, cb_a_, *sock_z_, cb_z_);
+    sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), kSendOptions);
+    ExchangeMessages(*sock_a_, cb_a_, *sock_z_, cb_z_);
+
+    absl::optional<DcSctpMessage> msg = cb_z_.ConsumeReceivedMessage();
+    ASSERT_TRUE(msg.has_value());
+    EXPECT_EQ(msg->stream_id(), StreamID(1));
+  }
+};
+
+INSTANTIATE_TEST_SUITE_P(Handovers,
+                         DcSctpSocketParametrizedTest,
+                         testing::Values(HandoverMode::kNoHandover,
+                                         HandoverMode::kPerformHandovers),
+                         [](const auto& test_info) {
+                           return test_info.param ==
+                                          HandoverMode::kPerformHandovers
+                                      ? "WithHandovers"
+                                      : "NoHandover";
+                         });
+
 TEST_F(DcSctpSocketTest, EstablishConnection) {
   EXPECT_CALL(cb_a_, OnConnected).Times(1);
   EXPECT_CALL(cb_z_, OnConnected).Times(1);
@@ -566,8 +630,8 @@
 
 TEST_F(DcSctpSocketTest, DoesntSendMorePacketsUntilCookieAckHasBeenReceived) {
   sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53),
-                             std::vector<uint8_t>(kLargeMessageSize)),
-               kSendOptions);
+                              std::vector<uint8_t>(kLargeMessageSize)),
+                kSendOptions);
   sock_a_->Connect();
 
   // Z reads INIT, produces INIT_ACK
@@ -623,11 +687,13 @@
               SizeIs(kLargeMessageSize));
 }
 
-TEST_F(DcSctpSocketTest, ShutdownConnection) {
+TEST_P(DcSctpSocketParametrizedTest, ShutdownConnection) {
   ConnectSockets();
+  MaybeHandoverSocketZ();
 
   RTC_LOG(LS_INFO) << "Shutting down";
 
+  EXPECT_CALL(cb_z_, OnClosed).Times(1);
   sock_a_->Shutdown();
   // Z reads SHUTDOWN, produces SHUTDOWN_ACK
   sock_z_->ReceivePacket(cb_a_.ConsumeSentPacket());
@@ -638,6 +704,9 @@
 
   EXPECT_EQ(sock_a_->state(), SocketState::kClosed);
   EXPECT_EQ(sock_z_->state(), SocketState::kClosed);
+
+  MaybeHandoverSocketZ();
+  EXPECT_EQ(sock_z_->state(), SocketState::kClosed);
 }
 
 TEST_F(DcSctpSocketTest, ShutdownTimerExpiresTooManyTimeClosesConnection) {
@@ -704,8 +773,9 @@
   EXPECT_EQ(msg->stream_id(), StreamID(1));
 }
 
-TEST_F(DcSctpSocketTest, TimeoutResendsPacket) {
+TEST_P(DcSctpSocketParametrizedTest, TimeoutResendsPacket) {
   ConnectSockets();
+  MaybeHandoverSocketZ();
 
   sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), kSendOptions);
   cb_a_.ConsumeSentPacket();
@@ -719,10 +789,13 @@
   absl::optional<DcSctpMessage> msg = cb_z_.ConsumeReceivedMessage();
   ASSERT_TRUE(msg.has_value());
   EXPECT_EQ(msg->stream_id(), StreamID(1));
+
+  MaybeHandoverSocketZAndSendMessage();
 }
 
-TEST_F(DcSctpSocketTest, SendALotOfBytesMissedSecondPacket) {
+TEST_P(DcSctpSocketParametrizedTest, SendALotOfBytesMissedSecondPacket) {
   ConnectSockets();
+  MaybeHandoverSocketZ();
 
   std::vector<uint8_t> payload(kLargeMessageSize);
   sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53), payload), kSendOptions);
@@ -739,10 +812,13 @@
   ASSERT_TRUE(msg.has_value());
   EXPECT_EQ(msg->stream_id(), StreamID(1));
   EXPECT_THAT(msg->payload(), testing::ElementsAreArray(payload));
+
+  MaybeHandoverSocketZAndSendMessage();
 }
 
-TEST_F(DcSctpSocketTest, SendingHeartbeatAnswersWithAck) {
+TEST_P(DcSctpSocketParametrizedTest, SendingHeartbeatAnswersWithAck) {
   ConnectSockets();
+  MaybeHandoverSocketZ();
 
   // Inject a HEARTBEAT chunk
   SctpPacket::Builder b(sock_a_->verification_tag(), DcSctpOptions());
@@ -761,10 +837,13 @@
       HeartbeatAckChunk::Parse(ack_packet.descriptors()[0].data));
   ASSERT_HAS_VALUE_AND_ASSIGN(HeartbeatInfoParameter info_param, ack.info());
   EXPECT_THAT(info_param.info(), ElementsAre(1, 2, 3, 4));
+
+  MaybeHandoverSocketZAndSendMessage();
 }
 
-TEST_F(DcSctpSocketTest, ExpectHeartbeatToBeSent) {
+TEST_P(DcSctpSocketParametrizedTest, ExpectHeartbeatToBeSent) {
   ConnectSockets();
+  MaybeHandoverSocketZ();
 
   EXPECT_THAT(cb_a_.ConsumeSentPacket(), IsEmpty());
 
@@ -786,11 +865,16 @@
   // Feed it to Sock-z and expect a HEARTBEAT_ACK that will be propagated back.
   sock_z_->ReceivePacket(hb_packet_raw);
   sock_a_->ReceivePacket(cb_z_.ConsumeSentPacket());
+
+  MaybeHandoverSocketZAndSendMessage();
 }
 
-TEST_F(DcSctpSocketTest, CloseConnectionAfterTooManyLostHeartbeats) {
+TEST_P(DcSctpSocketParametrizedTest,
+       CloseConnectionAfterTooManyLostHeartbeats) {
   ConnectSockets();
+  MaybeHandoverSocketZ();
 
+  EXPECT_CALL(cb_z_, OnClosed).Times(1);
   EXPECT_THAT(cb_a_.ConsumeSentPacket(), testing::IsEmpty());
   // Force-close socket Z so that it doesn't interfere from now on.
   sock_z_->Close();
@@ -825,12 +909,16 @@
   // Should suffice as exceeding RTO
   AdvanceTime(DurationMs(1000));
   RunTimers();
+
+  MaybeHandoverSocketZ();
 }
 
-TEST_F(DcSctpSocketTest, RecoversAfterASuccessfulAck) {
+TEST_P(DcSctpSocketParametrizedTest, RecoversAfterASuccessfulAck) {
   ConnectSockets();
+  MaybeHandoverSocketZ();
 
   EXPECT_THAT(cb_a_.ConsumeSentPacket(), testing::IsEmpty());
+  EXPECT_CALL(cb_z_, OnClosed).Times(1);
   // Force-close socket Z so that it doesn't interfere from now on.
   sock_z_->Close();
 
@@ -882,8 +970,9 @@
   EXPECT_EQ(another_packet.descriptors()[0].type, HeartbeatRequestChunk::kType);
 }
 
-TEST_F(DcSctpSocketTest, ResetStream) {
+TEST_P(DcSctpSocketParametrizedTest, ResetStream) {
   ConnectSockets();
+  MaybeHandoverSocketZ();
 
   sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), {});
   sock_z_->ReceivePacket(cb_a_.ConsumeSentPacket());
@@ -906,10 +995,13 @@
   // Receiving a response will trigger a callback. Streams are now reset.
   EXPECT_CALL(cb_a_, OnStreamsResetPerformed).Times(1);
   sock_a_->ReceivePacket(cb_z_.ConsumeSentPacket());
+
+  MaybeHandoverSocketZAndSendMessage();
 }
 
-TEST_F(DcSctpSocketTest, ResetStreamWillMakeChunksStartAtZeroSsn) {
+TEST_P(DcSctpSocketParametrizedTest, ResetStreamWillMakeChunksStartAtZeroSsn) {
   ConnectSockets();
+  MaybeHandoverSocketZ();
 
   std::vector<uint8_t> payload(options_.mtu - 100);
 
@@ -956,10 +1048,14 @@
 
   // Handle SACK
   sock_a_->ReceivePacket(cb_z_.ConsumeSentPacket());
+
+  MaybeHandoverSocketZAndSendMessage();
 }
 
-TEST_F(DcSctpSocketTest, ResetStreamWillOnlyResetTheRequestedStreams) {
+TEST_P(DcSctpSocketParametrizedTest,
+       ResetStreamWillOnlyResetTheRequestedStreams) {
   ConnectSockets();
+  MaybeHandoverSocketZ();
 
   std::vector<uint8_t> payload(options_.mtu - 100);
 
@@ -1034,10 +1130,13 @@
 
   // Handle SACK
   sock_a_->ReceivePacket(cb_z_.ConsumeSentPacket());
+
+  MaybeHandoverSocketZAndSendMessage();
 }
 
-TEST_F(DcSctpSocketTest, OnePeerReconnects) {
+TEST_P(DcSctpSocketParametrizedTest, OnePeerReconnects) {
   ConnectSockets();
+  MaybeHandoverSocketZ();
 
   EXPECT_CALL(cb_a_, OnConnectionRestarted).Times(1);
   // Let's be evil here - reconnect while a fragmented packet was about to be
@@ -1064,8 +1163,9 @@
   EXPECT_THAT(msg->payload(), testing::ElementsAreArray(payload));
 }
 
-TEST_F(DcSctpSocketTest, SendMessageWithLimitedRtx) {
+TEST_P(DcSctpSocketParametrizedTest, SendMessageWithLimitedRtx) {
   ConnectSockets();
+  MaybeHandoverSocketZ();
 
   SendOptions send_options;
   send_options.max_retransmissions = 0;
@@ -1117,10 +1217,13 @@
 
   absl::optional<DcSctpMessage> msg3 = cb_z_.ConsumeReceivedMessage();
   EXPECT_FALSE(msg3.has_value());
+
+  MaybeHandoverSocketZAndSendMessage();
 }
 
-TEST_F(DcSctpSocketTest, SendManyFragmentedMessagesWithLimitedRtx) {
+TEST_P(DcSctpSocketParametrizedTest, SendManyFragmentedMessagesWithLimitedRtx) {
   ConnectSockets();
+  MaybeHandoverSocketZ();
 
   SendOptions send_options;
   send_options.unordered = IsUnordered(true);
@@ -1210,8 +1313,9 @@
   std::string ToString() const override { return "FAKE"; }
 };
 
-TEST_F(DcSctpSocketTest, ReceivingUnknownChunkRespondsWithError) {
+TEST_P(DcSctpSocketParametrizedTest, ReceivingUnknownChunkRespondsWithError) {
   ConnectSockets();
+  MaybeHandoverSocketZ();
 
   // Inject a FAKE chunk
   SctpPacket::Builder b(sock_a_->verification_tag(), DcSctpOptions());
@@ -1228,10 +1332,13 @@
       UnrecognizedChunkTypeCause cause,
       error.error_causes().get<UnrecognizedChunkTypeCause>());
   EXPECT_THAT(cause.unrecognized_chunk(), ElementsAre(0x49, 0x00, 0x00, 0x04));
+
+  MaybeHandoverSocketZAndSendMessage();
 }
 
-TEST_F(DcSctpSocketTest, ReceivingErrorChunkReportsAsCallback) {
+TEST_P(DcSctpSocketParametrizedTest, ReceivingErrorChunkReportsAsCallback) {
   ConnectSockets();
+  MaybeHandoverSocketZ();
 
   // Inject a ERROR chunk
   SctpPacket::Builder b(sock_a_->verification_tag(), DcSctpOptions());
@@ -1243,6 +1350,8 @@
   EXPECT_CALL(cb_a_, OnError(ErrorKind::kPeerReported,
                              HasSubstr("Unrecognized Chunk Type")));
   sock_a_->ReceivePacket(b.Build());
+
+  MaybeHandoverSocketZAndSendMessage();
 }
 
 TEST_F(DcSctpSocketTest, PassingHighWatermarkWillOnlyAcceptCumAckTsn) {
@@ -1359,8 +1468,9 @@
   EXPECT_EQ(sock_a_->options().max_message_size, 42u);
 }
 
-TEST_F(DcSctpSocketTest, SendsMessagesWithLowLifetime) {
+TEST_P(DcSctpSocketParametrizedTest, SendsMessagesWithLowLifetime) {
   ConnectSockets();
+  MaybeHandoverSocketZ();
 
   // Mock that the time always goes forward.
   TimeMs now(0);
@@ -1394,10 +1504,14 @@
 
   // Validate that the sockets really make the time move forward.
   EXPECT_GE(*now, kIterations * 2);
+
+  MaybeHandoverSocketZAndSendMessage();
 }
 
-TEST_F(DcSctpSocketTest, DiscardsMessagesWithLowLifetimeIfMustBuffer) {
+TEST_P(DcSctpSocketParametrizedTest,
+       DiscardsMessagesWithLowLifetimeIfMustBuffer) {
   ConnectSockets();
+  MaybeHandoverSocketZ();
 
   SendOptions lifetime_0;
   lifetime_0.unordered = IsUnordered(true);
@@ -1449,53 +1563,65 @@
 
   // But none of the smaller messages.
   EXPECT_FALSE(cb_z_.ConsumeReceivedMessage().has_value());
+
+  MaybeHandoverSocketZAndSendMessage();
 }
 
-TEST_F(DcSctpSocketTest, HasReasonableBufferedAmountValues) {
+TEST_P(DcSctpSocketParametrizedTest, HasReasonableBufferedAmountValues) {
   ConnectSockets();
+  MaybeHandoverSocketZ();
 
   EXPECT_EQ(sock_a_->buffered_amount(StreamID(1)), 0u);
 
   sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53),
-                             std::vector<uint8_t>(kSmallMessageSize)),
-               kSendOptions);
+                              std::vector<uint8_t>(kSmallMessageSize)),
+                kSendOptions);
   // Sending a small message will directly send it as a single packet, so
   // nothing is left in the queue.
   EXPECT_EQ(sock_a_->buffered_amount(StreamID(1)), 0u);
 
   sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53),
-                             std::vector<uint8_t>(kLargeMessageSize)),
-               kSendOptions);
+                              std::vector<uint8_t>(kLargeMessageSize)),
+                kSendOptions);
 
   // Sending a message will directly start sending a few packets, so the
   // buffered amount is not the full message size.
   EXPECT_GT(sock_a_->buffered_amount(StreamID(1)), 0u);
   EXPECT_LT(sock_a_->buffered_amount(StreamID(1)), kLargeMessageSize);
+
+  MaybeHandoverSocketZAndSendMessage();
 }
 
 TEST_F(DcSctpSocketTest, HasDefaultOnBufferedAmountLowValueZero) {
   EXPECT_EQ(sock_a_->buffered_amount_low_threshold(StreamID(1)), 0u);
 }
 
-TEST_F(DcSctpSocketTest, TriggersOnBufferedAmountLowWithDefaultValueZero) {
+TEST_P(DcSctpSocketParametrizedTest,
+       TriggersOnBufferedAmountLowWithDefaultValueZero) {
   EXPECT_CALL(cb_a_, OnBufferedAmountLow).Times(0);
   ConnectSockets();
+  MaybeHandoverSocketZ();
 
   EXPECT_CALL(cb_a_, OnBufferedAmountLow(StreamID(1)));
   sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53),
-                             std::vector<uint8_t>(kSmallMessageSize)),
-               kSendOptions);
+                              std::vector<uint8_t>(kSmallMessageSize)),
+                kSendOptions);
   ExchangeMessages(*sock_a_, cb_a_, *sock_z_, cb_z_);
+
+  EXPECT_CALL(cb_a_, OnBufferedAmountLow).WillRepeatedly(testing::Return());
+  MaybeHandoverSocketZAndSendMessage();
 }
 
-TEST_F(DcSctpSocketTest, DoesntTriggerOnBufferedAmountLowIfBelowThreshold) {
+TEST_P(DcSctpSocketParametrizedTest,
+       DoesntTriggerOnBufferedAmountLowIfBelowThreshold) {
   static constexpr size_t kMessageSize = 1000;
   static constexpr size_t kBufferedAmountLowThreshold = kMessageSize * 10;
 
   sock_a_->SetBufferedAmountLowThreshold(StreamID(1),
-                                        kBufferedAmountLowThreshold);
+                                         kBufferedAmountLowThreshold);
   EXPECT_CALL(cb_a_, OnBufferedAmountLow).Times(0);
   ConnectSockets();
+  MaybeHandoverSocketZ();
 
   EXPECT_CALL(cb_a_, OnBufferedAmountLow(StreamID(1))).Times(0);
   sock_a_->Send(
@@ -1507,16 +1633,19 @@
       DcSctpMessage(StreamID(1), PPID(53), std::vector<uint8_t>(kMessageSize)),
       kSendOptions);
   ExchangeMessages(*sock_a_, cb_a_, *sock_z_, cb_z_);
+
+  MaybeHandoverSocketZAndSendMessage();
 }
 
-TEST_F(DcSctpSocketTest, TriggersOnBufferedAmountMultipleTimes) {
+TEST_P(DcSctpSocketParametrizedTest, TriggersOnBufferedAmountMultipleTimes) {
   static constexpr size_t kMessageSize = 1000;
   static constexpr size_t kBufferedAmountLowThreshold = kMessageSize / 2;
 
   sock_a_->SetBufferedAmountLowThreshold(StreamID(1),
-                                        kBufferedAmountLowThreshold);
+                                         kBufferedAmountLowThreshold);
   EXPECT_CALL(cb_a_, OnBufferedAmountLow).Times(0);
   ConnectSockets();
+  MaybeHandoverSocketZ();
 
   EXPECT_CALL(cb_a_, OnBufferedAmountLow(StreamID(1))).Times(3);
   EXPECT_CALL(cb_a_, OnBufferedAmountLow(StreamID(2))).Times(2);
@@ -1544,16 +1673,20 @@
       DcSctpMessage(StreamID(1), PPID(53), std::vector<uint8_t>(kMessageSize)),
       kSendOptions);
   ExchangeMessages(*sock_a_, cb_a_, *sock_z_, cb_z_);
+
+  MaybeHandoverSocketZAndSendMessage();
 }
 
-TEST_F(DcSctpSocketTest, TriggersOnBufferedAmountLowOnlyWhenCrossingThreshold) {
+TEST_P(DcSctpSocketParametrizedTest,
+       TriggersOnBufferedAmountLowOnlyWhenCrossingThreshold) {
   static constexpr size_t kMessageSize = 1000;
   static constexpr size_t kBufferedAmountLowThreshold = kMessageSize * 1.5;
 
   sock_a_->SetBufferedAmountLowThreshold(StreamID(1),
-                                        kBufferedAmountLowThreshold);
+                                         kBufferedAmountLowThreshold);
   EXPECT_CALL(cb_a_, OnBufferedAmountLow).Times(0);
   ConnectSockets();
+  MaybeHandoverSocketZ();
 
   EXPECT_CALL(cb_a_, OnBufferedAmountLow).Times(0);
 
@@ -1561,8 +1694,8 @@
   // messages will start to be fully buffered.
   while (sock_a_->buffered_amount(StreamID(1)) <= kBufferedAmountLowThreshold) {
     sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53),
-                               std::vector<uint8_t>(kMessageSize)),
-                 kSendOptions);
+                                std::vector<uint8_t>(kMessageSize)),
+                  kSendOptions);
   }
   size_t initial_buffered = sock_a_->buffered_amount(StreamID(1));
   ASSERT_GT(initial_buffered, kBufferedAmountLowThreshold);
@@ -1571,36 +1704,46 @@
   // callback.
   EXPECT_CALL(cb_a_, OnBufferedAmountLow(StreamID(1))).Times(1);
   ExchangeMessages(*sock_a_, cb_a_, *sock_z_, cb_z_);
+
+  MaybeHandoverSocketZAndSendMessage();
 }
 
-TEST_F(DcSctpSocketTest, DoesntTriggerOnTotalBufferAmountLowWhenBelow) {
+TEST_P(DcSctpSocketParametrizedTest,
+       DoesntTriggerOnTotalBufferAmountLowWhenBelow) {
   ConnectSockets();
+  MaybeHandoverSocketZ();
 
   EXPECT_CALL(cb_a_, OnTotalBufferedAmountLow).Times(0);
 
   sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53),
-                             std::vector<uint8_t>(kLargeMessageSize)),
-               kSendOptions);
+                              std::vector<uint8_t>(kLargeMessageSize)),
+                kSendOptions);
 
   ExchangeMessages(*sock_a_, cb_a_, *sock_z_, cb_z_);
+
+  MaybeHandoverSocketZAndSendMessage();
 }
 
-TEST_F(DcSctpSocketTest, TriggersOnTotalBufferAmountLowWhenCrossingThreshold) {
+TEST_P(DcSctpSocketParametrizedTest,
+       TriggersOnTotalBufferAmountLowWhenCrossingThreshold) {
   ConnectSockets();
+  MaybeHandoverSocketZ();
 
   EXPECT_CALL(cb_a_, OnTotalBufferedAmountLow).Times(0);
 
   // Fill up the send queue completely.
   for (;;) {
     if (sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53),
-                                   std::vector<uint8_t>(kLargeMessageSize)),
-                     kSendOptions) == SendStatus::kErrorResourceExhaustion) {
+                                    std::vector<uint8_t>(kLargeMessageSize)),
+                      kSendOptions) == SendStatus::kErrorResourceExhaustion) {
       break;
     }
   }
 
   EXPECT_CALL(cb_a_, OnTotalBufferedAmountLow).Times(1);
   ExchangeMessages(*sock_a_, cb_a_, *sock_z_, cb_z_);
+
+  MaybeHandoverSocketZAndSendMessage();
 }
 
 TEST_F(DcSctpSocketTest, InitialMetricsAreZeroed) {
@@ -1650,8 +1793,8 @@
 
   // Send one more (large - fragmented), and receive the delayed SACK.
   sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53),
-                             std::vector<uint8_t>(options_.mtu * 2 + 1)),
-               kSendOptions);
+                              std::vector<uint8_t>(options_.mtu * 2 + 1)),
+                kSendOptions);
   EXPECT_EQ(sock_a_->GetMetrics().unack_data_count, 3u);
 
   sock_z_->ReceivePacket(cb_a_.ConsumeSentPacket());  // DATA
@@ -1683,12 +1826,13 @@
   EXPECT_EQ(*sock_a_->GetMetrics().peer_rwnd_bytes, initial_a_rwnd);
 }
 
-TEST_F(DcSctpSocketTest, UnackDataAlsoIncludesSendQueue) {
+TEST_P(DcSctpSocketParametrizedTest, UnackDataAlsoIncludesSendQueue) {
   ConnectSockets();
+  MaybeHandoverSocketZ();
 
   sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53),
-                             std::vector<uint8_t>(kLargeMessageSize)),
-               kSendOptions);
+                              std::vector<uint8_t>(kLargeMessageSize)),
+                kSendOptions);
   size_t payload_bytes =
       options_.mtu - SctpPacket::kHeaderSize - DataChunk::kHeaderSize;
 
@@ -1706,14 +1850,17 @@
 
   EXPECT_LE(sock_a_->GetMetrics().unack_data_count,
             expected_sent_packets + expected_queued_packets + 2);
+
+  MaybeHandoverSocketZAndSendMessage();
 }
 
-TEST_F(DcSctpSocketTest, DoesntSendMoreThanMaxBurstPackets) {
+TEST_P(DcSctpSocketParametrizedTest, DoesntSendMoreThanMaxBurstPackets) {
   ConnectSockets();
+  MaybeHandoverSocketZ();
 
   sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53),
-                             std::vector<uint8_t>(kLargeMessageSize)),
-               kSendOptions);
+                              std::vector<uint8_t>(kLargeMessageSize)),
+                kSendOptions);
 
   for (int i = 0; i < kMaxBurstPackets; ++i) {
     std::vector<uint8_t> packet = cb_a_.ConsumeSentPacket();
@@ -1722,10 +1869,14 @@
   }
 
   EXPECT_THAT(cb_a_.ConsumeSentPacket(), IsEmpty());
+
+  ExchangeMessages(*sock_a_, cb_a_, *sock_z_, cb_z_);
+  MaybeHandoverSocketZAndSendMessage();
 }
 
-TEST_F(DcSctpSocketTest, SendsOnlyLargePackets) {
+TEST_P(DcSctpSocketParametrizedTest, SendsOnlyLargePackets) {
   ConnectSockets();
+  MaybeHandoverSocketZ();
 
   // A really large message, to ensure that the congestion window is often full.
   constexpr size_t kMessageSize = 100000;
@@ -1765,10 +1916,13 @@
     // The 4 is for padding/alignment.
     EXPECT_GE(size, options_.mtu - 4);
   }
+
+  MaybeHandoverSocketZAndSendMessage();
 }
 
-TEST_F(DcSctpSocketTest, DoesntBundleForwardTsnWithData) {
+TEST_P(DcSctpSocketParametrizedTest, DoesntBundleForwardTsnWithData) {
   ConnectSockets();
+  MaybeHandoverSocketZ();
 
   // Force an RTT measurement using heartbeats.
   AdvanceTime(options_.heartbeat_interval);
@@ -1848,5 +2002,49 @@
   EXPECT_EQ(packet4.descriptors()[0].type, ForwardTsnChunk::kType);
 }
 
+TEST_F(DcSctpSocketTest, SendMessagesAfterHandover) {
+  ConnectSockets();
+
+  // Send message before handover to move socket to a not initial state
+  sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), kSendOptions);
+  sock_z_->ReceivePacket(cb_a_.ConsumeSentPacket());
+  cb_z_.ConsumeReceivedMessage();
+
+  HandoverSocketZ();
+
+  absl::optional<DcSctpMessage> msg;
+
+  RTC_LOG(LS_INFO) << "Sending A #1";
+
+  sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53), {3, 4}), kSendOptions);
+  sock_z_->ReceivePacket(cb_a_.ConsumeSentPacket());
+
+  msg = cb_z_.ConsumeReceivedMessage();
+  ASSERT_TRUE(msg.has_value());
+  EXPECT_EQ(msg->stream_id(), StreamID(1));
+  EXPECT_THAT(msg->payload(), testing::ElementsAre(3, 4));
+
+  RTC_LOG(LS_INFO) << "Sending A #2";
+
+  sock_a_->Send(DcSctpMessage(StreamID(2), PPID(53), {5, 6}), kSendOptions);
+  sock_z_->ReceivePacket(cb_a_.ConsumeSentPacket());
+
+  msg = cb_z_.ConsumeReceivedMessage();
+  ASSERT_TRUE(msg.has_value());
+  EXPECT_EQ(msg->stream_id(), StreamID(2));
+  EXPECT_THAT(msg->payload(), testing::ElementsAre(5, 6));
+
+  RTC_LOG(LS_INFO) << "Sending Z #1";
+
+  sock_z_->Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2, 3}), kSendOptions);
+  sock_a_->ReceivePacket(cb_z_.ConsumeSentPacket());  // ack
+  sock_a_->ReceivePacket(cb_z_.ConsumeSentPacket());  // data
+
+  msg = cb_a_.ConsumeReceivedMessage();
+  ASSERT_TRUE(msg.has_value());
+  EXPECT_EQ(msg->stream_id(), StreamID(1));
+  EXPECT_THAT(msg->payload(), testing::ElementsAre(1, 2, 3));
+}
+
 }  // namespace
 }  // namespace dcsctp
diff --git a/net/dcsctp/socket/mock_dcsctp_socket_callbacks.h b/net/dcsctp/socket/mock_dcsctp_socket_callbacks.h
index 894dd9a..a49a0b3 100644
--- a/net/dcsctp/socket/mock_dcsctp_socket_callbacks.h
+++ b/net/dcsctp/socket/mock_dcsctp_socket_callbacks.h
@@ -150,6 +150,12 @@
     return timeout_manager_.GetNextExpiredTimeout();
   }
 
+  void Reset() {
+    sent_packets_.clear();
+    received_messages_.clear();
+    timeout_manager_.Reset();
+  }
+
  private:
   const std::string log_prefix_;
   TimeMs now_ = TimeMs(0);
diff --git a/net/dcsctp/socket/transmission_control_block.cc b/net/dcsctp/socket/transmission_control_block.cc
index f0f1ab9..2e4e968 100644
--- a/net/dcsctp/socket/transmission_control_block.cc
+++ b/net/dcsctp/socket/transmission_control_block.cc
@@ -183,4 +183,30 @@
   return sb.Release();
 }
 
+HandoverReadinessStatus TransmissionControlBlock::GetHandoverReadiness() const {
+  HandoverReadinessStatus status;
+  status.Add(data_tracker_.GetHandoverReadiness());
+  status.Add(stream_reset_handler_.GetHandoverReadiness());
+  status.Add(reassembly_queue_.GetHandoverReadiness());
+  status.Add(retransmission_queue_.GetHandoverReadiness());
+  return status;
+}
+
+void TransmissionControlBlock::AddHandoverState(
+    DcSctpSocketHandoverState& state) {
+  state.capabilities.partial_reliability = capabilities_.partial_reliability;
+  state.capabilities.message_interleaving = capabilities_.message_interleaving;
+  state.capabilities.reconfig = capabilities_.reconfig;
+
+  state.my_verification_tag = my_verification_tag().value();
+  state.peer_verification_tag = peer_verification_tag().value();
+  state.my_initial_tsn = my_initial_tsn().value();
+  state.peer_initial_tsn = peer_initial_tsn().value();
+  state.tie_tag = tie_tag().value();
+
+  data_tracker_.AddHandoverState(state);
+  stream_reset_handler_.AddHandoverState(state);
+  reassembly_queue_.AddHandoverState(state);
+  retransmission_queue_.AddHandoverState(state);
+}
 }  // namespace dcsctp
diff --git a/net/dcsctp/socket/transmission_control_block.h b/net/dcsctp/socket/transmission_control_block.h
index c3766d1..6d9dfc5 100644
--- a/net/dcsctp/socket/transmission_control_block.h
+++ b/net/dcsctp/socket/transmission_control_block.h
@@ -44,20 +44,22 @@
 // closed or restarted, this object will be deleted and/or replaced.
 class TransmissionControlBlock : public Context {
  public:
-  TransmissionControlBlock(TimerManager& timer_manager,
-                           absl::string_view log_prefix,
-                           const DcSctpOptions& options,
-                           const Capabilities& capabilities,
-                           DcSctpSocketCallbacks& callbacks,
-                           SendQueue& send_queue,
-                           VerificationTag my_verification_tag,
-                           TSN my_initial_tsn,
-                           VerificationTag peer_verification_tag,
-                           TSN peer_initial_tsn,
-                           size_t a_rwnd,
-                           TieTag tie_tag,
-                           PacketSender& packet_sender,
-                           std::function<bool()> is_connection_established)
+  TransmissionControlBlock(
+      TimerManager& timer_manager,
+      absl::string_view log_prefix,
+      const DcSctpOptions& options,
+      const Capabilities& capabilities,
+      DcSctpSocketCallbacks& callbacks,
+      SendQueue& send_queue,
+      VerificationTag my_verification_tag,
+      TSN my_initial_tsn,
+      VerificationTag peer_verification_tag,
+      TSN peer_initial_tsn,
+      size_t a_rwnd,
+      TieTag tie_tag,
+      PacketSender& packet_sender,
+      std::function<bool()> is_connection_established,
+      const DcSctpSocketHandoverState* handover_state = nullptr)
       : log_prefix_(log_prefix),
         options_(options),
         timer_manager_(timer_manager),
@@ -86,10 +88,14 @@
         packet_sender_(packet_sender),
         rto_(options),
         tx_error_counter_(log_prefix, options),
-        data_tracker_(log_prefix, delayed_ack_timer_.get(), peer_initial_tsn),
+        data_tracker_(log_prefix,
+                      delayed_ack_timer_.get(),
+                      peer_initial_tsn,
+                      handover_state),
         reassembly_queue_(log_prefix,
                           peer_initial_tsn,
-                          options.max_receiver_window_buffer_size),
+                          options.max_receiver_window_buffer_size,
+                          handover_state),
         retransmission_queue_(
             log_prefix,
             my_initial_tsn,
@@ -100,13 +106,15 @@
             *t3_rtx_,
             options,
             capabilities.partial_reliability,
-            capabilities.message_interleaving),
+            capabilities.message_interleaving,
+            handover_state),
         stream_reset_handler_(log_prefix,
                               this,
                               &timer_manager,
                               &data_tracker_,
                               &reassembly_queue_,
-                              &retransmission_queue_),
+                              &retransmission_queue_,
+                              handover_state),
         heartbeat_handler_(log_prefix, options, this, &timer_manager_) {}
 
   // Implementation of `Context`.
@@ -188,6 +196,10 @@
   // Returns a textual representation of this object, for logging.
   std::string ToString() const;
 
+  HandoverReadinessStatus GetHandoverReadiness() const;
+
+  void AddHandoverState(DcSctpSocketHandoverState& state);
+
  private:
   // Will be called when the retransmission timer (t3-rtx) expires.
   absl::optional<DurationMs> OnRtxTimerExpiry();
diff --git a/net/dcsctp/timer/fake_timeout.h b/net/dcsctp/timer/fake_timeout.h
index f2bf103..e8f50d9 100644
--- a/net/dcsctp/timer/fake_timeout.h
+++ b/net/dcsctp/timer/fake_timeout.h
@@ -91,6 +91,8 @@
     return absl::nullopt;
   }
 
+  void Reset() { timers_.clear(); }
+
  private:
   const std::function<TimeMs()> get_time_;
   webrtc::flat_set<FakeTimeout*> timers_;