dcsctp: support socket handover in RetransmissionQueue

Bug: webrtc:13154
Change-Id: I9c73b1153b65409eb026e015804c22f3e874ff82
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/232022
Reviewed-by: Victor Boivie <boivie@webrtc.org>
Commit-Queue: Sergey Sukhanov <sergeysu@webrtc.org>
Cr-Commit-Position: refs/heads/main@{#35009}
diff --git a/net/dcsctp/public/dcsctp_handover_state.h b/net/dcsctp/public/dcsctp_handover_state.h
index 2cd77ed..886ee46 100644
--- a/net/dcsctp/public/dcsctp_handover_state.h
+++ b/net/dcsctp/public/dcsctp_handover_state.h
@@ -24,6 +24,16 @@
 // for serialization. Serialization is not provided by dcSCTP. If needed it has
 // to be implemented in the calling client.
 struct DcSctpSocketHandoverState {
+  struct Transmission {
+    uint32_t next_tsn = 0;
+    uint32_t next_reset_req_sn = 0;
+    uint32_t cwnd = 0;
+    uint32_t rwnd = 0;
+    uint32_t ssthresh = 0;
+    uint32_t partial_bytes_acked = 0;
+  };
+  Transmission tx;
+
   struct OrderedStream {
     uint32_t id = 0;
     uint32_t next_ssn = 0;
diff --git a/net/dcsctp/tx/BUILD.gn b/net/dcsctp/tx/BUILD.gn
index 50e424c..a02f5dc 100644
--- a/net/dcsctp/tx/BUILD.gn
+++ b/net/dcsctp/tx/BUILD.gn
@@ -79,6 +79,7 @@
     "../common:str_join",
     "../packet:chunk",
     "../packet:data",
+    "../public:socket",
     "../public:types",
     "../timer",
   ]
diff --git a/net/dcsctp/tx/retransmission_queue.cc b/net/dcsctp/tx/retransmission_queue.cc
index 0156a56..8468143 100644
--- a/net/dcsctp/tx/retransmission_queue.cc
+++ b/net/dcsctp/tx/retransmission_queue.cc
@@ -54,7 +54,7 @@
 
 RetransmissionQueue::RetransmissionQueue(
     absl::string_view log_prefix,
-    TSN initial_tsn,
+    TSN my_initial_tsn,
     size_t a_rwnd,
     SendQueue& send_queue,
     std::function<void(DurationMs rtt)> on_new_rtt,
@@ -62,7 +62,8 @@
     Timer& t3_rtx,
     const DcSctpOptions& options,
     bool supports_partial_reliability,
-    bool use_message_interleaving)
+    bool use_message_interleaving,
+    const DcSctpSocketHandoverState* handover_state)
     : options_(options),
       min_bytes_required_to_send_(options.mtu * kMinBytesRequiredToSendFactor),
       partial_reliability_(supports_partial_reliability),
@@ -74,15 +75,21 @@
       on_clear_retransmission_counter_(
           std::move(on_clear_retransmission_counter)),
       t3_rtx_(t3_rtx),
-      cwnd_(options_.cwnd_mtus_initial * options_.mtu),
-      rwnd_(a_rwnd),
+      cwnd_(handover_state ? handover_state->tx.cwnd
+                           : options_.cwnd_mtus_initial * options_.mtu),
+      rwnd_(handover_state ? handover_state->tx.rwnd : a_rwnd),
       // https://tools.ietf.org/html/rfc4960#section-7.2.1
       // "The initial value of ssthresh MAY be arbitrarily high (for
       // example, implementations MAY use the size of the receiver advertised
       // window).""
-      ssthresh_(rwnd_),
-      next_tsn_(tsn_unwrapper_.Unwrap(initial_tsn)),
-      last_cumulative_tsn_ack_(tsn_unwrapper_.Unwrap(TSN(*initial_tsn - 1))),
+      ssthresh_(handover_state ? handover_state->tx.ssthresh : rwnd_),
+      partial_bytes_acked_(
+          handover_state ? handover_state->tx.partial_bytes_acked : 0),
+      next_tsn_(tsn_unwrapper_.Unwrap(
+          handover_state ? TSN(handover_state->tx.next_tsn) : my_initial_tsn)),
+      last_cumulative_tsn_ack_(tsn_unwrapper_.Unwrap(
+          handover_state ? TSN(handover_state->tx.next_tsn - 1)
+                         : TSN(*my_initial_tsn - 1))),
       send_queue_(send_queue) {}
 
 bool RetransmissionQueue::IsConsistent() const {
@@ -919,4 +926,25 @@
   send_queue_.RollbackResetStreams();
 }
 
+HandoverReadinessStatus RetransmissionQueue::GetHandoverReadiness() const {
+  HandoverReadinessStatus status;
+  if (!outstanding_data_.empty()) {
+    status.Add(HandoverUnreadinessReason::kRetransmissionQueueOutstandingData);
+  }
+  if (fast_recovery_exit_tsn_.has_value()) {
+    status.Add(HandoverUnreadinessReason::kRetransmissionQueueFastRecovery);
+  }
+  if (!to_be_retransmitted_.empty()) {
+    status.Add(HandoverUnreadinessReason::kRetransmissionQueueNotEmpty);
+  }
+  return status;
+}
+
+void RetransmissionQueue::AddHandoverState(DcSctpSocketHandoverState& state) {
+  state.tx.next_tsn = next_tsn().value();
+  state.tx.rwnd = rwnd_;
+  state.tx.cwnd = cwnd_;
+  state.tx.ssthresh = ssthresh_;
+  state.tx.partial_bytes_acked = partial_bytes_acked_;
+}
 }  // namespace dcsctp
diff --git a/net/dcsctp/tx/retransmission_queue.h b/net/dcsctp/tx/retransmission_queue.h
index e4175c6..943df48 100644
--- a/net/dcsctp/tx/retransmission_queue.h
+++ b/net/dcsctp/tx/retransmission_queue.h
@@ -26,6 +26,7 @@
 #include "net/dcsctp/packet/chunk/iforward_tsn_chunk.h"
 #include "net/dcsctp/packet/chunk/sack_chunk.h"
 #include "net/dcsctp/packet/data.h"
+#include "net/dcsctp/public/dcsctp_handover_state.h"
 #include "net/dcsctp/public/dcsctp_options.h"
 #include "net/dcsctp/timer/timer.h"
 #include "net/dcsctp/tx/retransmission_timeout.h"
@@ -61,23 +62,25 @@
     kAbandoned,
   };
 
-  // Creates a RetransmissionQueue which will send data using `initial_tsn` as
-  // the first TSN to use for sent fragments. It will poll data from
-  // `send_queue` and call `on_send_queue_empty` when it is empty. When
-  // SACKs are received, it will estimate the RTT, and call `on_new_rtt`. When
-  // an outstanding chunk has been ACKed, it will call
+  // Creates a RetransmissionQueue which will send data using `my_initial_tsn`
+  // (or a value from `DcSctpSocketHandoverState` if given) as the first TSN
+  // to use for sent fragments. It will poll data from `send_queue`. When SACKs
+  // are received, it will estimate the RTT, and call `on_new_rtt`. When an
+  // outstanding chunk has been ACKed, it will call
   // `on_clear_retransmission_counter` and will also use `t3_rtx`, which is the
   // SCTP retransmission timer to manage retransmissions.
-  RetransmissionQueue(absl::string_view log_prefix,
-                      TSN initial_tsn,
-                      size_t a_rwnd,
-                      SendQueue& send_queue,
-                      std::function<void(DurationMs rtt)> on_new_rtt,
-                      std::function<void()> on_clear_retransmission_counter,
-                      Timer& t3_rtx,
-                      const DcSctpOptions& options,
-                      bool supports_partial_reliability = true,
-                      bool use_message_interleaving = false);
+  RetransmissionQueue(
+      absl::string_view log_prefix,
+      TSN my_initial_tsn,
+      size_t a_rwnd,
+      SendQueue& send_queue,
+      std::function<void(DurationMs rtt)> on_new_rtt,
+      std::function<void()> on_clear_retransmission_counter,
+      Timer& t3_rtx,
+      const DcSctpOptions& options,
+      bool supports_partial_reliability = true,
+      bool use_message_interleaving = false,
+      const DcSctpSocketHandoverState* handover_state = nullptr);
 
   // Handles a received SACK. Returns true if the `sack` was processed and
   // false if it was discarded due to received out-of-order and not relevant.
@@ -139,6 +142,10 @@
   void CommitResetStreams();
   void RollbackResetStreams();
 
+  HandoverReadinessStatus GetHandoverReadiness() const;
+
+  void AddHandoverState(DcSctpSocketHandoverState& state);
+
  private:
   enum class CongestionAlgorithmPhase {
     kSlowStart,
@@ -279,9 +286,7 @@
   // action indicated when nacking an item (e.g. retransmitting or abandoning).
   // The return value indicate if an action was performed, meaning that packet
   // loss was detected and acted upon.
-  bool NackItem(UnwrappedTSN cumulative_tsn_ack,
-                TxData& item,
-                bool retransmit_now);
+  bool NackItem(UnwrappedTSN tsn, TxData& item, bool retransmit_now);
 
   // Will mark the chunks covered by the `gap_ack_blocks` from an incoming SACK
   // as "acked" and update `ack_info` by adding new TSNs to `added_tsns`.
@@ -375,7 +380,7 @@
   // Slow Start Threshold. See RFC4960.
   size_t ssthresh_;
   // Partial Bytes Acked. See RFC4960.
-  size_t partial_bytes_acked_ = 0;
+  size_t partial_bytes_acked_;
   // If set, fast recovery is enabled until this TSN has been cumulative
   // acked.
   absl::optional<UnwrappedTSN> fast_recovery_exit_tsn_ = absl::nullopt;
diff --git a/net/dcsctp/tx/retransmission_queue_test.cc b/net/dcsctp/tx/retransmission_queue_test.cc
index c64aeb1..5f524de 100644
--- a/net/dcsctp/tx/retransmission_queue_test.cc
+++ b/net/dcsctp/tx/retransmission_queue_test.cc
@@ -89,6 +89,17 @@
         supports_partial_reliability, use_message_interleaving);
   }
 
+  RetransmissionQueue CreateQueueByHandover(RetransmissionQueue& queue) {
+    EXPECT_EQ(queue.GetHandoverReadiness(), HandoverReadinessStatus());
+    DcSctpSocketHandoverState state;
+    queue.AddHandoverState(state);
+    return RetransmissionQueue(
+        "", TSN(10), kArwnd, producer_, on_rtt_.AsStdFunction(),
+        on_clear_retransmission_counter_.AsStdFunction(), *timer_, options_,
+        /*supports_partial_reliability=*/true,
+        /*use_message_interleaving=*/false, &state);
+  }
+
   DcSctpOptions options_;
   DataGenerator gen_;
   TimeMs now_ = TimeMs(0);
@@ -1275,5 +1286,125 @@
   EXPECT_TRUE(queue.can_send_data());
 }
 
+TEST_F(RetransmissionQueueTest, ReadyForHandoverWhenHasNoOutstandingData) {
+  RetransmissionQueue queue = CreateQueue();
+  EXPECT_CALL(producer_, Produce)
+      .WillOnce(CreateChunk())
+      .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; });
+
+  EXPECT_THAT(GetSentPacketTSNs(queue), SizeIs(1));
+  EXPECT_EQ(
+      queue.GetHandoverReadiness(),
+      HandoverReadinessStatus(
+          HandoverUnreadinessReason::kRetransmissionQueueOutstandingData));
+
+  queue.HandleSack(now_, SackChunk(TSN(10), kArwnd, {}, {}));
+  EXPECT_EQ(queue.GetHandoverReadiness(), HandoverReadinessStatus());
+}
+
+TEST_F(RetransmissionQueueTest, ReadyForHandoverWhenNothingToRetransmit) {
+  RetransmissionQueue queue = CreateQueue();
+  EXPECT_CALL(producer_, Produce)
+      .WillOnce(CreateChunk())
+      .WillOnce(CreateChunk())
+      .WillOnce(CreateChunk())
+      .WillOnce(CreateChunk())
+      .WillOnce(CreateChunk())
+      .WillOnce(CreateChunk())
+      .WillOnce(CreateChunk())
+      .WillOnce(CreateChunk())
+      .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; });
+  EXPECT_THAT(GetSentPacketTSNs(queue), SizeIs(8));
+  EXPECT_EQ(
+      queue.GetHandoverReadiness(),
+      HandoverReadinessStatus(
+          HandoverUnreadinessReason::kRetransmissionQueueOutstandingData));
+
+  // Send more chunks, but leave some chunks unacked to force retransmission
+  // after three NACKs.
+
+  // Send 18
+  EXPECT_CALL(producer_, Produce)
+      .WillOnce(CreateChunk())
+      .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; });
+  EXPECT_THAT(GetSentPacketTSNs(queue), SizeIs(1));
+
+  // Ack 12, 14-15, 17-18
+  queue.HandleSack(now_, SackChunk(TSN(12), kArwnd,
+                                   {SackChunk::GapAckBlock(2, 3),
+                                    SackChunk::GapAckBlock(5, 6)},
+                                   {}));
+
+  // Send 19
+  EXPECT_CALL(producer_, Produce)
+      .WillOnce(CreateChunk())
+      .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; });
+  EXPECT_THAT(GetSentPacketTSNs(queue), SizeIs(1));
+
+  // Ack 12, 14-15, 17-19
+  queue.HandleSack(now_, SackChunk(TSN(12), kArwnd,
+                                   {SackChunk::GapAckBlock(2, 3),
+                                    SackChunk::GapAckBlock(5, 7)},
+                                   {}));
+
+  // Send 20
+  EXPECT_CALL(producer_, Produce)
+      .WillOnce(CreateChunk())
+      .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; });
+  EXPECT_THAT(GetSentPacketTSNs(queue), SizeIs(1));
+
+  // Ack 12, 14-15, 17-20
+  // This will trigger "fast retransmit" mode and only chunks 13 and 16 will be
+  // resent right now. The send queue will not even be queried.
+  queue.HandleSack(now_, SackChunk(TSN(12), kArwnd,
+                                   {SackChunk::GapAckBlock(2, 3),
+                                    SackChunk::GapAckBlock(5, 8)},
+                                   {}));
+  EXPECT_EQ(
+      queue.GetHandoverReadiness(),
+      HandoverReadinessStatus()
+          .Add(HandoverUnreadinessReason::kRetransmissionQueueOutstandingData)
+          .Add(HandoverUnreadinessReason::kRetransmissionQueueFastRecovery)
+          .Add(HandoverUnreadinessReason::kRetransmissionQueueNotEmpty));
+
+  // Send "fast retransmit" mode chunks
+  EXPECT_CALL(producer_, Produce).Times(0);
+  EXPECT_THAT(GetSentPacketTSNs(queue), SizeIs(2));
+  EXPECT_EQ(
+      queue.GetHandoverReadiness(),
+      HandoverReadinessStatus()
+          .Add(HandoverUnreadinessReason::kRetransmissionQueueOutstandingData)
+          .Add(HandoverUnreadinessReason::kRetransmissionQueueFastRecovery));
+
+  // Ack 20 to confirm the retransmission
+  queue.HandleSack(now_, SackChunk(TSN(20), kArwnd, {}, {}));
+  EXPECT_EQ(queue.GetHandoverReadiness(), HandoverReadinessStatus());
+}
+
+TEST_F(RetransmissionQueueTest, HandoverTest) {
+  RetransmissionQueue queue = CreateQueue();
+  EXPECT_CALL(producer_, Produce)
+      .WillOnce(CreateChunk())
+      .WillOnce(CreateChunk())
+      .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; });
+  EXPECT_THAT(GetSentPacketTSNs(queue), SizeIs(2));
+  queue.HandleSack(now_, SackChunk(TSN(11), kArwnd, {}, {}));
+
+  RetransmissionQueue handedover_queue = CreateQueueByHandover(queue);
+
+  EXPECT_CALL(producer_, Produce)
+      .WillOnce(CreateChunk())
+      .WillOnce(CreateChunk())
+      .WillOnce(CreateChunk())
+      .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; });
+  EXPECT_THAT(GetSentPacketTSNs(handedover_queue),
+              testing::ElementsAre(TSN(12), TSN(13), TSN(14)));
+
+  handedover_queue.HandleSack(now_, SackChunk(TSN(13), kArwnd, {}, {}));
+  EXPECT_THAT(handedover_queue.GetChunkStatesForTesting(),
+              ElementsAre(Pair(TSN(13), State::kAcked),  //
+                          Pair(TSN(14), State::kInFlight)));
+}
+
 }  // namespace
 }  // namespace dcsctp