dcsctp: Add public API for setting priorities

This is the first part of supporting stream priorities, and adds the API
and very basic support for setting and retrieving the stream priority.

This commit doesn't in any way change the actual packet sending - the
specified priority values are stored, but not acted on.

This is all that is client visible, so clients can start using the API
as written, and they would never notice that things are missing.

Bug: webrtc:5696
Change-Id: I24fce8cbb6f3cba187df99d1d3f45e73621c93c6
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/261943
Reviewed-by: Harald Alvestrand <hta@webrtc.org>
Commit-Queue: Victor Boivie <boivie@webrtc.org>
Cr-Commit-Position: refs/heads/main@{#37034}
diff --git a/net/dcsctp/public/dcsctp_handover_state.h b/net/dcsctp/public/dcsctp_handover_state.h
index a58535d..36fc37b 100644
--- a/net/dcsctp/public/dcsctp_handover_state.h
+++ b/net/dcsctp/public/dcsctp_handover_state.h
@@ -48,6 +48,7 @@
     uint32_t next_ssn = 0;
     uint32_t next_unordered_mid = 0;
     uint32_t next_ordered_mid = 0;
+    uint16_t priority = 0;
   };
   struct Transmission {
     uint32_t next_tsn = 0;
diff --git a/net/dcsctp/public/dcsctp_options.h b/net/dcsctp/public/dcsctp_options.h
index c394552..4511bed 100644
--- a/net/dcsctp/public/dcsctp_options.h
+++ b/net/dcsctp/public/dcsctp_options.h
@@ -71,6 +71,11 @@
   // `max_receiver_window_buffer_size`).
   size_t max_message_size = 256 * 1024;
 
+  // The default stream priority, if not overridden by
+  // `SctpSocket::SetStreamPriority`. The default value is selected to be
+  // compatible with https://www.w3.org/TR/webrtc-priority/, section 4.2-4.3.
+  StreamPriority default_stream_priority = StreamPriority(256);
+
   // Maximum received window buffer size. This should be a bit larger than the
   // largest sized message you want to be able to receive. This essentially
   // limits the memory usage on the receive side. Note that memory is allocated
diff --git a/net/dcsctp/public/dcsctp_socket.h b/net/dcsctp/public/dcsctp_socket.h
index e15a5bf..0a65dae 100644
--- a/net/dcsctp/public/dcsctp_socket.h
+++ b/net/dcsctp/public/dcsctp_socket.h
@@ -430,6 +430,15 @@
   // Update the options max_message_size.
   virtual void SetMaxMessageSize(size_t max_message_size) = 0;
 
+  // Sets the priority of an outgoing stream. The initial value, when not set,
+  // is `DcSctpOptions::default_stream_priority`.
+  virtual void SetStreamPriority(StreamID stream_id,
+                                 StreamPriority priority) = 0;
+
+  // Returns the currently set priority for an outgoing stream. The initial
+  // value, when not set, is `DcSctpOptions::default_stream_priority`.
+  virtual StreamPriority GetStreamPriority(StreamID stream_id) const = 0;
+
   // Sends the message `message` using the provided send options.
   // Sending a message is an asynchrous operation, and the `OnError` callback
   // may be invoked to indicate any errors in sending the message.
diff --git a/net/dcsctp/public/mock_dcsctp_socket.h b/net/dcsctp/public/mock_dcsctp_socket.h
index 6560a3f..0fd572b 100644
--- a/net/dcsctp/public/mock_dcsctp_socket.h
+++ b/net/dcsctp/public/mock_dcsctp_socket.h
@@ -41,6 +41,16 @@
 
   MOCK_METHOD(void, SetMaxMessageSize, (size_t max_message_size), (override));
 
+  MOCK_METHOD(void,
+              SetStreamPriority,
+              (StreamID stream_id, StreamPriority priority),
+              (override));
+
+  MOCK_METHOD(StreamPriority,
+              GetStreamPriority,
+              (StreamID stream_id),
+              (const, override));
+
   MOCK_METHOD(SendStatus,
               Send,
               (DcSctpMessage message, const SendOptions& send_options),
diff --git a/net/dcsctp/public/types.h b/net/dcsctp/public/types.h
index caa03bb..358e243 100644
--- a/net/dcsctp/public/types.h
+++ b/net/dcsctp/public/types.h
@@ -31,6 +31,10 @@
 // other messages on the same stream.
 using IsUnordered = webrtc::StrongAlias<class IsUnorderedTag, bool>;
 
+// Stream priority, where higher values indicate higher priority. The meaning of
+// this value and how it's used depends on the stream scheduler.
+using StreamPriority = webrtc::StrongAlias<class StreamPriorityTag, uint16_t>;
+
 // Duration, as milliseconds. Overflows after 24 days.
 class DurationMs : public webrtc::StrongAlias<class DurationMsTag, int32_t> {
  public:
diff --git a/net/dcsctp/socket/dcsctp_socket.cc b/net/dcsctp/socket/dcsctp_socket.cc
index 9d6ae0e..e0a912c 100644
--- a/net/dcsctp/socket/dcsctp_socket.cc
+++ b/net/dcsctp/socket/dcsctp_socket.cc
@@ -189,6 +189,7 @@
       send_queue_(
           log_prefix_,
           options_.max_send_buffer_size,
+          options_.default_stream_priority,
           [this](StreamID stream_id) {
             callbacks_.OnBufferedAmountLow(stream_id);
           },
@@ -420,6 +421,14 @@
   RTC_DCHECK(IsConsistent());
 }
 
+void DcSctpSocket::SetStreamPriority(StreamID stream_id,
+                                     StreamPriority priority) {
+  send_queue_.SetStreamPriority(stream_id, priority);
+}
+StreamPriority DcSctpSocket::GetStreamPriority(StreamID stream_id) const {
+  return send_queue_.GetStreamPriority(stream_id);
+}
+
 SendStatus DcSctpSocket::Send(DcSctpMessage message,
                               const SendOptions& send_options) {
   RTC_DCHECK_RUN_ON(&thread_checker_);
diff --git a/net/dcsctp/socket/dcsctp_socket.h b/net/dcsctp/socket/dcsctp_socket.h
index 07e760a..d70d0fc 100644
--- a/net/dcsctp/socket/dcsctp_socket.h
+++ b/net/dcsctp/socket/dcsctp_socket.h
@@ -96,6 +96,8 @@
   SocketState state() const override;
   const DcSctpOptions& options() const override { return options_; }
   void SetMaxMessageSize(size_t max_message_size) override;
+  void SetStreamPriority(StreamID stream_id, StreamPriority priority) override;
+  StreamPriority GetStreamPriority(StreamID stream_id) const override;
   size_t buffered_amount(StreamID stream_id) const override;
   size_t buffered_amount_low_threshold(StreamID stream_id) const override;
   void SetBufferedAmountLowThreshold(StreamID stream_id, size_t bytes) override;
diff --git a/net/dcsctp/socket/dcsctp_socket_test.cc b/net/dcsctp/socket/dcsctp_socket_test.cc
index 770fd84..cc5566f 100644
--- a/net/dcsctp/socket/dcsctp_socket_test.cc
+++ b/net/dcsctp/socket/dcsctp_socket_test.cc
@@ -2333,6 +2333,51 @@
   absl::optional<DcSctpMessage> msg6 = z.cb.ConsumeReceivedMessage();
   ASSERT_TRUE(msg6.has_value());
   EXPECT_EQ(msg6->stream_id(), StreamID(3));
-}  // namespace
+}
+
+TEST(DcSctpSocketTest, StreamsHaveInitialPriority) {
+  DcSctpOptions options = {.default_stream_priority = StreamPriority(42)};
+  SocketUnderTest a("A", options);
+
+  EXPECT_EQ(a.socket.GetStreamPriority(StreamID(1)),
+            options.default_stream_priority);
+
+  a.socket.Send(DcSctpMessage(StreamID(2), PPID(53), {1, 2}), kSendOptions);
+
+  EXPECT_EQ(a.socket.GetStreamPriority(StreamID(2)),
+            options.default_stream_priority);
+}
+
+TEST(DcSctpSocketTest, CanChangeStreamPriority) {
+  DcSctpOptions options = {.default_stream_priority = StreamPriority(42)};
+  SocketUnderTest a("A", options);
+
+  a.socket.SetStreamPriority(StreamID(1), StreamPriority(43));
+  EXPECT_EQ(a.socket.GetStreamPriority(StreamID(1)), StreamPriority(43));
+
+  a.socket.Send(DcSctpMessage(StreamID(2), PPID(53), {1, 2}), kSendOptions);
+
+  a.socket.SetStreamPriority(StreamID(2), StreamPriority(43));
+  EXPECT_EQ(a.socket.GetStreamPriority(StreamID(2)), StreamPriority(43));
+}
+
+TEST_P(DcSctpSocketParametrizedTest, WillHandoverPriority) {
+  DcSctpOptions options = {.default_stream_priority = StreamPriority(42)};
+  auto a = std::make_unique<SocketUnderTest>("A", options);
+  SocketUnderTest z("Z");
+
+  ConnectSockets(*a, z);
+
+  a->socket.SetStreamPriority(StreamID(1), StreamPriority(43));
+  a->socket.Send(DcSctpMessage(StreamID(2), PPID(53), {1, 2}), kSendOptions);
+  a->socket.SetStreamPriority(StreamID(2), StreamPriority(43));
+
+  ExchangeMessages(*a, z);
+
+  a = MaybeHandoverSocket(std::move(a));
+
+  EXPECT_EQ(a->socket.GetStreamPriority(StreamID(1)), StreamPriority(43));
+  EXPECT_EQ(a->socket.GetStreamPriority(StreamID(2)), StreamPriority(43));
+}
 }  // namespace
 }  // namespace dcsctp
diff --git a/net/dcsctp/tx/rr_send_queue.cc b/net/dcsctp/tx/rr_send_queue.cc
index d4ce59d..3a2166b 100644
--- a/net/dcsctp/tx/rr_send_queue.cc
+++ b/net/dcsctp/tx/rr_send_queue.cc
@@ -30,11 +30,13 @@
 
 RRSendQueue::RRSendQueue(absl::string_view log_prefix,
                          size_t buffer_size,
+                         StreamPriority default_priority,
                          std::function<void(StreamID)> on_buffered_amount_low,
                          size_t total_buffered_amount_low_threshold,
                          std::function<void()> on_total_buffered_amount_low)
     : log_prefix_(std::string(log_prefix) + "fcfs: "),
       buffer_size_(buffer_size),
+      default_priority_(default_priority),
       on_buffered_amount_low_(std::move(on_buffered_amount_low)),
       total_buffered_amount_(std::move(on_total_buffered_amount_low)) {
   total_buffered_amount_.SetLowThreshold(total_buffered_amount_low_threshold);
@@ -75,6 +77,7 @@
   state.next_ssn = next_ssn_.value();
   state.next_ordered_mid = next_ordered_mid_.value();
   state.next_unordered_mid = next_unordered_mid_.value();
+  state.priority = *priority_;
 }
 
 bool RRSendQueue::IsConsistent() const {
@@ -515,12 +518,28 @@
   return streams_
       .emplace(stream_id,
                OutgoingStream(
-                   stream_id,
+                   stream_id, default_priority_,
                    [this, stream_id]() { on_buffered_amount_low_(stream_id); },
                    total_buffered_amount_))
       .first->second;
 }
 
+void RRSendQueue::SetStreamPriority(StreamID stream_id,
+                                    StreamPriority priority) {
+  OutgoingStream& stream = GetOrCreateStreamInfo(stream_id);
+
+  stream.set_priority(priority);
+  RTC_DCHECK(IsConsistent());
+}
+
+StreamPriority RRSendQueue::GetStreamPriority(StreamID stream_id) const {
+  auto stream_it = streams_.find(stream_id);
+  if (stream_it == streams_.end()) {
+    return default_priority_;
+  }
+  return stream_it->second.priority();
+}
+
 HandoverReadinessStatus RRSendQueue::GetHandoverReadiness() const {
   HandoverReadinessStatus status;
   if (!IsEmpty()) {
@@ -542,12 +561,12 @@
   for (const DcSctpSocketHandoverState::OutgoingStream& state_stream :
        state.tx.streams) {
     StreamID stream_id(state_stream.id);
-    streams_.emplace(stream_id, OutgoingStream(
-                                    stream_id,
-                                    [this, stream_id]() {
-                                      on_buffered_amount_low_(stream_id);
-                                    },
-                                    total_buffered_amount_, &state_stream));
+    streams_.emplace(
+        stream_id,
+        OutgoingStream(
+            stream_id, StreamPriority(state_stream.priority),
+            [this, stream_id]() { on_buffered_amount_low_(stream_id); },
+            total_buffered_amount_, &state_stream));
   }
 }
 }  // namespace dcsctp
diff --git a/net/dcsctp/tx/rr_send_queue.h b/net/dcsctp/tx/rr_send_queue.h
index 57a43cc..7ddb426 100644
--- a/net/dcsctp/tx/rr_send_queue.h
+++ b/net/dcsctp/tx/rr_send_queue.h
@@ -43,6 +43,7 @@
  public:
   RRSendQueue(absl::string_view log_prefix,
               size_t buffer_size,
+              StreamPriority default_priority,
               std::function<void(StreamID)> on_buffered_amount_low,
               size_t total_buffered_amount_low_threshold,
               std::function<void()> on_total_buffered_amount_low);
@@ -79,6 +80,8 @@
   size_t buffered_amount_low_threshold(StreamID stream_id) const override;
   void SetBufferedAmountLowThreshold(StreamID stream_id, size_t bytes) override;
 
+  void SetStreamPriority(StreamID stream_id, StreamPriority priority);
+  StreamPriority GetStreamPriority(StreamID stream_id) const;
   HandoverReadinessStatus GetHandoverReadiness() const;
   void AddHandoverState(DcSctpSocketHandoverState& state);
   void RestoreFromState(const DcSctpSocketHandoverState& state);
@@ -112,10 +115,12 @@
    public:
     OutgoingStream(
         StreamID stream_id,
+        StreamPriority priority,
         std::function<void()> on_buffered_amount_low,
         ThresholdWatcher& total_buffered_amount,
         const DcSctpSocketHandoverState::OutgoingStream* state = nullptr)
         : stream_id_(stream_id),
+          priority_(priority),
           next_unordered_mid_(MID(state ? state->next_unordered_mid : 0)),
           next_ordered_mid_(MID(state ? state->next_ordered_mid : 0)),
           next_ssn_(SSN(state ? state->next_ssn : 0)),
@@ -166,6 +171,9 @@
     // expired non-partially sent message.
     bool HasDataToSend(TimeMs now);
 
+    void set_priority(StreamPriority priority) { priority_ = priority; }
+    StreamPriority priority() const { return priority_; }
+
     void AddHandoverState(
         DcSctpSocketHandoverState::OutgoingStream& state) const;
 
@@ -218,6 +226,7 @@
     bool IsConsistent() const;
 
     const StreamID stream_id_;
+    StreamPriority priority_;
     PauseState pause_state_ = PauseState::kNotPaused;
     // MIDs are different for unordered and ordered messages sent on a stream.
     MID next_unordered_mid_;
@@ -247,6 +256,7 @@
 
   const std::string log_prefix_;
   const size_t buffer_size_;
+  const StreamPriority default_priority_;
 
   // Called when the buffered amount is below what has been set using
   // `SetBufferedAmountLowThreshold`.
diff --git a/net/dcsctp/tx/rr_send_queue_test.cc b/net/dcsctp/tx/rr_send_queue_test.cc
index fbbce58..3966c17 100644
--- a/net/dcsctp/tx/rr_send_queue_test.cc
+++ b/net/dcsctp/tx/rr_send_queue_test.cc
@@ -32,6 +32,7 @@
 constexpr StreamID kStreamID(1);
 constexpr PPID kPPID(53);
 constexpr size_t kMaxQueueSize = 1000;
+constexpr StreamPriority kDefaultPriority(10);
 constexpr size_t kBufferedAmountLowThreshold = 500;
 constexpr size_t kOneFragmentPacketSize = 100;
 constexpr size_t kTwoFragmentPacketSize = 101;
@@ -41,6 +42,7 @@
   RRSendQueueTest()
       : buf_("log: ",
              kMaxQueueSize,
+             kDefaultPriority,
              on_buffered_amount_low_.AsStdFunction(),
              kBufferedAmountLowThreshold,
              on_total_buffered_amount_low_.AsStdFunction()) {}
@@ -759,5 +761,39 @@
 
   EXPECT_FALSE(buf_.Produce(kNow, kOneFragmentPacketSize).has_value());
 }
+
+TEST_F(RRSendQueueTest, StreamsHaveInitialPriority) {
+  EXPECT_EQ(buf_.GetStreamPriority(StreamID(1)), kDefaultPriority);
+
+  buf_.Add(kNow, DcSctpMessage(StreamID(2), kPPID, std::vector<uint8_t>(40)));
+  EXPECT_EQ(buf_.GetStreamPriority(StreamID(2)), kDefaultPriority);
+}
+
+TEST_F(RRSendQueueTest, CanChangeStreamPriority) {
+  buf_.SetStreamPriority(StreamID(1), StreamPriority(42));
+  EXPECT_EQ(buf_.GetStreamPriority(StreamID(1)), StreamPriority(42));
+
+  buf_.Add(kNow, DcSctpMessage(StreamID(2), kPPID, std::vector<uint8_t>(40)));
+  buf_.SetStreamPriority(StreamID(2), StreamPriority(42));
+  EXPECT_EQ(buf_.GetStreamPriority(StreamID(2)), StreamPriority(42));
+}
+
+TEST_F(RRSendQueueTest, WillHandoverPriority) {
+  buf_.SetStreamPriority(StreamID(1), StreamPriority(42));
+
+  buf_.Add(kNow, DcSctpMessage(StreamID(2), kPPID, std::vector<uint8_t>(40)));
+  buf_.SetStreamPriority(StreamID(2), StreamPriority(42));
+
+  DcSctpSocketHandoverState state;
+  buf_.AddHandoverState(state);
+
+  RRSendQueue q2("log: ", kMaxQueueSize, kDefaultPriority,
+                 on_buffered_amount_low_.AsStdFunction(),
+                 kBufferedAmountLowThreshold,
+                 on_total_buffered_amount_low_.AsStdFunction());
+  q2.RestoreFromState(state);
+  EXPECT_EQ(q2.GetStreamPriority(StreamID(1)), StreamPriority(42));
+  EXPECT_EQ(q2.GetStreamPriority(StreamID(2)), StreamPriority(42));
+}
 }  // namespace
 }  // namespace dcsctp