dcsctp: Ensure callbacks are always triggered

The previous manual way of triggering the deferred callbacks was very
error-prone, and this was also forgotten at a few places.

We can do better.

Using the RAII programming idiom, the callbacks are now ensured to be
called before returning from public methods.

Also added additional debug checks to ensure that there is a
ScopedDeferrer active whenever callbacks are deferred.

Bug: webrtc:13217
Change-Id: I16a8343b52c00fb30acb018d3846acd0a64318e0
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/233242
Commit-Queue: Victor Boivie <boivie@webrtc.org>
Reviewed-by: Florent Castelli <orphis@webrtc.org>
Cr-Commit-Position: refs/heads/main@{#35117}
diff --git a/net/dcsctp/socket/callback_deferrer.cc b/net/dcsctp/socket/callback_deferrer.cc
index 1b7fbac..b4af10e 100644
--- a/net/dcsctp/socket/callback_deferrer.cc
+++ b/net/dcsctp/socket/callback_deferrer.cc
@@ -36,12 +36,19 @@
 };
 }  // namespace
 
+void CallbackDeferrer::Prepare() {
+  RTC_DCHECK(!prepared_);
+  prepared_ = true;
+}
+
 void CallbackDeferrer::TriggerDeferred() {
   // Need to swap here. The client may call into the library from within a
   // callback, and that might result in adding new callbacks to this instance,
   // and the vector can't be modified while iterated on.
+  RTC_DCHECK(prepared_);
   std::vector<std::function<void(DcSctpSocketCallbacks & cb)>> deferred;
   deferred.swap(deferred_);
+  prepared_ = false;
 
   for (auto& cb : deferred) {
     cb(underlying_);
@@ -70,12 +77,14 @@
 }
 
 void CallbackDeferrer::OnMessageReceived(DcSctpMessage message) {
+  RTC_DCHECK(prepared_);
   deferred_.emplace_back(
       [deliverer = MessageDeliverer(std::move(message))](
           DcSctpSocketCallbacks& cb) mutable { deliverer.Deliver(cb); });
 }
 
 void CallbackDeferrer::OnError(ErrorKind error, absl::string_view message) {
+  RTC_DCHECK(prepared_);
   deferred_.emplace_back(
       [error, message = std::string(message)](DcSctpSocketCallbacks& cb) {
         cb.OnError(error, message);
@@ -83,6 +92,7 @@
 }
 
 void CallbackDeferrer::OnAborted(ErrorKind error, absl::string_view message) {
+  RTC_DCHECK(prepared_);
   deferred_.emplace_back(
       [error, message = std::string(message)](DcSctpSocketCallbacks& cb) {
         cb.OnAborted(error, message);
@@ -90,14 +100,17 @@
 }
 
 void CallbackDeferrer::OnConnected() {
+  RTC_DCHECK(prepared_);
   deferred_.emplace_back([](DcSctpSocketCallbacks& cb) { cb.OnConnected(); });
 }
 
 void CallbackDeferrer::OnClosed() {
+  RTC_DCHECK(prepared_);
   deferred_.emplace_back([](DcSctpSocketCallbacks& cb) { cb.OnClosed(); });
 }
 
 void CallbackDeferrer::OnConnectionRestarted() {
+  RTC_DCHECK(prepared_);
   deferred_.emplace_back(
       [](DcSctpSocketCallbacks& cb) { cb.OnConnectionRestarted(); });
 }
@@ -105,6 +118,7 @@
 void CallbackDeferrer::OnStreamsResetFailed(
     rtc::ArrayView<const StreamID> outgoing_streams,
     absl::string_view reason) {
+  RTC_DCHECK(prepared_);
   deferred_.emplace_back(
       [streams = std::vector<StreamID>(outgoing_streams.begin(),
                                        outgoing_streams.end()),
@@ -115,6 +129,7 @@
 
 void CallbackDeferrer::OnStreamsResetPerformed(
     rtc::ArrayView<const StreamID> outgoing_streams) {
+  RTC_DCHECK(prepared_);
   deferred_.emplace_back(
       [streams = std::vector<StreamID>(outgoing_streams.begin(),
                                        outgoing_streams.end())](
@@ -123,6 +138,7 @@
 
 void CallbackDeferrer::OnIncomingStreamsReset(
     rtc::ArrayView<const StreamID> incoming_streams) {
+  RTC_DCHECK(prepared_);
   deferred_.emplace_back(
       [streams = std::vector<StreamID>(incoming_streams.begin(),
                                        incoming_streams.end())](
@@ -130,12 +146,14 @@
 }
 
 void CallbackDeferrer::OnBufferedAmountLow(StreamID stream_id) {
+  RTC_DCHECK(prepared_);
   deferred_.emplace_back([stream_id](DcSctpSocketCallbacks& cb) {
     cb.OnBufferedAmountLow(stream_id);
   });
 }
 
 void CallbackDeferrer::OnTotalBufferedAmountLow() {
+  RTC_DCHECK(prepared_);
   deferred_.emplace_back(
       [](DcSctpSocketCallbacks& cb) { cb.OnTotalBufferedAmountLow(); });
 }
diff --git a/net/dcsctp/socket/callback_deferrer.h b/net/dcsctp/socket/callback_deferrer.h
index ab2739f..918b1df 100644
--- a/net/dcsctp/socket/callback_deferrer.h
+++ b/net/dcsctp/socket/callback_deferrer.h
@@ -26,7 +26,6 @@
 #include "rtc_base/ref_counted_object.h"
 
 namespace dcsctp {
-
 // Defers callbacks until they can be safely triggered.
 //
 // There are a lot of callbacks from the dcSCTP library to the client,
@@ -44,11 +43,22 @@
 // There are a number of exceptions, which is clearly annotated in the API.
 class CallbackDeferrer : public DcSctpSocketCallbacks {
  public:
+  class ScopedDeferrer {
+   public:
+    explicit ScopedDeferrer(CallbackDeferrer& callback_deferrer)
+        : callback_deferrer_(callback_deferrer) {
+      callback_deferrer_.Prepare();
+    }
+
+    ~ScopedDeferrer() { callback_deferrer_.TriggerDeferred(); }
+
+   private:
+    CallbackDeferrer& callback_deferrer_;
+  };
+
   explicit CallbackDeferrer(DcSctpSocketCallbacks& underlying)
       : underlying_(underlying) {}
 
-  void TriggerDeferred();
-
   // Implementation of DcSctpSocketCallbacks
   SendPacketStatus SendPacketWithStatus(
       rtc::ArrayView<const uint8_t> data) override;
@@ -71,7 +81,11 @@
   void OnTotalBufferedAmountLow() override;
 
  private:
+  void Prepare();
+  void TriggerDeferred();
+
   DcSctpSocketCallbacks& underlying_;
+  bool prepared_ = false;
   std::vector<std::function<void(DcSctpSocketCallbacks& cb)>> deferred_;
 };
 }  // namespace dcsctp
diff --git a/net/dcsctp/socket/dcsctp_socket.cc b/net/dcsctp/socket/dcsctp_socket.cc
index a1cc12d..1001813 100644
--- a/net/dcsctp/socket/dcsctp_socket.cc
+++ b/net/dcsctp/socket/dcsctp_socket.cc
@@ -281,6 +281,8 @@
 }
 
 void DcSctpSocket::Connect() {
+  CallbackDeferrer::ScopedDeferrer deferrer(callbacks_);
+
   if (state_ == State::kClosed) {
     MakeConnectionParameters();
     RTC_DLOG(LS_INFO)
@@ -296,10 +298,11 @@
                          << "Called Connect on a socket that is not closed";
   }
   RTC_DCHECK(IsConsistent());
-  callbacks_.TriggerDeferred();
 }
 
 void DcSctpSocket::RestoreFromState(const DcSctpSocketHandoverState& state) {
+  CallbackDeferrer::ScopedDeferrer deferrer(callbacks_);
+
   if (state_ != State::kClosed) {
     callbacks_.OnError(ErrorKind::kUnsupportedOperation,
                        "Only closed socket can be restored from state");
@@ -334,10 +337,11 @@
   }
 
   RTC_DCHECK(IsConsistent());
-  callbacks_.TriggerDeferred();
 }
 
 void DcSctpSocket::Shutdown() {
+  CallbackDeferrer::ScopedDeferrer deferrer(callbacks_);
+
   if (tcb_ != nullptr) {
     // https://tools.ietf.org/html/rfc4960#section-9.2
     // "Upon receipt of the SHUTDOWN primitive from its upper layer, the
@@ -361,10 +365,11 @@
     InternalClose(ErrorKind::kNoError, "");
   }
   RTC_DCHECK(IsConsistent());
-  callbacks_.TriggerDeferred();
 }
 
 void DcSctpSocket::Close() {
+  CallbackDeferrer::ScopedDeferrer deferrer(callbacks_);
+
   if (state_ != State::kClosed) {
     if (tcb_ != nullptr) {
       SctpPacket::Builder b = tcb_->PacketBuilder();
@@ -379,7 +384,6 @@
     RTC_DLOG(LS_INFO) << log_prefix() << "Called Close on a closed socket";
   }
   RTC_DCHECK(IsConsistent());
-  callbacks_.TriggerDeferred();
 }
 
 void DcSctpSocket::CloseConnectionBecauseOfTooManyTransmissionErrors() {
@@ -411,6 +415,8 @@
 
 SendStatus DcSctpSocket::Send(DcSctpMessage message,
                               const SendOptions& send_options) {
+  CallbackDeferrer::ScopedDeferrer deferrer(callbacks_);
+
   if (message.payload().empty()) {
     callbacks_.OnError(ErrorKind::kProtocolViolation,
                        "Unable to send empty message");
@@ -445,12 +451,13 @@
   }
 
   RTC_DCHECK(IsConsistent());
-  callbacks_.TriggerDeferred();
   return SendStatus::kSuccess;
 }
 
 ResetStreamsStatus DcSctpSocket::ResetStreams(
     rtc::ArrayView<const StreamID> outgoing_streams) {
+  CallbackDeferrer::ScopedDeferrer deferrer(callbacks_);
+
   if (tcb_ == nullptr) {
     callbacks_.OnError(ErrorKind::kWrongSequence,
                        "Can't reset streams as the socket is not connected");
@@ -472,7 +479,6 @@
   }
 
   RTC_DCHECK(IsConsistent());
-  callbacks_.TriggerDeferred();
   return ResetStreamsStatus::kPerformed;
 }
 
@@ -654,6 +660,8 @@
 }
 
 void DcSctpSocket::HandleTimeout(TimeoutID timeout_id) {
+  CallbackDeferrer::ScopedDeferrer deferrer(callbacks_);
+
   timer_manager_.HandleTimeout(timeout_id);
 
   if (tcb_ != nullptr && tcb_->HasTooManyTxErrors()) {
@@ -662,10 +670,11 @@
   }
 
   RTC_DCHECK(IsConsistent());
-  callbacks_.TriggerDeferred();
 }
 
 void DcSctpSocket::ReceivePacket(rtc::ArrayView<const uint8_t> data) {
+  CallbackDeferrer::ScopedDeferrer deferrer(callbacks_);
+
   ++metrics_.rx_packets_count;
 
   if (packet_observer_ != nullptr) {
@@ -681,7 +690,6 @@
     callbacks_.OnError(ErrorKind::kParseFailed,
                        "Failed to parse received SCTP packet");
     RTC_DCHECK(IsConsistent());
-    callbacks_.TriggerDeferred();
     return;
   }
 
@@ -696,7 +704,6 @@
     RTC_DLOG(LS_VERBOSE) << log_prefix()
                          << "Packet failed verification tag check - dropping";
     RTC_DCHECK(IsConsistent());
-    callbacks_.TriggerDeferred();
     return;
   }
 
@@ -714,7 +721,6 @@
   }
 
   RTC_DCHECK(IsConsistent());
-  callbacks_.TriggerDeferred();
 }
 
 void DcSctpSocket::DebugPrintOutgoing(rtc::ArrayView<const uint8_t> payload) {
@@ -1646,6 +1652,8 @@
 
 absl::optional<DcSctpSocketHandoverState>
 DcSctpSocket::GetHandoverStateAndClose() {
+  CallbackDeferrer::ScopedDeferrer deferrer(callbacks_);
+
   if (!GetHandoverReadiness().IsReady()) {
     return absl::nullopt;
   }
@@ -1659,7 +1667,6 @@
     tcb_->AddHandoverState(state);
     send_queue_.AddHandoverState(state);
     InternalClose(ErrorKind::kNoError, "handover");
-    callbacks_.TriggerDeferred();
   }
 
   return std::move(state);