dcsctp: Add public API for setting priorities

This is a reland of commit 17a02a31d7d2897b75ad69fdac5d10e7475a5865.

This is the first part of supporting stream priorities, and adds the API
and very basic support for setting and retrieving the stream priority.

This commit doesn't in any way change the actual packet sending - the
specified priority values are stored, but not acted on.

This is all that is client visible, so clients can start using the API
as written, and they would never notice that things are missing.

Bug: webrtc:5696
Change-Id: I04d64a63cbaec67568496ad99667e14eba85f2e0
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/264424
Commit-Queue: Victor Boivie <boivie@webrtc.org>
Reviewed-by: Florent Castelli <orphis@webrtc.org>
Cr-Commit-Position: refs/heads/main@{#37081}
diff --git a/net/dcsctp/public/dcsctp_options.h b/net/dcsctp/public/dcsctp_options.h
index c394552..4511bed 100644
--- a/net/dcsctp/public/dcsctp_options.h
+++ b/net/dcsctp/public/dcsctp_options.h
@@ -71,6 +71,11 @@
   // `max_receiver_window_buffer_size`).
   size_t max_message_size = 256 * 1024;
 
+  // The default stream priority, if not overridden by
+  // `SctpSocket::SetStreamPriority`. The default value is selected to be
+  // compatible with https://www.w3.org/TR/webrtc-priority/, section 4.2-4.3.
+  StreamPriority default_stream_priority = StreamPriority(256);
+
   // Maximum received window buffer size. This should be a bit larger than the
   // largest sized message you want to be able to receive. This essentially
   // limits the memory usage on the receive side. Note that memory is allocated
diff --git a/net/dcsctp/public/dcsctp_socket.h b/net/dcsctp/public/dcsctp_socket.h
index e15a5bf..0a65dae 100644
--- a/net/dcsctp/public/dcsctp_socket.h
+++ b/net/dcsctp/public/dcsctp_socket.h
@@ -430,6 +430,15 @@
   // Update the options max_message_size.
   virtual void SetMaxMessageSize(size_t max_message_size) = 0;
 
+  // Sets the priority of an outgoing stream. The initial value, when not set,
+  // is `DcSctpOptions::default_stream_priority`.
+  virtual void SetStreamPriority(StreamID stream_id,
+                                 StreamPriority priority) = 0;
+
+  // Returns the currently set priority for an outgoing stream. The initial
+  // value, when not set, is `DcSctpOptions::default_stream_priority`.
+  virtual StreamPriority GetStreamPriority(StreamID stream_id) const = 0;
+
   // Sends the message `message` using the provided send options.
   // Sending a message is an asynchrous operation, and the `OnError` callback
   // may be invoked to indicate any errors in sending the message.
diff --git a/net/dcsctp/public/mock_dcsctp_socket.h b/net/dcsctp/public/mock_dcsctp_socket.h
index 6560a3f..0fd572b 100644
--- a/net/dcsctp/public/mock_dcsctp_socket.h
+++ b/net/dcsctp/public/mock_dcsctp_socket.h
@@ -41,6 +41,16 @@
 
   MOCK_METHOD(void, SetMaxMessageSize, (size_t max_message_size), (override));
 
+  MOCK_METHOD(void,
+              SetStreamPriority,
+              (StreamID stream_id, StreamPriority priority),
+              (override));
+
+  MOCK_METHOD(StreamPriority,
+              GetStreamPriority,
+              (StreamID stream_id),
+              (const, override));
+
   MOCK_METHOD(SendStatus,
               Send,
               (DcSctpMessage message, const SendOptions& send_options),
diff --git a/net/dcsctp/public/types.h b/net/dcsctp/public/types.h
index caa03bb..358e243 100644
--- a/net/dcsctp/public/types.h
+++ b/net/dcsctp/public/types.h
@@ -31,6 +31,10 @@
 // other messages on the same stream.
 using IsUnordered = webrtc::StrongAlias<class IsUnorderedTag, bool>;
 
+// Stream priority, where higher values indicate higher priority. The meaning of
+// this value and how it's used depends on the stream scheduler.
+using StreamPriority = webrtc::StrongAlias<class StreamPriorityTag, uint16_t>;
+
 // Duration, as milliseconds. Overflows after 24 days.
 class DurationMs : public webrtc::StrongAlias<class DurationMsTag, int32_t> {
  public:
diff --git a/net/dcsctp/socket/dcsctp_socket.cc b/net/dcsctp/socket/dcsctp_socket.cc
index 9d6ae0e..5f8312d 100644
--- a/net/dcsctp/socket/dcsctp_socket.cc
+++ b/net/dcsctp/socket/dcsctp_socket.cc
@@ -189,6 +189,7 @@
       send_queue_(
           log_prefix_,
           options_.max_send_buffer_size,
+          options_.default_stream_priority,
           [this](StreamID stream_id) {
             callbacks_.OnBufferedAmountLow(stream_id);
           },
@@ -420,6 +421,16 @@
   RTC_DCHECK(IsConsistent());
 }
 
+void DcSctpSocket::SetStreamPriority(StreamID stream_id,
+                                     StreamPriority priority) {
+  RTC_DCHECK_RUN_ON(&thread_checker_);
+  send_queue_.SetStreamPriority(stream_id, priority);
+}
+StreamPriority DcSctpSocket::GetStreamPriority(StreamID stream_id) const {
+  RTC_DCHECK_RUN_ON(&thread_checker_);
+  return send_queue_.GetStreamPriority(stream_id);
+}
+
 SendStatus DcSctpSocket::Send(DcSctpMessage message,
                               const SendOptions& send_options) {
   RTC_DCHECK_RUN_ON(&thread_checker_);
diff --git a/net/dcsctp/socket/dcsctp_socket.h b/net/dcsctp/socket/dcsctp_socket.h
index 07e760a..d70d0fc 100644
--- a/net/dcsctp/socket/dcsctp_socket.h
+++ b/net/dcsctp/socket/dcsctp_socket.h
@@ -96,6 +96,8 @@
   SocketState state() const override;
   const DcSctpOptions& options() const override { return options_; }
   void SetMaxMessageSize(size_t max_message_size) override;
+  void SetStreamPriority(StreamID stream_id, StreamPriority priority) override;
+  StreamPriority GetStreamPriority(StreamID stream_id) const override;
   size_t buffered_amount(StreamID stream_id) const override;
   size_t buffered_amount_low_threshold(StreamID stream_id) const override;
   void SetBufferedAmountLowThreshold(StreamID stream_id, size_t bytes) override;
diff --git a/net/dcsctp/socket/dcsctp_socket_test.cc b/net/dcsctp/socket/dcsctp_socket_test.cc
index 770fd84..cc5566f 100644
--- a/net/dcsctp/socket/dcsctp_socket_test.cc
+++ b/net/dcsctp/socket/dcsctp_socket_test.cc
@@ -2333,6 +2333,51 @@
   absl::optional<DcSctpMessage> msg6 = z.cb.ConsumeReceivedMessage();
   ASSERT_TRUE(msg6.has_value());
   EXPECT_EQ(msg6->stream_id(), StreamID(3));
-}  // namespace
+}
+
+TEST(DcSctpSocketTest, StreamsHaveInitialPriority) {
+  DcSctpOptions options = {.default_stream_priority = StreamPriority(42)};
+  SocketUnderTest a("A", options);
+
+  EXPECT_EQ(a.socket.GetStreamPriority(StreamID(1)),
+            options.default_stream_priority);
+
+  a.socket.Send(DcSctpMessage(StreamID(2), PPID(53), {1, 2}), kSendOptions);
+
+  EXPECT_EQ(a.socket.GetStreamPriority(StreamID(2)),
+            options.default_stream_priority);
+}
+
+TEST(DcSctpSocketTest, CanChangeStreamPriority) {
+  DcSctpOptions options = {.default_stream_priority = StreamPriority(42)};
+  SocketUnderTest a("A", options);
+
+  a.socket.SetStreamPriority(StreamID(1), StreamPriority(43));
+  EXPECT_EQ(a.socket.GetStreamPriority(StreamID(1)), StreamPriority(43));
+
+  a.socket.Send(DcSctpMessage(StreamID(2), PPID(53), {1, 2}), kSendOptions);
+
+  a.socket.SetStreamPriority(StreamID(2), StreamPriority(43));
+  EXPECT_EQ(a.socket.GetStreamPriority(StreamID(2)), StreamPriority(43));
+}
+
+TEST_P(DcSctpSocketParametrizedTest, WillHandoverPriority) {
+  DcSctpOptions options = {.default_stream_priority = StreamPriority(42)};
+  auto a = std::make_unique<SocketUnderTest>("A", options);
+  SocketUnderTest z("Z");
+
+  ConnectSockets(*a, z);
+
+  a->socket.SetStreamPriority(StreamID(1), StreamPriority(43));
+  a->socket.Send(DcSctpMessage(StreamID(2), PPID(53), {1, 2}), kSendOptions);
+  a->socket.SetStreamPriority(StreamID(2), StreamPriority(43));
+
+  ExchangeMessages(*a, z);
+
+  a = MaybeHandoverSocket(std::move(a));
+
+  EXPECT_EQ(a->socket.GetStreamPriority(StreamID(1)), StreamPriority(43));
+  EXPECT_EQ(a->socket.GetStreamPriority(StreamID(2)), StreamPriority(43));
+}
 }  // namespace
 }  // namespace dcsctp
diff --git a/net/dcsctp/tx/rr_send_queue.cc b/net/dcsctp/tx/rr_send_queue.cc
index d4ce59d..3a2166b 100644
--- a/net/dcsctp/tx/rr_send_queue.cc
+++ b/net/dcsctp/tx/rr_send_queue.cc
@@ -30,11 +30,13 @@
 
 RRSendQueue::RRSendQueue(absl::string_view log_prefix,
                          size_t buffer_size,
+                         StreamPriority default_priority,
                          std::function<void(StreamID)> on_buffered_amount_low,
                          size_t total_buffered_amount_low_threshold,
                          std::function<void()> on_total_buffered_amount_low)
     : log_prefix_(std::string(log_prefix) + "fcfs: "),
       buffer_size_(buffer_size),
+      default_priority_(default_priority),
       on_buffered_amount_low_(std::move(on_buffered_amount_low)),
       total_buffered_amount_(std::move(on_total_buffered_amount_low)) {
   total_buffered_amount_.SetLowThreshold(total_buffered_amount_low_threshold);
@@ -75,6 +77,7 @@
   state.next_ssn = next_ssn_.value();
   state.next_ordered_mid = next_ordered_mid_.value();
   state.next_unordered_mid = next_unordered_mid_.value();
+  state.priority = *priority_;
 }
 
 bool RRSendQueue::IsConsistent() const {
@@ -515,12 +518,28 @@
   return streams_
       .emplace(stream_id,
                OutgoingStream(
-                   stream_id,
+                   stream_id, default_priority_,
                    [this, stream_id]() { on_buffered_amount_low_(stream_id); },
                    total_buffered_amount_))
       .first->second;
 }
 
+void RRSendQueue::SetStreamPriority(StreamID stream_id,
+                                    StreamPriority priority) {
+  OutgoingStream& stream = GetOrCreateStreamInfo(stream_id);
+
+  stream.set_priority(priority);
+  RTC_DCHECK(IsConsistent());
+}
+
+StreamPriority RRSendQueue::GetStreamPriority(StreamID stream_id) const {
+  auto stream_it = streams_.find(stream_id);
+  if (stream_it == streams_.end()) {
+    return default_priority_;
+  }
+  return stream_it->second.priority();
+}
+
 HandoverReadinessStatus RRSendQueue::GetHandoverReadiness() const {
   HandoverReadinessStatus status;
   if (!IsEmpty()) {
@@ -542,12 +561,12 @@
   for (const DcSctpSocketHandoverState::OutgoingStream& state_stream :
        state.tx.streams) {
     StreamID stream_id(state_stream.id);
-    streams_.emplace(stream_id, OutgoingStream(
-                                    stream_id,
-                                    [this, stream_id]() {
-                                      on_buffered_amount_low_(stream_id);
-                                    },
-                                    total_buffered_amount_, &state_stream));
+    streams_.emplace(
+        stream_id,
+        OutgoingStream(
+            stream_id, StreamPriority(state_stream.priority),
+            [this, stream_id]() { on_buffered_amount_low_(stream_id); },
+            total_buffered_amount_, &state_stream));
   }
 }
 }  // namespace dcsctp
diff --git a/net/dcsctp/tx/rr_send_queue.h b/net/dcsctp/tx/rr_send_queue.h
index 57a43cc..7ddb426 100644
--- a/net/dcsctp/tx/rr_send_queue.h
+++ b/net/dcsctp/tx/rr_send_queue.h
@@ -43,6 +43,7 @@
  public:
   RRSendQueue(absl::string_view log_prefix,
               size_t buffer_size,
+              StreamPriority default_priority,
               std::function<void(StreamID)> on_buffered_amount_low,
               size_t total_buffered_amount_low_threshold,
               std::function<void()> on_total_buffered_amount_low);
@@ -79,6 +80,8 @@
   size_t buffered_amount_low_threshold(StreamID stream_id) const override;
   void SetBufferedAmountLowThreshold(StreamID stream_id, size_t bytes) override;
 
+  void SetStreamPriority(StreamID stream_id, StreamPriority priority);
+  StreamPriority GetStreamPriority(StreamID stream_id) const;
   HandoverReadinessStatus GetHandoverReadiness() const;
   void AddHandoverState(DcSctpSocketHandoverState& state);
   void RestoreFromState(const DcSctpSocketHandoverState& state);
@@ -112,10 +115,12 @@
    public:
     OutgoingStream(
         StreamID stream_id,
+        StreamPriority priority,
         std::function<void()> on_buffered_amount_low,
         ThresholdWatcher& total_buffered_amount,
         const DcSctpSocketHandoverState::OutgoingStream* state = nullptr)
         : stream_id_(stream_id),
+          priority_(priority),
           next_unordered_mid_(MID(state ? state->next_unordered_mid : 0)),
           next_ordered_mid_(MID(state ? state->next_ordered_mid : 0)),
           next_ssn_(SSN(state ? state->next_ssn : 0)),
@@ -166,6 +171,9 @@
     // expired non-partially sent message.
     bool HasDataToSend(TimeMs now);
 
+    void set_priority(StreamPriority priority) { priority_ = priority; }
+    StreamPriority priority() const { return priority_; }
+
     void AddHandoverState(
         DcSctpSocketHandoverState::OutgoingStream& state) const;
 
@@ -218,6 +226,7 @@
     bool IsConsistent() const;
 
     const StreamID stream_id_;
+    StreamPriority priority_;
     PauseState pause_state_ = PauseState::kNotPaused;
     // MIDs are different for unordered and ordered messages sent on a stream.
     MID next_unordered_mid_;
@@ -247,6 +256,7 @@
 
   const std::string log_prefix_;
   const size_t buffer_size_;
+  const StreamPriority default_priority_;
 
   // Called when the buffered amount is below what has been set using
   // `SetBufferedAmountLowThreshold`.
diff --git a/net/dcsctp/tx/rr_send_queue_test.cc b/net/dcsctp/tx/rr_send_queue_test.cc
index fbbce58..3966c17 100644
--- a/net/dcsctp/tx/rr_send_queue_test.cc
+++ b/net/dcsctp/tx/rr_send_queue_test.cc
@@ -32,6 +32,7 @@
 constexpr StreamID kStreamID(1);
 constexpr PPID kPPID(53);
 constexpr size_t kMaxQueueSize = 1000;
+constexpr StreamPriority kDefaultPriority(10);
 constexpr size_t kBufferedAmountLowThreshold = 500;
 constexpr size_t kOneFragmentPacketSize = 100;
 constexpr size_t kTwoFragmentPacketSize = 101;
@@ -41,6 +42,7 @@
   RRSendQueueTest()
       : buf_("log: ",
              kMaxQueueSize,
+             kDefaultPriority,
              on_buffered_amount_low_.AsStdFunction(),
              kBufferedAmountLowThreshold,
              on_total_buffered_amount_low_.AsStdFunction()) {}
@@ -759,5 +761,39 @@
 
   EXPECT_FALSE(buf_.Produce(kNow, kOneFragmentPacketSize).has_value());
 }
+
+TEST_F(RRSendQueueTest, StreamsHaveInitialPriority) {
+  EXPECT_EQ(buf_.GetStreamPriority(StreamID(1)), kDefaultPriority);
+
+  buf_.Add(kNow, DcSctpMessage(StreamID(2), kPPID, std::vector<uint8_t>(40)));
+  EXPECT_EQ(buf_.GetStreamPriority(StreamID(2)), kDefaultPriority);
+}
+
+TEST_F(RRSendQueueTest, CanChangeStreamPriority) {
+  buf_.SetStreamPriority(StreamID(1), StreamPriority(42));
+  EXPECT_EQ(buf_.GetStreamPriority(StreamID(1)), StreamPriority(42));
+
+  buf_.Add(kNow, DcSctpMessage(StreamID(2), kPPID, std::vector<uint8_t>(40)));
+  buf_.SetStreamPriority(StreamID(2), StreamPriority(42));
+  EXPECT_EQ(buf_.GetStreamPriority(StreamID(2)), StreamPriority(42));
+}
+
+TEST_F(RRSendQueueTest, WillHandoverPriority) {
+  buf_.SetStreamPriority(StreamID(1), StreamPriority(42));
+
+  buf_.Add(kNow, DcSctpMessage(StreamID(2), kPPID, std::vector<uint8_t>(40)));
+  buf_.SetStreamPriority(StreamID(2), StreamPriority(42));
+
+  DcSctpSocketHandoverState state;
+  buf_.AddHandoverState(state);
+
+  RRSendQueue q2("log: ", kMaxQueueSize, kDefaultPriority,
+                 on_buffered_amount_low_.AsStdFunction(),
+                 kBufferedAmountLowThreshold,
+                 on_total_buffered_amount_low_.AsStdFunction());
+  q2.RestoreFromState(state);
+  EXPECT_EQ(q2.GetStreamPriority(StreamID(1)), StreamPriority(42));
+  EXPECT_EQ(q2.GetStreamPriority(StreamID(2)), StreamPriority(42));
+}
 }  // namespace
 }  // namespace dcsctp