dcsctp: Enable message interleaving

This adds support to enable message interleaving in the stream scheduler
from the socket, proxied by the send queue.

It also adds socket unit tests to ensure that prioritization and
interleaving works. Also, send queue test has been added to validate the
integration of the stream scheduler. But the actual scheduling parts of
it will be tested in the stream scheduler unit tests.

Bug: webrtc:5696
Change-Id: Ic7d3d2dc28405c77a107f0148f0096882961eec7
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/262248
Commit-Queue: Victor Boivie <boivie@webrtc.org>
Reviewed-by: Harald Alvestrand <hta@webrtc.org>
Cr-Commit-Position: refs/heads/main@{#37355}
diff --git a/net/dcsctp/socket/dcsctp_socket.cc b/net/dcsctp/socket/dcsctp_socket.cc
index 822040e..56abb49 100644
--- a/net/dcsctp/socket/dcsctp_socket.cc
+++ b/net/dcsctp/socket/dcsctp_socket.cc
@@ -189,6 +189,7 @@
       send_queue_(
           log_prefix_,
           options_.max_send_buffer_size,
+          options_.mtu,
           options_.default_stream_priority,
           [this](StreamID stream_id) {
             callbacks_.OnBufferedAmountLow(stream_id);
diff --git a/net/dcsctp/socket/dcsctp_socket_test.cc b/net/dcsctp/socket/dcsctp_socket_test.cc
index 82fbb1b..e70378f 100644
--- a/net/dcsctp/socket/dcsctp_socket_test.cc
+++ b/net/dcsctp/socket/dcsctp_socket_test.cc
@@ -371,6 +371,18 @@
   return handover_socket;
 }
 
+std::vector<uint32_t> GetReceivedMessagePpids(SocketUnderTest& z) {
+  std::vector<uint32_t> ppids;
+  for (;;) {
+    absl::optional<DcSctpMessage> msg = z.cb.ConsumeReceivedMessage();
+    if (!msg.has_value()) {
+      break;
+    }
+    ppids.push_back(*msg->ppid());
+  }
+  return ppids;
+}
+
 // Test parameter that controls whether to perform handovers during the test. A
 // test can have multiple points where it conditionally hands over socket Z.
 // Either socket Z will be handed over at all those points or handed over never.
@@ -2403,5 +2415,110 @@
   ExchangeMessages(a, z);
   a.socket.ResetStreams(std::vector<StreamID>({StreamID(2)}));
 }
+
+TEST(DcSctpSocketTest, SmallSentMessagesWithPrioWillArriveInSpecificOrder) {
+  DcSctpOptions options = {.enable_message_interleaving = true};
+  SocketUnderTest a("A", options);
+  SocketUnderTest z("A", options);
+
+  a.socket.SetStreamPriority(StreamID(1), StreamPriority(700));
+  a.socket.SetStreamPriority(StreamID(2), StreamPriority(200));
+  a.socket.SetStreamPriority(StreamID(3), StreamPriority(100));
+
+  // Enqueue messages before connecting the socket, to ensure they aren't send
+  // as soon as Send() is called.
+  a.socket.Send(DcSctpMessage(StreamID(3), PPID(301),
+                              std::vector<uint8_t>(kSmallMessageSize)),
+                kSendOptions);
+  a.socket.Send(DcSctpMessage(StreamID(1), PPID(101),
+                              std::vector<uint8_t>(kSmallMessageSize)),
+                kSendOptions);
+  a.socket.Send(DcSctpMessage(StreamID(2), PPID(201),
+                              std::vector<uint8_t>(kSmallMessageSize)),
+                kSendOptions);
+  a.socket.Send(DcSctpMessage(StreamID(1), PPID(102),
+                              std::vector<uint8_t>(kSmallMessageSize)),
+                kSendOptions);
+  a.socket.Send(DcSctpMessage(StreamID(1), PPID(103),
+                              std::vector<uint8_t>(kSmallMessageSize)),
+                kSendOptions);
+
+  ConnectSockets(a, z);
+  ExchangeMessages(a, z);
+
+  std::vector<uint32_t> received_ppids;
+  for (;;) {
+    absl::optional<DcSctpMessage> msg = z.cb.ConsumeReceivedMessage();
+    if (!msg.has_value()) {
+      break;
+    }
+    received_ppids.push_back(*msg->ppid());
+  }
+
+  EXPECT_THAT(received_ppids, ElementsAre(101, 102, 103, 201, 301));
+}
+
+TEST(DcSctpSocketTest, LargeSentMessagesWithPrioWillArriveInSpecificOrder) {
+  DcSctpOptions options = {.enable_message_interleaving = true};
+  SocketUnderTest a("A", options);
+  SocketUnderTest z("A", options);
+
+  a.socket.SetStreamPriority(StreamID(1), StreamPriority(700));
+  a.socket.SetStreamPriority(StreamID(2), StreamPriority(200));
+  a.socket.SetStreamPriority(StreamID(3), StreamPriority(100));
+
+  // Enqueue messages before connecting the socket, to ensure they aren't send
+  // as soon as Send() is called.
+  a.socket.Send(DcSctpMessage(StreamID(3), PPID(301),
+                              std::vector<uint8_t>(kLargeMessageSize)),
+                kSendOptions);
+  a.socket.Send(DcSctpMessage(StreamID(1), PPID(101),
+                              std::vector<uint8_t>(kLargeMessageSize)),
+                kSendOptions);
+  a.socket.Send(DcSctpMessage(StreamID(2), PPID(201),
+                              std::vector<uint8_t>(kLargeMessageSize)),
+                kSendOptions);
+  a.socket.Send(DcSctpMessage(StreamID(1), PPID(102),
+                              std::vector<uint8_t>(kLargeMessageSize)),
+                kSendOptions);
+
+  ConnectSockets(a, z);
+  ExchangeMessages(a, z);
+
+  EXPECT_THAT(GetReceivedMessagePpids(z), ElementsAre(101, 102, 201, 301));
+}
+
+TEST(DcSctpSocketTest, MessageWithHigherPrioWillInterruptLowerPrioMessage) {
+  DcSctpOptions options = {.enable_message_interleaving = true};
+  SocketUnderTest a("A", options);
+  SocketUnderTest z("Z", options);
+
+  ConnectSockets(a, z);
+
+  a.socket.SetStreamPriority(StreamID(2), StreamPriority(128));
+  a.socket.Send(DcSctpMessage(StreamID(2), PPID(201),
+                              std::vector<uint8_t>(kLargeMessageSize)),
+                kSendOptions);
+
+  // Due to a non-zero initial congestion window, the message will already start
+  // to send, but will not succeed to be sent completely before filling the
+  // congestion window or stopping due to reaching how many packets that can be
+  // sent at once (max burst). The important thing is that the entire message
+  // doesn't get sent in full.
+
+  // Now enqueue two messages; one small and one large higher priority message.
+  a.socket.SetStreamPriority(StreamID(1), StreamPriority(512));
+  a.socket.Send(DcSctpMessage(StreamID(1), PPID(101),
+                              std::vector<uint8_t>(kSmallMessageSize)),
+                kSendOptions);
+  a.socket.Send(DcSctpMessage(StreamID(1), PPID(102),
+                              std::vector<uint8_t>(kLargeMessageSize)),
+                kSendOptions);
+
+  ExchangeMessages(a, z);
+
+  EXPECT_THAT(GetReceivedMessagePpids(z), ElementsAre(101, 102, 201));
+}
+
 }  // namespace
 }  // namespace dcsctp
diff --git a/net/dcsctp/socket/transmission_control_block.h b/net/dcsctp/socket/transmission_control_block.h
index 038ad36..f212788 100644
--- a/net/dcsctp/socket/transmission_control_block.h
+++ b/net/dcsctp/socket/transmission_control_block.h
@@ -129,6 +129,7 @@
     if (handover_state == nullptr) {
       send_queue.Reset();
     }
+    send_queue.EnableMessageInterleaving(capabilities.message_interleaving);
   }
 
   // Implementation of `Context`.
diff --git a/net/dcsctp/tx/mock_send_queue.h b/net/dcsctp/tx/mock_send_queue.h
index 82e96b7..0c8f5d1 100644
--- a/net/dcsctp/tx/mock_send_queue.h
+++ b/net/dcsctp/tx/mock_send_queue.h
@@ -52,6 +52,7 @@
               SetBufferedAmountLowThreshold,
               (StreamID stream_id, size_t bytes),
               (override));
+  MOCK_METHOD(void, EnableMessageInterleaving, (bool enabled), (override));
 };
 
 }  // namespace dcsctp
diff --git a/net/dcsctp/tx/rr_send_queue.cc b/net/dcsctp/tx/rr_send_queue.cc
index bec6f08..174d19b 100644
--- a/net/dcsctp/tx/rr_send_queue.cc
+++ b/net/dcsctp/tx/rr_send_queue.cc
@@ -32,6 +32,7 @@
 
 RRSendQueue::RRSendQueue(absl::string_view log_prefix,
                          size_t buffer_size,
+                         size_t mtu,
                          StreamPriority default_priority,
                          std::function<void(StreamID)> on_buffered_amount_low,
                          size_t total_buffered_amount_low_threshold,
@@ -39,8 +40,7 @@
     : log_prefix_(std::string(log_prefix) + "fcfs: "),
       buffer_size_(buffer_size),
       default_priority_(default_priority),
-      // TODO(webrtc:5696): Provide correct MTU.
-      scheduler_(DcSctpOptions::kMaxSafeMTUSize),
+      scheduler_(mtu),
       on_buffered_amount_low_(std::move(on_buffered_amount_low)),
       total_buffered_amount_(std::move(on_total_buffered_amount_low)) {
   total_buffered_amount_.SetLowThreshold(total_buffered_amount_low_threshold);
diff --git a/net/dcsctp/tx/rr_send_queue.h b/net/dcsctp/tx/rr_send_queue.h
index c2f1ee8..49c36fe 100644
--- a/net/dcsctp/tx/rr_send_queue.h
+++ b/net/dcsctp/tx/rr_send_queue.h
@@ -45,6 +45,7 @@
  public:
   RRSendQueue(absl::string_view log_prefix,
               size_t buffer_size,
+              size_t mtu,
               StreamPriority default_priority,
               std::function<void(StreamID)> on_buffered_amount_low,
               size_t total_buffered_amount_low_threshold,
@@ -81,6 +82,9 @@
   }
   size_t buffered_amount_low_threshold(StreamID stream_id) const override;
   void SetBufferedAmountLowThreshold(StreamID stream_id, size_t bytes) override;
+  void EnableMessageInterleaving(bool enabled) override {
+    scheduler_.EnableMessageInterleaving(enabled);
+  }
 
   void SetStreamPriority(StreamID stream_id, StreamPriority priority);
   StreamPriority GetStreamPriority(StreamID stream_id) const;
diff --git a/net/dcsctp/tx/rr_send_queue_test.cc b/net/dcsctp/tx/rr_send_queue_test.cc
index 3966c17..7471ccc 100644
--- a/net/dcsctp/tx/rr_send_queue_test.cc
+++ b/net/dcsctp/tx/rr_send_queue_test.cc
@@ -36,12 +36,14 @@
 constexpr size_t kBufferedAmountLowThreshold = 500;
 constexpr size_t kOneFragmentPacketSize = 100;
 constexpr size_t kTwoFragmentPacketSize = 101;
+constexpr size_t kMtu = 1100;
 
 class RRSendQueueTest : public testing::Test {
  protected:
   RRSendQueueTest()
       : buf_("log: ",
              kMaxQueueSize,
+             kMtu,
              kDefaultPriority,
              on_buffered_amount_low_.AsStdFunction(),
              kBufferedAmountLowThreshold,
@@ -787,7 +789,7 @@
   DcSctpSocketHandoverState state;
   buf_.AddHandoverState(state);
 
-  RRSendQueue q2("log: ", kMaxQueueSize, kDefaultPriority,
+  RRSendQueue q2("log: ", kMaxQueueSize, kMtu, kDefaultPriority,
                  on_buffered_amount_low_.AsStdFunction(),
                  kBufferedAmountLowThreshold,
                  on_total_buffered_amount_low_.AsStdFunction());
@@ -795,5 +797,25 @@
   EXPECT_EQ(q2.GetStreamPriority(StreamID(1)), StreamPriority(42));
   EXPECT_EQ(q2.GetStreamPriority(StreamID(2)), StreamPriority(42));
 }
+
+TEST_F(RRSendQueueTest, WillSendMessagesByPrio) {
+  buf_.EnableMessageInterleaving(true);
+  buf_.SetStreamPriority(StreamID(1), StreamPriority(10));
+  buf_.SetStreamPriority(StreamID(2), StreamPriority(20));
+  buf_.SetStreamPriority(StreamID(3), StreamPriority(30));
+
+  buf_.Add(kNow, DcSctpMessage(StreamID(1), kPPID, std::vector<uint8_t>(40)));
+  buf_.Add(kNow, DcSctpMessage(StreamID(2), kPPID, std::vector<uint8_t>(20)));
+  buf_.Add(kNow, DcSctpMessage(StreamID(3), kPPID, std::vector<uint8_t>(10)));
+  std::vector<uint16_t> expected_streams = {3, 2, 2, 1, 1, 1, 1};
+
+  for (uint16_t stream_num : expected_streams) {
+    ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk,
+                                buf_.Produce(kNow, 10));
+    EXPECT_EQ(chunk.data.stream_id, StreamID(stream_num));
+  }
+  EXPECT_FALSE(buf_.Produce(kNow, 1).has_value());
+}
+
 }  // namespace
 }  // namespace dcsctp
diff --git a/net/dcsctp/tx/send_queue.h b/net/dcsctp/tx/send_queue.h
index b2e5a9d..a7e6635 100644
--- a/net/dcsctp/tx/send_queue.h
+++ b/net/dcsctp/tx/send_queue.h
@@ -126,6 +126,12 @@
   // Sets a limit for the `OnBufferedAmountLow` event.
   virtual void SetBufferedAmountLowThreshold(StreamID stream_id,
                                              size_t bytes) = 0;
+
+  // Configures the send queue to support interleaved message sending as
+  // described in RFC8260. Every send queue starts with this value set as
+  // disabled, but can later change it when the capabilities of the connection
+  // have been negotiated. This affects the behavior of the `Produce` method.
+  virtual void EnableMessageInterleaving(bool enabled) = 0;
 };
 }  // namespace dcsctp