dcsctp: Avoid recalculation of outstanding bytes

Recalculating outstanding bytes is expensive when the congestion window
is large, as it iterates over all inflight data chunks. By doing it
incrementally, it will be a constant operation in most cases, and
in the remaining cases, a function of the number of chunks acked in a
single SACK, which is typically just a few chunks.

Implementing this fix required some refactoring to calculate it
correctly (and to be honest, it was likely done incorrectly previously).

Previously, the state of an item in the retransmission queue was
simplified as "in flight", "acked", "nacked", "abandoned", but these
were not completely orthogonal. A chunk could be abandoned while it was
in-flight or it could be abandoned because it was lost. The difference
between these if that chunk should be accounted for in
outstanding_bytes() or not.

Unit tests have been added to verify this.

Bug: webrtc:12799
Change-Id: I72341538bb0c4f8f89555b08f0c8a28815f0f828
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/219623
Reviewed-by: Harald Alvestrand <hta@webrtc.org>
Commit-Queue: Victor Boivie <boivie@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#34139}
diff --git a/net/dcsctp/tx/retransmission_queue.cc b/net/dcsctp/tx/retransmission_queue.cc
index 704e6ab..17d7358 100644
--- a/net/dcsctp/tx/retransmission_queue.cc
+++ b/net/dcsctp/tx/retransmission_queue.cc
@@ -38,6 +38,7 @@
 #include "net/dcsctp/public/types.h"
 #include "net/dcsctp/timer/timer.h"
 #include "net/dcsctp/tx/send_queue.h"
+#include "rtc_base/checks.h"
 #include "rtc_base/logging.h"
 #include "rtc_base/strings/string_builder.h"
 
@@ -83,6 +84,20 @@
       last_cumulative_tsn_ack_(tsn_unwrapper_.Unwrap(TSN(*initial_tsn - 1))),
       send_queue_(send_queue) {}
 
+bool RetransmissionQueue::IsConsistent() const {
+  size_t actual_outstanding_bytes = absl::c_accumulate(
+      outstanding_data_, 0,
+      [&](size_t r, const std::pair<const UnwrappedTSN, TxData>& d) {
+        // Packets that have been ACKED or NACKED are not outstanding, as they
+        // are received. And packets that are marked for retransmission or
+        // abandoned are lost, and not outstanding.
+        return r + (d.second.is_outstanding()
+                        ? GetSerializedChunkSize(d.second.data())
+                        : 0);
+      });
+  return actual_outstanding_bytes == outstanding_bytes_;
+}
+
 // Returns how large a chunk will be, serialized, carrying the data
 size_t RetransmissionQueue::GetSerializedChunkSize(const Data& data) const {
   return RoundUpTo4(data_chunk_header_size_ + data.size());
@@ -95,6 +110,9 @@
   for (auto it = outstanding_data_.begin(); it != first_unacked; ++it) {
     ack_info.bytes_acked_by_cumulative_tsn_ack += it->second.data().size();
     ack_info.acked_tsns.push_back(it->first.Wrap());
+    if (it->second.is_outstanding()) {
+      outstanding_bytes_ -= GetSerializedChunkSize(it->second.data());
+    }
   }
 
   outstanding_data_.erase(outstanding_data_.begin(), first_unacked);
@@ -115,10 +133,13 @@
     auto end = outstanding_data_.upper_bound(
         UnwrappedTSN::AddTo(cumulative_tsn_ack, block.end));
     for (auto iter = start; iter != end; ++iter) {
-      if (iter->second.state() != State::kAcked) {
+      if (!iter->second.is_acked()) {
         ack_info.bytes_acked_by_new_gap_ack_blocks +=
             iter->second.data().size();
-        iter->second.SetState(State::kAcked);
+        if (iter->second.is_outstanding()) {
+          outstanding_bytes_ -= GetSerializedChunkSize(iter->second.data());
+        }
+        iter->second.Ack();
         ack_info.highest_tsn_acked =
             std::max(ack_info.highest_tsn_acked, iter->first);
         ack_info.acked_tsns.push_back(iter->first.Wrap());
@@ -159,9 +180,11 @@
     for (auto iter = outstanding_data_.upper_bound(prev_block_last_acked);
          iter != outstanding_data_.lower_bound(cur_block_first_acked); ++iter) {
       if (iter->first <= max_tsn_to_nack) {
-        iter->second.Nack();
+        if (iter->second.is_outstanding()) {
+          outstanding_bytes_ -= GetSerializedChunkSize(iter->second.data());
+        }
 
-        if (iter->second.state() == State::kToBeRetransmitted) {
+        if (iter->second.Nack()) {
           ack_info.has_packet_loss = true;
           RTC_DLOG(LS_VERBOSE) << log_prefix_ << *iter->first.Wrap()
                                << " marked for retransmission";
@@ -367,7 +390,6 @@
   // NACK and possibly mark for retransmit chunks that weren't acked.
   NackBetweenAckBlocks(cumulative_tsn_ack, sack.gap_ack_blocks(), ack_info);
 
-  RecalculateOutstandingBytes();
   // Update of outstanding_data_ is now done. Congestion control remains.
   UpdateReceiverWindow(sack.a_rwnd());
 
@@ -413,6 +435,7 @@
 
   last_cumulative_tsn_ack_ = cumulative_tsn_ack;
   StartT3RtxTimerIfOutstandingData();
+  RTC_DCHECK(IsConsistent());
   return true;
 }
 
@@ -440,19 +463,6 @@
   }
 }
 
-void RetransmissionQueue::RecalculateOutstandingBytes() {
-  outstanding_bytes_ = absl::c_accumulate(
-      outstanding_data_, 0,
-      [&](size_t r, const std::pair<const UnwrappedTSN, TxData>& d) {
-        // Packets that have been ACKED or NACKED are not outstanding, as they
-        // are received. And packets that are marked for retransmission or
-        // abandoned are lost, and not outstanding.
-        return r + (d.second.state() == State::kInFlight
-                        ? GetSerializedChunkSize(d.second.data())
-                        : 0);
-      });
-}
-
 void RetransmissionQueue::HandleT3RtxTimerExpiry() {
   size_t old_cwnd = cwnd_;
   size_t old_outstanding_bytes = outstanding_bytes_;
@@ -484,17 +494,18 @@
   for (auto& elem : outstanding_data_) {
     UnwrappedTSN tsn = elem.first;
     TxData& item = elem.second;
-    if (item.state() == State::kInFlight || item.state() == State::kNacked) {
-      RTC_DLOG(LS_VERBOSE) << log_prefix_ << "Chunk " << *tsn.Wrap()
-                           << " will be retransmitted due to T3-RTX";
-      item.SetState(State::kToBeRetransmitted);
-      ++count;
+    if (!item.is_acked()) {
+      if (item.is_outstanding()) {
+        outstanding_bytes_ -= GetSerializedChunkSize(item.data());
+      }
+      if (item.Nack(/*retransmit_now=*/true)) {
+        RTC_DLOG(LS_VERBOSE) << log_prefix_ << "Chunk " << *tsn.Wrap()
+                             << " will be retransmitted due to T3-RTX";
+        ++count;
+      }
     }
   }
 
-  // Marking some packets as retransmitted changes outstanding bytes.
-  RecalculateOutstandingBytes();
-
   // https://tools.ietf.org/html/rfc4960#section-6.3.3
   // "Start the retransmission timer T3-rtx on the destination address
   // to which the retransmission is sent, if rule R1 above indicates to do so."
@@ -506,6 +517,7 @@
                     << ", rtx-packets=" << count << ", outstanding_bytes "
                     << outstanding_bytes_ << " (" << old_outstanding_bytes
                     << ")";
+  RTC_DCHECK(IsConsistent());
 }
 
 std::vector<std::pair<TSN, Data>>
@@ -516,21 +528,21 @@
     TxData& item = elem.second;
 
     size_t serialized_size = GetSerializedChunkSize(item.data());
-    if (item.state() == State::kToBeRetransmitted &&
-        serialized_size <= max_size) {
+    if (item.should_be_retransmitted() && serialized_size <= max_size) {
+      RTC_DCHECK(!item.is_outstanding());
+      RTC_DCHECK(!item.is_abandoned());
+      RTC_DCHECK(!item.is_acked());
       item.Retransmit();
       result.emplace_back(tsn.Wrap(), item.data().Clone());
       max_size -= serialized_size;
+      outstanding_bytes_ += serialized_size;
     }
     // No point in continuing if the packet is full.
     if (max_size <= data_chunk_header_size_) {
       break;
     }
   }
-  // As some chunks may have switched state, that needs to be reflected here.
-  if (!result.empty()) {
-    RecalculateOutstandingBytes();
-  }
+
   return result;
 }
 
@@ -624,6 +636,7 @@
                          << " (" << old_outstanding_bytes << "), cwnd=" << cwnd_
                          << ", rwnd=" << rwnd_ << " (" << old_rwnd << ")";
   }
+  RTC_DCHECK(IsConsistent());
   return to_be_sent;
 }
 
@@ -632,7 +645,20 @@
   std::vector<std::pair<TSN, RetransmissionQueue::State>> states;
   states.emplace_back(last_cumulative_tsn_ack_.Wrap(), State::kAcked);
   for (const auto& elem : outstanding_data_) {
-    states.emplace_back(elem.first.Wrap(), elem.second.state());
+    State state;
+    if (elem.second.is_abandoned()) {
+      state = State::kAbandoned;
+    } else if (elem.second.should_be_retransmitted()) {
+      state = State::kToBeRetransmitted;
+    } else if (elem.second.is_acked()) {
+      state = State::kAcked;
+    } else if (elem.second.is_outstanding()) {
+      state = State::kInFlight;
+    } else {
+      state = State::kNacked;
+    }
+
+    states.emplace_back(elem.first.Wrap(), state);
   }
   return states;
 }
@@ -645,28 +671,43 @@
   if (!outstanding_data_.empty()) {
     auto it = outstanding_data_.begin();
     return it->first == last_cumulative_tsn_ack_.next_value() &&
-           it->second.state() == State::kAbandoned;
+           it->second.is_abandoned();
+  }
+  RTC_DCHECK(IsConsistent());
+  return false;
+}
+
+void RetransmissionQueue::TxData::Ack() {
+  ack_state_ = AckState::kAcked;
+  should_be_retransmitted_ = false;
+}
+
+bool RetransmissionQueue::TxData::Nack(bool retransmit_now) {
+  ack_state_ = AckState::kNacked;
+  ++nack_count_;
+  if ((retransmit_now || nack_count_ >= kNumberOfNacksForRetransmission) &&
+      !is_abandoned_) {
+    should_be_retransmitted_ = true;
+    return true;
   }
   return false;
 }
 
-void RetransmissionQueue::TxData::Nack() {
-  ++nack_count_;
-  if (nack_count_ >= kNumberOfNacksForRetransmission) {
-    state_ = State::kToBeRetransmitted;
-  } else {
-    state_ = State::kNacked;
-  }
-}
-
 void RetransmissionQueue::TxData::Retransmit() {
-  state_ = State::kInFlight;
+  ack_state_ = AckState::kUnacked;
+  should_be_retransmitted_ = false;
+
   nack_count_ = 0;
   ++num_retransmissions_;
 }
 
+void RetransmissionQueue::TxData::Abandon() {
+  is_abandoned_ = true;
+  should_be_retransmitted_ = false;
+}
+
 bool RetransmissionQueue::TxData::has_expired(TimeMs now) const {
-  if (state_ != State::kAcked && state_ != State::kAbandoned) {
+  if (ack_state_ != AckState::kAcked && !is_abandoned_) {
     if (max_retransmissions_.has_value() &&
         num_retransmissions_ >= *max_retransmissions_) {
       return true;
@@ -704,13 +745,13 @@
     UnwrappedTSN tsn = elem.first;
     TxData& other = elem.second;
 
-    if (other.state() != State::kAbandoned &&
+    if (!other.is_abandoned() &&
         other.data().stream_id == item.data().stream_id &&
         other.data().is_unordered == item.data().is_unordered &&
         other.data().message_id == item.data().message_id) {
       RTC_DLOG(LS_VERBOSE) << log_prefix_ << "Marking chunk " << *tsn.Wrap()
                            << " as abandoned";
-      other.SetState(State::kAbandoned);
+      other.Abandon();
     }
   }
 }
@@ -724,8 +765,7 @@
     UnwrappedTSN tsn = elem.first;
     const TxData& item = elem.second;
 
-    if ((tsn != new_cumulative_ack.next_value()) ||
-        item.state() != State::kAbandoned) {
+    if ((tsn != new_cumulative_ack.next_value()) || !item.is_abandoned()) {
       break;
     }
     new_cumulative_ack = tsn;
@@ -752,8 +792,7 @@
     UnwrappedTSN tsn = elem.first;
     const TxData& item = elem.second;
 
-    if ((tsn != new_cumulative_ack.next_value()) ||
-        item.state() != State::kAbandoned) {
+    if ((tsn != new_cumulative_ack.next_value()) || !item.is_abandoned()) {
       break;
     }
     new_cumulative_ack = tsn;
diff --git a/net/dcsctp/tx/retransmission_queue.h b/net/dcsctp/tx/retransmission_queue.h
index c2599a4..e9586b2 100644
--- a/net/dcsctp/tx/retransmission_queue.h
+++ b/net/dcsctp/tx/retransmission_queue.h
@@ -43,7 +43,7 @@
 class RetransmissionQueue {
  public:
   static constexpr size_t kMinimumFragmentedPayload = 10;
-  // State for DATA chunks (message fragments) in the queue.
+  // State for DATA chunks (message fragments) in the queue - used in tests.
   enum class State {
     // The chunk has been sent but not received yet (from the sender's point of
     // view, as no SACK has been received yet that reference this chunk).
@@ -154,24 +154,50 @@
 
     TimeMs time_sent() const { return time_sent_; }
 
-    State state() const { return state_; }
-    void SetState(State state) { state_ = state; }
-
     const Data& data() const { return data_; }
 
-    // Nacks an item. If it has been nacked enough times, it will be marked for
-    // retransmission.
-    void Nack();
+    // Acks an item.
+    void Ack();
+
+    // Nacks an item. If it has been nacked enough times, or if `retransmit_now`
+    // is set, it might be marked for retransmission, which is indicated by the
+    // return value.
+    bool Nack(bool retransmit_now = false);
+
+    // Prepares the item to be retransmitted. Sets it as outstanding and
+    // clears all nack counters.
     void Retransmit();
 
-    bool has_been_retransmitted() { return num_retransmissions_ > 0; }
+    // Marks this item as abandoned.
+    void Abandon();
+
+    bool is_outstanding() const { return ack_state_ == AckState::kUnacked; }
+    bool is_acked() const { return ack_state_ == AckState::kAcked; }
+    bool is_abandoned() const { return is_abandoned_; }
+
+    // Indicates if this chunk should be retransmitted.
+    bool should_be_retransmitted() const { return should_be_retransmitted_; }
+    // Indicates if this chunk has ever been retransmitted.
+    bool has_been_retransmitted() const { return num_retransmissions_ > 0; }
 
     // Given the current time, and the current state of this DATA chunk, it will
     // indicate if it has expired (SCTP Partial Reliability Extension).
     bool has_expired(TimeMs now) const;
 
    private:
-    State state_ = State::kInFlight;
+    enum class AckState {
+      kUnacked,
+      kAcked,
+      kNacked,
+    };
+    // Indicates the presence of this chunk, if it's in flight (Unacked), has
+    // been received (Acked) or is lost (Nacked).
+    AckState ack_state_ = AckState::kUnacked;
+    // Indicates if this chunk has been abandoned, which is a terminal state.
+    bool is_abandoned_ = false;
+    // Indicates if this chunk should be retransmitted.
+    bool should_be_retransmitted_ = false;
+
     // The number of times the DATA chunk has been nacked (by having received a
     // SACK which doesn't include it). Will be cleared on retransmissions.
     size_t nack_count_ = 0;
@@ -214,6 +240,8 @@
     UnwrappedTSN highest_tsn_acked;
   };
 
+  bool IsConsistent() const;
+
   // Returns how large a chunk will be, serialized, carrying the data
   size_t GetSerializedChunkSize(const Data& data) const;
 
@@ -270,8 +298,6 @@
   // Update the congestion control algorithm, given as packet loss has been
   // detected, as reported in an incoming SACK chunk.
   void HandlePacketLoss(UnwrappedTSN highest_tsn_acked);
-  // Recalculate the number of in-flight payload bytes.
-  void RecalculateOutstandingBytes();
   // Update the view of the receiver window size.
   void UpdateReceiverWindow(uint32_t a_rwnd);
   // Given `max_size` of space left in a packet, which chunks can be added to
@@ -337,7 +363,7 @@
   // cumulative acked. Note that it also contains chunks that have been acked in
   // gap ack blocks.
   std::map<UnwrappedTSN, TxData> outstanding_data_;
-  // The sum of the message bytes of the send_queue_
+  // The number of bytes that are in-flight (sent but not yet acked or nacked).
   size_t outstanding_bytes_ = 0;
 };
 }  // namespace dcsctp
diff --git a/net/dcsctp/tx/retransmission_queue_test.cc b/net/dcsctp/tx/retransmission_queue_test.cc
index f36d91e..f7368d1 100644
--- a/net/dcsctp/tx/retransmission_queue_test.cc
+++ b/net/dcsctp/tx/retransmission_queue_test.cc
@@ -800,5 +800,117 @@
   EXPECT_THAT(chunks_to_send, ElementsAre(Pair(TSN(10), _), Pair(TSN(11), _)));
 }
 
+TEST_F(RetransmissionQueueTest, AccountsInflightAbandonedChunksAsOutstanding) {
+  RetransmissionQueue queue = CreateQueue();
+  EXPECT_CALL(producer_, Produce)
+      .WillOnce([this](TimeMs, size_t) {
+        SendQueue::DataToSend dts(gen_.Ordered({1, 2, 3, 4}, "B"));
+        dts.max_retransmissions = 0;
+        return dts;
+      })
+      .WillOnce([this](TimeMs, size_t) {
+        SendQueue::DataToSend dts(gen_.Ordered({5, 6, 7, 8}, ""));
+        dts.max_retransmissions = 0;
+        return dts;
+      })
+      .WillOnce([this](TimeMs, size_t) {
+        SendQueue::DataToSend dts(gen_.Ordered({9, 10, 11, 12}, ""));
+        dts.max_retransmissions = 0;
+        return dts;
+      })
+      .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; });
+
+  // Send and ack first chunk (TSN 10)
+  std::vector<std::pair<TSN, Data>> chunks_to_send =
+      queue.GetChunksToSend(now_, 1000);
+  EXPECT_THAT(chunks_to_send, ElementsAre(Pair(TSN(10), _), Pair(TSN(11), _),
+                                          Pair(TSN(12), _)));
+  EXPECT_THAT(queue.GetChunkStatesForTesting(),
+              ElementsAre(Pair(TSN(9), State::kAcked),      //
+                          Pair(TSN(10), State::kInFlight),  //
+                          Pair(TSN(11), State::kInFlight),  //
+                          Pair(TSN(12), State::kInFlight)));
+  EXPECT_EQ(queue.outstanding_bytes(), (16 + 4) * 3u);
+
+  // Discard the message while it was outstanding.
+  EXPECT_CALL(producer_, Discard(IsUnordered(false), StreamID(1), MID(42)))
+      .Times(1);
+  EXPECT_TRUE(queue.ShouldSendForwardTsn(now_));
+
+  EXPECT_THAT(queue.GetChunkStatesForTesting(),
+              ElementsAre(Pair(TSN(9), State::kAcked),       //
+                          Pair(TSN(10), State::kAbandoned),  //
+                          Pair(TSN(11), State::kAbandoned),  //
+                          Pair(TSN(12), State::kAbandoned)));
+  EXPECT_EQ(queue.outstanding_bytes(), (16 + 4) * 3u);
+
+  // Now ACK those, one at a time.
+  queue.HandleSack(now_, SackChunk(TSN(10), kArwnd, {}, {}));
+  EXPECT_EQ(queue.outstanding_bytes(), (16 + 4) * 2u);
+
+  queue.HandleSack(now_, SackChunk(TSN(11), kArwnd, {}, {}));
+  EXPECT_EQ(queue.outstanding_bytes(), (16 + 4) * 1u);
+
+  queue.HandleSack(now_, SackChunk(TSN(12), kArwnd, {}, {}));
+  EXPECT_EQ(queue.outstanding_bytes(), 0u);
+}
+
+TEST_F(RetransmissionQueueTest, AccountsNackedAbandonedChunksAsNotOutstanding) {
+  RetransmissionQueue queue = CreateQueue();
+  EXPECT_CALL(producer_, Produce)
+      .WillOnce([this](TimeMs, size_t) {
+        SendQueue::DataToSend dts(gen_.Ordered({1, 2, 3, 4}, "B"));
+        dts.max_retransmissions = 0;
+        return dts;
+      })
+      .WillOnce([this](TimeMs, size_t) {
+        SendQueue::DataToSend dts(gen_.Ordered({5, 6, 7, 8}, ""));
+        dts.max_retransmissions = 0;
+        return dts;
+      })
+      .WillOnce([this](TimeMs, size_t) {
+        SendQueue::DataToSend dts(gen_.Ordered({9, 10, 11, 12}, ""));
+        dts.max_retransmissions = 0;
+        return dts;
+      })
+      .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; });
+
+  // Send and ack first chunk (TSN 10)
+  std::vector<std::pair<TSN, Data>> chunks_to_send =
+      queue.GetChunksToSend(now_, 1000);
+  EXPECT_THAT(chunks_to_send, ElementsAre(Pair(TSN(10), _), Pair(TSN(11), _),
+                                          Pair(TSN(12), _)));
+  EXPECT_THAT(queue.GetChunkStatesForTesting(),
+              ElementsAre(Pair(TSN(9), State::kAcked),      //
+                          Pair(TSN(10), State::kInFlight),  //
+                          Pair(TSN(11), State::kInFlight),  //
+                          Pair(TSN(12), State::kInFlight)));
+  EXPECT_EQ(queue.outstanding_bytes(), (16 + 4) * 3u);
+
+  // Mark the message as lost.
+  queue.HandleT3RtxTimerExpiry();
+
+  EXPECT_CALL(producer_, Discard(IsUnordered(false), StreamID(1), MID(42)))
+      .Times(1);
+  EXPECT_TRUE(queue.ShouldSendForwardTsn(now_));
+
+  EXPECT_THAT(queue.GetChunkStatesForTesting(),
+              ElementsAre(Pair(TSN(9), State::kAcked),       //
+                          Pair(TSN(10), State::kAbandoned),  //
+                          Pair(TSN(11), State::kAbandoned),  //
+                          Pair(TSN(12), State::kAbandoned)));
+  EXPECT_EQ(queue.outstanding_bytes(), 0u);
+
+  // Now ACK those, one at a time.
+  queue.HandleSack(now_, SackChunk(TSN(10), kArwnd, {}, {}));
+  EXPECT_EQ(queue.outstanding_bytes(), 0u);
+
+  queue.HandleSack(now_, SackChunk(TSN(11), kArwnd, {}, {}));
+  EXPECT_EQ(queue.outstanding_bytes(), 0u);
+
+  queue.HandleSack(now_, SackChunk(TSN(12), kArwnd, {}, {}));
+  EXPECT_EQ(queue.outstanding_bytes(), 0u);
+}
+
 }  // namespace
 }  // namespace dcsctp