dcsctp: Exit deferred stream reset on FORWARD-TSN

https://datatracker.ietf.org/doc/html/rfc6525#section-5.2.2:

E2:  If the Sender's Last Assigned TSN is greater than the cumulative
        acknowledgment point, then the endpoint MUST enter "deferred
        reset processing". ...  until the cumulative
        acknowledgment point reaches the Sender's Last Assigned TSN.

The cumulative acknowledgement point can not only be reached by
receiving DATA chunks, but also by receiving a FORWARD-TSN that
instructs the receiver to skip them. This was only done for DATA and not
for FORWARD-TSN, which is now corrected.

Additionally, an unnecessary implicit sending of SACK after having
received FORWARD-TSN was removed as this is done anyway every time a
packet has been received. This unifies the processing of DATA and
FORWARD-TSN more.

Bug: webrtc:14600
Change-Id: If797d3c46e741074fe05e322d0aebec765a87968
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/321400
Reviewed-by: Harald Alvestrand <hta@webrtc.org>
Commit-Queue: Victor Boivie <boivie@webrtc.org>
Cr-Commit-Position: refs/heads/main@{#40811}
diff --git a/net/dcsctp/public/types.h b/net/dcsctp/public/types.h
index d072562..7d69875 100644
--- a/net/dcsctp/public/types.h
+++ b/net/dcsctp/public/types.h
@@ -41,6 +41,9 @@
   constexpr explicit DurationMs(const UnderlyingType& v)
       : webrtc::StrongAlias<class DurationMsTag, int32_t>(v) {}
 
+  static constexpr DurationMs InfiniteDuration() {
+    return DurationMs(std::numeric_limits<int32_t>::max());
+  }
   // Convenience methods for working with time.
   constexpr DurationMs& operator+=(DurationMs d) {
     value_ += d.value_;
diff --git a/net/dcsctp/socket/dcsctp_socket.cc b/net/dcsctp/socket/dcsctp_socket.cc
index a6845e3..6101007 100644
--- a/net/dcsctp/socket/dcsctp_socket.cc
+++ b/net/dcsctp/socket/dcsctp_socket.cc
@@ -1102,9 +1102,7 @@
 
   if (tcb_->data_tracker().Observe(tsn, immediate_ack)) {
     tcb_->reassembly_queue().Add(tsn, std::move(data));
-    tcb_->reassembly_queue().MaybeResetStreamsDeferred(
-        tcb_->data_tracker().last_cumulative_acked_tsn());
-    DeliverReassembledMessages();
+    MaybeResetStreamsDeferredAndDeliverMessages();
   }
 }
 
@@ -1455,12 +1453,15 @@
   callbacks_.OnConnected();
 }
 
-void DcSctpSocket::DeliverReassembledMessages() {
-  if (tcb_->reassembly_queue().HasMessages()) {
-    for (auto& message : tcb_->reassembly_queue().FlushMessages()) {
-      ++metrics_.rx_messages_count;
-      callbacks_.OnMessageReceived(std::move(message));
-    }
+void DcSctpSocket::MaybeResetStreamsDeferredAndDeliverMessages() {
+  // As new data has been received, see if paused streams can be resumed, which
+  // results in even more data added to the reassembly queue.
+  tcb_->reassembly_queue().MaybeResetStreamsDeferred(
+      tcb_->data_tracker().last_cumulative_acked_tsn());
+
+  for (auto& message : tcb_->reassembly_queue().FlushMessages()) {
+    ++metrics_.rx_messages_count;
+    callbacks_.OnMessageReceived(std::move(message));
   }
 }
 
@@ -1710,12 +1711,10 @@
   }
   tcb_->data_tracker().HandleForwardTsn(chunk.new_cumulative_tsn());
   tcb_->reassembly_queue().Handle(chunk);
+
   // A forward TSN - for ordered streams - may allow messages to be
   // delivered.
-  DeliverReassembledMessages();
-
-  // Processing a FORWARD_TSN might result in sending a SACK.
-  tcb_->MaybeSendSack();
+  MaybeResetStreamsDeferredAndDeliverMessages();
 }
 
 void DcSctpSocket::MaybeSendShutdownOrAck() {
diff --git a/net/dcsctp/socket/dcsctp_socket.h b/net/dcsctp/socket/dcsctp_socket.h
index 157c515..4f7d178 100644
--- a/net/dcsctp/socket/dcsctp_socket.h
+++ b/net/dcsctp/socket/dcsctp_socket.h
@@ -179,8 +179,10 @@
   // Parses `payload`, which is a serialized packet that is just going to be
   // sent and prints all chunks.
   void DebugPrintOutgoing(rtc::ArrayView<const uint8_t> payload);
-  // Called whenever there may be reassembled messages, and delivers those.
-  void DeliverReassembledMessages();
+  // Called whenever data has been received, or the cumulative acknowledgment
+  // TSN has moved, that may result in performing deferred stream resetting and
+  // delivering messages.
+  void MaybeResetStreamsDeferredAndDeliverMessages();
   // Returns true if there is a TCB, and false otherwise (and reports an error).
   bool ValidateHasTCB();
 
diff --git a/net/dcsctp/socket/dcsctp_socket_test.cc b/net/dcsctp/socket/dcsctp_socket_test.cc
index 4d8fc8a..1320284 100644
--- a/net/dcsctp/socket/dcsctp_socket_test.cc
+++ b/net/dcsctp/socket/dcsctp_socket_test.cc
@@ -9,6 +9,7 @@
  */
 #include "net/dcsctp/socket/dcsctp_socket.h"
 
+#include <algorithm>
 #include <cstdint>
 #include <deque>
 #include <memory>
@@ -30,6 +31,7 @@
 #include "net/dcsctp/packet/chunk/data_chunk.h"
 #include "net/dcsctp/packet/chunk/data_common.h"
 #include "net/dcsctp/packet/chunk/error_chunk.h"
+#include "net/dcsctp/packet/chunk/forward_tsn_chunk.h"
 #include "net/dcsctp/packet/chunk/heartbeat_ack_chunk.h"
 #include "net/dcsctp/packet/chunk/heartbeat_request_chunk.h"
 #include "net/dcsctp/packet/chunk/idata_chunk.h"
@@ -275,6 +277,26 @@
   RunTimers(z);
 }
 
+// Exchanges messages between `a` and `z`, advancing time until there are no
+// more pending timers, or until `max_timeout` is reached.
+void ExchangeMessagesAndAdvanceTime(
+    SocketUnderTest& a,
+    SocketUnderTest& z,
+    DurationMs max_timeout = DurationMs(10000)) {
+  TimeMs time_started = a.cb.TimeMillis();
+  while (a.cb.TimeMillis() - time_started < max_timeout) {
+    ExchangeMessages(a, z);
+
+    DurationMs time_to_next_timeout =
+        std::min(a.cb.GetTimeToNextTimeout(), z.cb.GetTimeToNextTimeout());
+    if (time_to_next_timeout == DurationMs::InfiniteDuration()) {
+      // No more pending timer.
+      return;
+    }
+    AdvanceTime(a, z, time_to_next_timeout);
+  }
+}
+
 // Calls Connect() on `sock_a_` and make the connection established.
 void ConnectSockets(SocketUnderTest& a, SocketUnderTest& z) {
   EXPECT_CALL(a.cb, OnConnected).Times(1);
@@ -2977,5 +2999,60 @@
 
   MaybeHandoverSocketAndSendMessage(a, std::move(z));
 }
+
+TEST(DcSctpSocketTest, HandlesForwardTsnOutOfOrderWithStreamResetting) {
+  // This test ensures that receiving FORWARD-TSN and RECONFIG out of order is
+  // handled correctly.
+  SocketUnderTest a("A", {.heartbeat_interval = DurationMs(0)});
+  SocketUnderTest z("Z", {.heartbeat_interval = DurationMs(0)});
+
+  ConnectSockets(a, z);
+  std::vector<uint8_t> payload(kSmallMessageSize);
+  a.socket.Send(DcSctpMessage(StreamID(1), PPID(51), payload),
+                {
+                    .max_retransmissions = 0,
+                });
+
+  // Packet is lost.
+  EXPECT_THAT(a.cb.ConsumeSentPacket(),
+              HasChunks(ElementsAre(
+                  IsDataChunk(AllOf(Property(&DataChunk::ssn, SSN(0)),
+                                    Property(&DataChunk::ppid, PPID(51)))))));
+  AdvanceTime(a, z, a.options.rto_initial);
+
+  auto fwd_tsn_packet = a.cb.ConsumeSentPacket();
+  EXPECT_THAT(fwd_tsn_packet,
+              HasChunks(ElementsAre(IsChunkType(ForwardTsnChunk::kType))));
+  // Reset stream 1
+  a.socket.ResetStreams(std::vector<StreamID>({StreamID(1)}));
+  auto reconfig_packet = a.cb.ConsumeSentPacket();
+  EXPECT_THAT(reconfig_packet,
+              HasChunks(ElementsAre(IsChunkType(ReConfigChunk::kType))));
+
+  // These two packets are received in the wrong order.
+  z.socket.ReceivePacket(reconfig_packet);
+  z.socket.ReceivePacket(fwd_tsn_packet);
+  ExchangeMessagesAndAdvanceTime(a, z);
+
+  a.socket.Send(DcSctpMessage(StreamID(1), PPID(52), payload), {});
+  a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), {});
+
+  auto data_packet_2 = a.cb.ConsumeSentPacket();
+  auto data_packet_3 = a.cb.ConsumeSentPacket();
+  EXPECT_THAT(data_packet_2, HasChunks(ElementsAre(IsDataChunk(AllOf(
+                                 Property(&DataChunk::ssn, SSN(0)),
+                                 Property(&DataChunk::ppid, PPID(52)))))));
+  EXPECT_THAT(data_packet_3, HasChunks(ElementsAre(IsDataChunk(AllOf(
+                                 Property(&DataChunk::ssn, SSN(1)),
+                                 Property(&DataChunk::ppid, PPID(53)))))));
+
+  z.socket.ReceivePacket(data_packet_2);
+  z.socket.ReceivePacket(data_packet_3);
+  ASSERT_THAT(z.cb.ConsumeReceivedMessage(),
+              testing::Optional(Property(&DcSctpMessage::ppid, PPID(52))));
+  ASSERT_THAT(z.cb.ConsumeReceivedMessage(),
+              testing::Optional(Property(&DcSctpMessage::ppid, PPID(53))));
+}
+
 }  // namespace
 }  // namespace dcsctp
diff --git a/net/dcsctp/socket/mock_dcsctp_socket_callbacks.h b/net/dcsctp/socket/mock_dcsctp_socket_callbacks.h
index 8b2a772..150c1b9 100644
--- a/net/dcsctp/socket/mock_dcsctp_socket_callbacks.h
+++ b/net/dcsctp/socket/mock_dcsctp_socket_callbacks.h
@@ -166,6 +166,10 @@
     return timeout_manager_.GetNextExpiredTimeout();
   }
 
+  DurationMs GetTimeToNextTimeout() const {
+    return timeout_manager_.GetTimeToNextTimeout();
+  }
+
  private:
   const std::string log_prefix_;
   TimeMs now_ = TimeMs(0);
diff --git a/net/dcsctp/timer/fake_timeout.h b/net/dcsctp/timer/fake_timeout.h
index 74ffe5a..4621b2c 100644
--- a/net/dcsctp/timer/fake_timeout.h
+++ b/net/dcsctp/timer/fake_timeout.h
@@ -20,6 +20,7 @@
 #include "absl/types/optional.h"
 #include "api/task_queue/task_queue_base.h"
 #include "net/dcsctp/public/timeout.h"
+#include "net/dcsctp/public/types.h"
 #include "rtc_base/checks.h"
 #include "rtc_base/containers/flat_set.h"
 
@@ -53,6 +54,7 @@
   }
 
   TimeoutID timeout_id() const { return timeout_id_; }
+  TimeMs expiry() const { return expiry_; }
 
  private:
   const std::function<TimeMs()> get_time_;
@@ -97,6 +99,19 @@
     return absl::nullopt;
   }
 
+  DurationMs GetTimeToNextTimeout() const {
+    TimeMs next_expiry = TimeMs::InfiniteFuture();
+    for (const FakeTimeout* timer : timers_) {
+      if (timer->expiry() < next_expiry) {
+        next_expiry = timer->expiry();
+      }
+    }
+    TimeMs now = get_time_();
+    return next_expiry != TimeMs::InfiniteFuture() && next_expiry >= now
+               ? next_expiry - now
+               : DurationMs::InfiniteDuration();
+  }
+
  private:
   const std::function<TimeMs()> get_time_;
   webrtc::flat_set<FakeTimeout*> timers_;