dcsctp: Refactor send queue (2/2)

Let the send queue generate callbacks directly.

No functional change - pure refactoring.

Bug: webrtc:5696
Change-Id: Ic1e8ccba9612c5955e599c5d8257a5fa6980f666
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/264143
Reviewed-by: Florent Castelli <orphis@webrtc.org>
Commit-Queue: Victor Boivie <boivie@webrtc.org>
Cr-Commit-Position: refs/heads/main@{#37401}
diff --git a/net/dcsctp/socket/dcsctp_socket.cc b/net/dcsctp/socket/dcsctp_socket.cc
index 9287b86..aa48649 100644
--- a/net/dcsctp/socket/dcsctp_socket.cc
+++ b/net/dcsctp/socket/dcsctp_socket.cc
@@ -186,16 +186,12 @@
                        options.max_retransmissions))),
       packet_sender_(callbacks_,
                      absl::bind_front(&DcSctpSocket::OnSentPacket, this)),
-      send_queue_(
-          log_prefix_,
-          options_.max_send_buffer_size,
-          options_.mtu,
-          options_.default_stream_priority,
-          [this](StreamID stream_id) {
-            callbacks_.OnBufferedAmountLow(stream_id);
-          },
-          options_.total_buffered_amount_low_threshold,
-          [this]() { callbacks_.OnTotalBufferedAmountLow(); }) {}
+      send_queue_(log_prefix_,
+                  &callbacks_,
+                  options_.max_send_buffer_size,
+                  options_.mtu,
+                  options_.default_stream_priority,
+                  options_.total_buffered_amount_low_threshold) {}
 
 std::string DcSctpSocket::log_prefix() const {
   return log_prefix_ + "[" + std::string(ToString(state_)) + "] ";
diff --git a/net/dcsctp/tx/BUILD.gn b/net/dcsctp/tx/BUILD.gn
index e691b76..e8fbce9 100644
--- a/net/dcsctp/tx/BUILD.gn
+++ b/net/dcsctp/tx/BUILD.gn
@@ -189,6 +189,7 @@
       "../packet:sctp_packet",
       "../public:socket",
       "../public:types",
+      "../socket:mock_callbacks",
       "../testing:data_generator",
       "../testing:testing_macros",
       "../timer",
diff --git a/net/dcsctp/tx/rr_send_queue.cc b/net/dcsctp/tx/rr_send_queue.cc
index bee9d51..9e45486 100644
--- a/net/dcsctp/tx/rr_send_queue.cc
+++ b/net/dcsctp/tx/rr_send_queue.cc
@@ -31,18 +31,18 @@
 namespace dcsctp {
 
 RRSendQueue::RRSendQueue(absl::string_view log_prefix,
+                         DcSctpSocketCallbacks* callbacks,
                          size_t buffer_size,
                          size_t mtu,
                          StreamPriority default_priority,
-                         std::function<void(StreamID)> on_buffered_amount_low,
-                         size_t total_buffered_amount_low_threshold,
-                         std::function<void()> on_total_buffered_amount_low)
+                         size_t total_buffered_amount_low_threshold)
     : log_prefix_(std::string(log_prefix) + "fcfs: "),
+      callbacks_(*callbacks),
       buffer_size_(buffer_size),
       default_priority_(default_priority),
       scheduler_(mtu),
-      on_buffered_amount_low_(std::move(on_buffered_amount_low)),
-      total_buffered_amount_(std::move(on_total_buffered_amount_low)) {
+      total_buffered_amount_(
+          [this]() { callbacks_.OnTotalBufferedAmountLow(); }) {
   total_buffered_amount_.SetLowThreshold(total_buffered_amount_low_threshold);
 }
 
@@ -472,10 +472,12 @@
   }
 
   return streams_
-      .emplace(std::piecewise_construct, std::forward_as_tuple(stream_id),
-               std::forward_as_tuple(
-                   this, &scheduler_, stream_id, default_priority_,
-                   [this, stream_id]() { on_buffered_amount_low_(stream_id); }))
+      .emplace(
+          std::piecewise_construct, std::forward_as_tuple(stream_id),
+          std::forward_as_tuple(this, &scheduler_, stream_id, default_priority_,
+                                [this, stream_id]() {
+                                  callbacks_.OnBufferedAmountLow(stream_id);
+                                }))
       .first->second;
 }
 
@@ -520,7 +522,7 @@
         std::piecewise_construct, std::forward_as_tuple(stream_id),
         std::forward_as_tuple(
             this, &scheduler_, stream_id, StreamPriority(state_stream.priority),
-            [this, stream_id]() { on_buffered_amount_low_(stream_id); },
+            [this, stream_id]() { callbacks_.OnBufferedAmountLow(stream_id); },
             &state_stream));
   }
 }
diff --git a/net/dcsctp/tx/rr_send_queue.h b/net/dcsctp/tx/rr_send_queue.h
index 8e6085f..9152f27 100644
--- a/net/dcsctp/tx/rr_send_queue.h
+++ b/net/dcsctp/tx/rr_send_queue.h
@@ -41,15 +41,18 @@
 //
 // As messages can be (requested to be) sent before the connection is properly
 // established, this send queue is always present - even for closed connections.
+//
+// The send queue may trigger callbacks:
+//  * `OnBufferedAmountLow`, `OnTotalBufferedAmountLow`
+//    These will be triggered as defined in their documentation.
 class RRSendQueue : public SendQueue {
  public:
   RRSendQueue(absl::string_view log_prefix,
+              DcSctpSocketCallbacks* callbacks,
               size_t buffer_size,
               size_t mtu,
               StreamPriority default_priority,
-              std::function<void(StreamID)> on_buffered_amount_low,
-              size_t total_buffered_amount_low_threshold,
-              std::function<void()> on_total_buffered_amount_low);
+              size_t total_buffered_amount_low_threshold);
 
   // Indicates if the buffer is full. Note that it's up to the caller to ensure
   // that the buffer is not full prior to adding new items to it.
@@ -255,18 +258,11 @@
       size_t max_size);
 
   const std::string log_prefix_;
+  DcSctpSocketCallbacks& callbacks_;
   const size_t buffer_size_;
   const StreamPriority default_priority_;
   StreamScheduler scheduler_;
 
-  // Called when the buffered amount is below what has been set using
-  // `SetBufferedAmountLowThreshold`.
-  const std::function<void(StreamID)> on_buffered_amount_low_;
-
-  // Called when the total buffered amount is below what has been set using
-  // `SetTotalBufferedAmountLowThreshold`.
-  const std::function<void()> on_total_buffered_amount_low_;
-
   // The total amount of buffer data, for all streams.
   ThresholdWatcher total_buffered_amount_;
 
diff --git a/net/dcsctp/tx/rr_send_queue_test.cc b/net/dcsctp/tx/rr_send_queue_test.cc
index 7471ccc..78b5ecd 100644
--- a/net/dcsctp/tx/rr_send_queue_test.cc
+++ b/net/dcsctp/tx/rr_send_queue_test.cc
@@ -18,6 +18,7 @@
 #include "net/dcsctp/public/dcsctp_options.h"
 #include "net/dcsctp/public/dcsctp_socket.h"
 #include "net/dcsctp/public/types.h"
+#include "net/dcsctp/socket/mock_dcsctp_socket_callbacks.h"
 #include "net/dcsctp/testing/testing_macros.h"
 #include "net/dcsctp/tx/send_queue.h"
 #include "rtc_base/gunit.h"
@@ -42,18 +43,14 @@
  protected:
   RRSendQueueTest()
       : buf_("log: ",
+             &callbacks_,
              kMaxQueueSize,
              kMtu,
              kDefaultPriority,
-             on_buffered_amount_low_.AsStdFunction(),
-             kBufferedAmountLowThreshold,
-             on_total_buffered_amount_low_.AsStdFunction()) {}
+             kBufferedAmountLowThreshold) {}
 
+  testing::NiceMock<MockDcSctpSocketCallbacks> callbacks_;
   const DcSctpOptions options_;
-  testing::NiceMock<testing::MockFunction<void(StreamID)>>
-      on_buffered_amount_low_;
-  testing::NiceMock<testing::MockFunction<void()>>
-      on_total_buffered_amount_low_;
   RRSendQueue buf_;
 };
 
@@ -546,7 +543,7 @@
 }
 
 TEST_F(RRSendQueueTest, DoesntTriggerOnBufferedAmountLowWhenSetToZero) {
-  EXPECT_CALL(on_buffered_amount_low_, Call).Times(0);
+  EXPECT_CALL(callbacks_, OnBufferedAmountLow).Times(0);
   buf_.SetBufferedAmountLowThreshold(StreamID(1), 0u);
 }
 
@@ -554,7 +551,7 @@
   buf_.Add(kNow, DcSctpMessage(StreamID(1), kPPID, std::vector<uint8_t>(1)));
   EXPECT_EQ(buf_.buffered_amount(StreamID(1)), 1u);
 
-  EXPECT_CALL(on_buffered_amount_low_, Call(StreamID(1)));
+  EXPECT_CALL(callbacks_, OnBufferedAmountLow(StreamID(1)));
 
   ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk1,
                               buf_.Produce(kNow, kOneFragmentPacketSize));
@@ -566,20 +563,20 @@
 TEST_F(RRSendQueueTest, WillRetriggerOnBufferedAmountLowIfAddingMore) {
   buf_.Add(kNow, DcSctpMessage(StreamID(1), kPPID, std::vector<uint8_t>(1)));
 
-  EXPECT_CALL(on_buffered_amount_low_, Call(StreamID(1)));
+  EXPECT_CALL(callbacks_, OnBufferedAmountLow(StreamID(1)));
 
   ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk1,
                               buf_.Produce(kNow, kOneFragmentPacketSize));
   EXPECT_EQ(chunk1.data.stream_id, StreamID(1));
   EXPECT_THAT(chunk1.data.payload, SizeIs(1));
 
-  EXPECT_CALL(on_buffered_amount_low_, Call).Times(0);
+  EXPECT_CALL(callbacks_, OnBufferedAmountLow).Times(0);
 
   buf_.Add(kNow, DcSctpMessage(StreamID(1), kPPID, std::vector<uint8_t>(1)));
   EXPECT_EQ(buf_.buffered_amount(StreamID(1)), 1u);
 
   // Should now trigger again, as buffer_amount went above the threshold.
-  EXPECT_CALL(on_buffered_amount_low_, Call(StreamID(1)));
+  EXPECT_CALL(callbacks_, OnBufferedAmountLow(StreamID(1)));
   ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk2,
                               buf_.Produce(kNow, kOneFragmentPacketSize));
   EXPECT_EQ(chunk2.data.stream_id, StreamID(1));
@@ -592,7 +589,7 @@
   buf_.Add(kNow, DcSctpMessage(StreamID(1), kPPID, std::vector<uint8_t>(10)));
   EXPECT_EQ(buf_.buffered_amount(StreamID(1)), 10u);
 
-  EXPECT_CALL(on_buffered_amount_low_, Call).Times(0);
+  EXPECT_CALL(callbacks_, OnBufferedAmountLow).Times(0);
   ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk1,
                               buf_.Produce(kNow, kOneFragmentPacketSize));
   EXPECT_EQ(chunk1.data.stream_id, StreamID(1));
@@ -610,7 +607,7 @@
 }
 
 TEST_F(RRSendQueueTest, WillTriggerOnBufferedAmountLowSetAboveZero) {
-  EXPECT_CALL(on_buffered_amount_low_, Call).Times(0);
+  EXPECT_CALL(callbacks_, OnBufferedAmountLow).Times(0);
 
   buf_.SetBufferedAmountLowThreshold(StreamID(1), 700);
 
@@ -629,7 +626,7 @@
   EXPECT_THAT(chunk2.data.payload, SizeIs(kOneFragmentPacketSize));
   EXPECT_EQ(buf_.buffered_amount(StreamID(1)), 800u);
 
-  EXPECT_CALL(on_buffered_amount_low_, Call(StreamID(1)));
+  EXPECT_CALL(callbacks_, OnBufferedAmountLow(StreamID(1)));
 
   ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk3,
                               buf_.Produce(kNow, kOneFragmentPacketSize));
@@ -638,7 +635,7 @@
   EXPECT_EQ(buf_.buffered_amount(StreamID(1)), 700u);
 
   // Doesn't trigger when reducing even further.
-  EXPECT_CALL(on_buffered_amount_low_, Call).Times(0);
+  EXPECT_CALL(callbacks_, OnBufferedAmountLow).Times(0);
 
   ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk4,
                               buf_.Produce(kNow, kOneFragmentPacketSize));
@@ -648,25 +645,25 @@
 }
 
 TEST_F(RRSendQueueTest, WillRetriggerOnBufferedAmountLowSetAboveZero) {
-  EXPECT_CALL(on_buffered_amount_low_, Call).Times(0);
+  EXPECT_CALL(callbacks_, OnBufferedAmountLow).Times(0);
 
   buf_.SetBufferedAmountLowThreshold(StreamID(1), 700);
 
   buf_.Add(kNow, DcSctpMessage(StreamID(1), kPPID, std::vector<uint8_t>(1000)));
 
-  EXPECT_CALL(on_buffered_amount_low_, Call(StreamID(1)));
+  EXPECT_CALL(callbacks_, OnBufferedAmountLow(StreamID(1)));
   ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk1,
                               buf_.Produce(kNow, 400));
   EXPECT_EQ(chunk1.data.stream_id, StreamID(1));
   EXPECT_THAT(chunk1.data.payload, SizeIs(400));
   EXPECT_EQ(buf_.buffered_amount(StreamID(1)), 600u);
 
-  EXPECT_CALL(on_buffered_amount_low_, Call).Times(0);
+  EXPECT_CALL(callbacks_, OnBufferedAmountLow).Times(0);
   buf_.Add(kNow, DcSctpMessage(StreamID(1), kPPID, std::vector<uint8_t>(200)));
   EXPECT_EQ(buf_.buffered_amount(StreamID(1)), 800u);
 
   // Will trigger again, as it went above the limit.
-  EXPECT_CALL(on_buffered_amount_low_, Call(StreamID(1)));
+  EXPECT_CALL(callbacks_, OnBufferedAmountLow(StreamID(1)));
   ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk2,
                               buf_.Produce(kNow, 200));
   EXPECT_EQ(chunk2.data.stream_id, StreamID(1));
@@ -675,7 +672,7 @@
 }
 
 TEST_F(RRSendQueueTest, TriggersOnBufferedAmountLowOnThresholdChanged) {
-  EXPECT_CALL(on_buffered_amount_low_, Call).Times(0);
+  EXPECT_CALL(callbacks_, OnBufferedAmountLow).Times(0);
 
   buf_.Add(kNow, DcSctpMessage(StreamID(1), kPPID, std::vector<uint8_t>(100)));
 
@@ -684,25 +681,25 @@
   buf_.SetBufferedAmountLowThreshold(StreamID(1), 99);
 
   // When the threshold reaches buffered_amount, it will trigger.
-  EXPECT_CALL(on_buffered_amount_low_, Call(StreamID(1)));
+  EXPECT_CALL(callbacks_, OnBufferedAmountLow(StreamID(1)));
   buf_.SetBufferedAmountLowThreshold(StreamID(1), 100);
 
   // But not when it's set low again.
-  EXPECT_CALL(on_buffered_amount_low_, Call).Times(0);
+  EXPECT_CALL(callbacks_, OnBufferedAmountLow).Times(0);
   buf_.SetBufferedAmountLowThreshold(StreamID(1), 50);
 
   // But it will trigger when it overshoots.
-  EXPECT_CALL(on_buffered_amount_low_, Call(StreamID(1)));
+  EXPECT_CALL(callbacks_, OnBufferedAmountLow(StreamID(1)));
   buf_.SetBufferedAmountLowThreshold(StreamID(1), 150);
 
   // But not when it's set low again.
-  EXPECT_CALL(on_buffered_amount_low_, Call).Times(0);
+  EXPECT_CALL(callbacks_, OnBufferedAmountLow).Times(0);
   buf_.SetBufferedAmountLowThreshold(StreamID(1), 0);
 }
 
 TEST_F(RRSendQueueTest,
        OnTotalBufferedAmountLowDoesNotTriggerOnBufferFillingUp) {
-  EXPECT_CALL(on_total_buffered_amount_low_, Call).Times(0);
+  EXPECT_CALL(callbacks_, OnTotalBufferedAmountLow).Times(0);
   std::vector<uint8_t> payload(kBufferedAmountLowThreshold - 1);
   buf_.Add(kNow, DcSctpMessage(kStreamID, kPPID, payload));
   EXPECT_EQ(buf_.total_buffered_amount(), payload.size());
@@ -713,7 +710,7 @@
 }
 
 TEST_F(RRSendQueueTest, TriggersOnTotalBufferedAmountLowWhenCrossing) {
-  EXPECT_CALL(on_total_buffered_amount_low_, Call).Times(0);
+  EXPECT_CALL(callbacks_, OnTotalBufferedAmountLow).Times(0);
   std::vector<uint8_t> payload(kBufferedAmountLowThreshold);
   buf_.Add(kNow, DcSctpMessage(kStreamID, kPPID, payload));
   EXPECT_EQ(buf_.total_buffered_amount(), payload.size());
@@ -722,7 +719,7 @@
   buf_.Add(kNow, DcSctpMessage(kStreamID, kPPID, std::vector<uint8_t>(1)));
 
   // Drain it a bit - will trigger.
-  EXPECT_CALL(on_total_buffered_amount_low_, Call).Times(1);
+  EXPECT_CALL(callbacks_, OnTotalBufferedAmountLow).Times(1);
   absl::optional<SendQueue::DataToSend> chunk_two =
       buf_.Produce(kNow, kOneFragmentPacketSize);
 }
@@ -789,10 +786,8 @@
   DcSctpSocketHandoverState state;
   buf_.AddHandoverState(state);
 
-  RRSendQueue q2("log: ", kMaxQueueSize, kMtu, kDefaultPriority,
-                 on_buffered_amount_low_.AsStdFunction(),
-                 kBufferedAmountLowThreshold,
-                 on_total_buffered_amount_low_.AsStdFunction());
+  RRSendQueue q2("log: ", &callbacks_, kMaxQueueSize, kMtu, kDefaultPriority,
+                 kBufferedAmountLowThreshold);
   q2.RestoreFromState(state);
   EXPECT_EQ(q2.GetStreamPriority(StreamID(1)), StreamPriority(42));
   EXPECT_EQ(q2.GetStreamPriority(StreamID(2)), StreamPriority(42));