Add ability to flush packets from pacer queue on a key frame

Bug: webrtc:11340
Change-Id: I70a97ab3ea576e54f1b4cf02042af5e6d5d6c2de
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/300541
Auto-Submit: Erik Språng <sprang@webrtc.org>
Reviewed-by: Ying Wang <yinwa@webrtc.org>
Commit-Queue: Ying Wang <yinwa@webrtc.org>
Cr-Commit-Position: refs/heads/main@{#39805}
diff --git a/modules/pacing/pacing_controller.cc b/modules/pacing/pacing_controller.cc
index cd94c7b..a526fc13 100644
--- a/modules/pacing/pacing_controller.cc
+++ b/modules/pacing/pacing_controller.cc
@@ -67,6 +67,8 @@
           IsEnabled(field_trials_, "WebRTC-Pacer-IgnoreTransportOverhead")),
       fast_retransmissions_(
           IsEnabled(field_trials_, "WebRTC-Pacer-FastRetransmissions")),
+      keyframe_flushing_(
+          IsEnabled(field_trials_, "WebRTC-Pacer-KeyframeFlushing")),
       transport_overhead_per_packet_(DataSize::Zero()),
       send_burst_interval_(TimeDelta::Zero()),
       last_timestamp_(clock_->CurrentTime()),
@@ -188,6 +190,21 @@
       << "SetPacingRate must be called before InsertPacket.";
   RTC_CHECK(packet->packet_type());
 
+  if (keyframe_flushing_ &&
+      packet->packet_type() == RtpPacketMediaType::kVideo &&
+      packet->is_key_frame() && packet->is_first_packet_of_frame() &&
+      !packet_queue_.HasKeyframePackets(packet->Ssrc())) {
+    // First packet of a keyframe (and no keyframe packets currently in the
+    // queue). Flush any pending packets currently in the queue for that stream
+    // in order to get the new keyframe out as quickly as possible.
+    packet_queue_.RemovePacketsForSsrc(packet->Ssrc());
+    absl::optional<uint32_t> rtx_ssrc =
+        packet_sender_->GetRtxSsrcForMedia(packet->Ssrc());
+    if (rtx_ssrc) {
+      packet_queue_.RemovePacketsForSsrc(*rtx_ssrc);
+    }
+  }
+
   prober_.OnIncomingPacket(DataSize::Bytes(packet->payload_size()));
 
   const Timestamp now = CurrentTime();
diff --git a/modules/pacing/pacing_controller.h b/modules/pacing/pacing_controller.h
index 5b94837..b0d802b 100644
--- a/modules/pacing/pacing_controller.h
+++ b/modules/pacing/pacing_controller.h
@@ -205,6 +205,7 @@
   const bool pace_audio_;
   const bool ignore_transport_overhead_;
   const bool fast_retransmissions_;
+  const bool keyframe_flushing_;
 
   DataSize transport_overhead_per_packet_;
   TimeDelta send_burst_interval_;
diff --git a/modules/pacing/pacing_controller_unittest.cc b/modules/pacing/pacing_controller_unittest.cc
index 688a3cc..d2ee908 100644
--- a/modules/pacing/pacing_controller_unittest.cc
+++ b/modules/pacing/pacing_controller_unittest.cc
@@ -2255,5 +2255,44 @@
   EXPECT_LE(callback.padding_sent(), kMaxPadding.bytes<size_t>());
 }
 
+TEST_F(PacingControllerTest, FlushesPacketsOnKeyFrames) {
+  const uint32_t kSsrc = 12345;
+  const uint32_t kRtxSsrc = 12346;
+
+  const test::ExplicitKeyValueConfig trials(
+      "WebRTC-Pacer-KeyframeFlushing/Enabled/");
+  auto pacer = std::make_unique<PacingController>(&clock_, &callback_, trials);
+  EXPECT_CALL(callback_, GetRtxSsrcForMedia(kSsrc))
+      .WillRepeatedly(Return(kRtxSsrc));
+  pacer->SetPacingRates(kTargetRate, DataRate::Zero());
+
+  // Enqueue a video packet and a retransmission of that video stream.
+  pacer->EnqueuePacket(BuildPacket(RtpPacketMediaType::kVideo, kSsrc,
+                                   /*sequence_number=*/1, /*capture_time=*/1,
+                                   /*size_bytes=*/100));
+  pacer->EnqueuePacket(BuildPacket(RtpPacketMediaType::kRetransmission,
+                                   kRtxSsrc,
+                                   /*sequence_number=*/10, /*capture_time=*/1,
+                                   /*size_bytes=*/100));
+  EXPECT_EQ(pacer->QueueSizePackets(), 2u);
+
+  // Enqueue the first packet of a keyframe for said stream.
+  auto packet = BuildPacket(RtpPacketMediaType::kVideo, kSsrc,
+                            /*sequence_number=*/2, /*capture_time=*/2,
+                            /*size_bytes=*/1000);
+  packet->set_is_key_frame(true);
+  packet->set_first_packet_of_frame(true);
+  pacer->EnqueuePacket(std::move(packet));
+
+  // Only they new keyframe packet should be left in the queue.
+  EXPECT_EQ(pacer->QueueSizePackets(), 1u);
+
+  EXPECT_CALL(callback_, SendPacket(kSsrc, /*sequence_number=*/2,
+                                    /*timestamp=*/2, /*is_retrnamission=*/false,
+                                    /*is_padding=*/false));
+  AdvanceTimeUntil(pacer->NextSendTime());
+  pacer->ProcessPackets();
+}
+
 }  // namespace
 }  // namespace webrtc
diff --git a/modules/pacing/prioritized_packet_queue.cc b/modules/pacing/prioritized_packet_queue.cc
index 0c285c4..ea211ea 100644
--- a/modules/pacing/prioritized_packet_queue.cc
+++ b/modules/pacing/prioritized_packet_queue.cc
@@ -50,10 +50,13 @@
 }
 
 PrioritizedPacketQueue::StreamQueue::StreamQueue(Timestamp creation_time)
-    : last_enqueue_time_(creation_time) {}
+    : last_enqueue_time_(creation_time), num_keyframe_packets_(0) {}
 
 bool PrioritizedPacketQueue::StreamQueue::EnqueuePacket(QueuedPacket packet,
                                                         int priority_level) {
+  if (packet.packet->is_key_frame()) {
+    ++num_keyframe_packets_;
+  }
   bool first_packet_at_level = packets_[priority_level].empty();
   packets_[priority_level].push_back(std::move(packet));
   return first_packet_at_level;
@@ -64,6 +67,10 @@
   RTC_DCHECK(!packets_[priority_level].empty());
   QueuedPacket packet = std::move(packets_[priority_level].front());
   packets_[priority_level].pop_front();
+  if (packet.packet->is_key_frame()) {
+    RTC_DCHECK_GT(num_keyframe_packets_, 0);
+    --num_keyframe_packets_;
+  }
   return packet;
 }
 
@@ -98,6 +105,7 @@
   for (int i = 0; i < kNumPriorityLevels; ++i) {
     packets_by_prio[i].swap(packets_[i]);
   }
+  num_keyframe_packets_ = 0;
   return packets_by_prio;
 }
 
@@ -292,6 +300,14 @@
   }
 }
 
+bool PrioritizedPacketQueue::HasKeyframePackets(uint32_t ssrc) const {
+  auto it = streams_.find(ssrc);
+  if (it != streams_.end()) {
+    return it->second->has_keyframe_packets();
+  }
+  return false;
+}
+
 void PrioritizedPacketQueue::DequeuePacketInternal(QueuedPacket& packet) {
   --size_packets_;
   RTC_DCHECK(packet.packet->packet_type().has_value());
diff --git a/modules/pacing/prioritized_packet_queue.h b/modules/pacing/prioritized_packet_queue.h
index 364b53a..935c530 100644
--- a/modules/pacing/prioritized_packet_queue.h
+++ b/modules/pacing/prioritized_packet_queue.h
@@ -85,6 +85,10 @@
   // Remove any packets matching the given SSRC.
   void RemovePacketsForSsrc(uint32_t ssrc);
 
+  // Checks if the queue for the given SSRC has original (retransmissions not
+  // counted) video packets containing keyframe data.
+  bool HasKeyframePackets(uint32_t ssrc) const;
+
  private:
   static constexpr int kNumPriorityLevels = 4;
 
@@ -118,12 +122,14 @@
     bool IsEmpty() const;
     Timestamp LeadingPacketEnqueueTime(int priority_level) const;
     Timestamp LastEnqueueTime() const;
+    bool has_keyframe_packets() const { return num_keyframe_packets_ > 0; }
 
     std::array<std::deque<QueuedPacket>, kNumPriorityLevels> DequeueAll();
 
    private:
     std::deque<QueuedPacket> packets_[kNumPriorityLevels];
     Timestamp last_enqueue_time_;
+    int num_keyframe_packets_;
   };
 
   // Remove the packet from the internal state, e.g. queue time / size etc.
diff --git a/modules/pacing/prioritized_packet_queue_unittest.cc b/modules/pacing/prioritized_packet_queue_unittest.cc
index 964051c..9ed1964 100644
--- a/modules/pacing/prioritized_packet_queue_unittest.cc
+++ b/modules/pacing/prioritized_packet_queue_unittest.cc
@@ -27,12 +27,14 @@
 
 std::unique_ptr<RtpPacketToSend> CreatePacket(RtpPacketMediaType type,
                                               uint16_t sequence_number,
-                                              uint32_t ssrc = kDefaultSsrc) {
+                                              uint32_t ssrc = kDefaultSsrc,
+                                              bool is_key_frame = false) {
   auto packet = std::make_unique<RtpPacketToSend>(/*extensions=*/nullptr);
   packet->set_packet_type(type);
   packet->SetSsrc(ssrc);
   packet->SetSequenceNumber(sequence_number);
   packet->SetPayloadSize(kDefaultPayloadSize);
+  packet->set_is_key_frame(is_key_frame);
   return packet;
 }
 
@@ -360,4 +362,55 @@
   EXPECT_TRUE(queue.Empty());
 }
 
+TEST(PrioritizedPacketQueue, ReportsKeyframePackets) {
+  Timestamp now = Timestamp::Zero();
+  PrioritizedPacketQueue queue(now);
+  const uint32_t kVideoSsrc1 = 1234;
+  const uint32_t kVideoSsrc2 = 2345;
+
+  EXPECT_FALSE(queue.HasKeyframePackets(kVideoSsrc1));
+  EXPECT_FALSE(queue.HasKeyframePackets(kVideoSsrc2));
+
+  queue.Push(now, CreatePacket(RtpPacketMediaType::kVideo, /*seq=*/1,
+                               kVideoSsrc1, /*is_key_frame=*/true));
+  queue.Push(now, CreatePacket(RtpPacketMediaType::kVideo, /*seq=*/11,
+                               kVideoSsrc2, /*is_key_frame=*/false));
+
+  EXPECT_TRUE(queue.HasKeyframePackets(kVideoSsrc1));
+  EXPECT_FALSE(queue.HasKeyframePackets(kVideoSsrc2));
+
+  queue.Push(now, CreatePacket(RtpPacketMediaType::kVideo, /*seq=*/2,
+                               kVideoSsrc1, /*is_key_frame=*/true));
+  queue.Push(now, CreatePacket(RtpPacketMediaType::kVideo, /*seq=*/12,
+                               kVideoSsrc2, /*is_key_frame=*/true));
+
+  EXPECT_TRUE(queue.HasKeyframePackets(kVideoSsrc1));
+  EXPECT_TRUE(queue.HasKeyframePackets(kVideoSsrc2));
+
+  queue.Push(now, CreatePacket(RtpPacketMediaType::kVideo, /*seq=*/3,
+                               kVideoSsrc1, /*is_key_frame=*/false));
+  queue.Push(now, CreatePacket(RtpPacketMediaType::kVideo, /*seq=*/13,
+                               kVideoSsrc2, /*is_key_frame=*/true));
+
+  EXPECT_TRUE(queue.HasKeyframePackets(kVideoSsrc1));
+  EXPECT_TRUE(queue.HasKeyframePackets(kVideoSsrc2));
+
+  EXPECT_EQ(queue.Pop()->SequenceNumber(), 1);
+  EXPECT_EQ(queue.Pop()->SequenceNumber(), 11);
+
+  EXPECT_TRUE(queue.HasKeyframePackets(kVideoSsrc1));
+  EXPECT_TRUE(queue.HasKeyframePackets(kVideoSsrc2));
+
+  EXPECT_EQ(queue.Pop()->SequenceNumber(), 2);
+  EXPECT_EQ(queue.Pop()->SequenceNumber(), 12);
+
+  EXPECT_FALSE(queue.HasKeyframePackets(kVideoSsrc1));
+  EXPECT_TRUE(queue.HasKeyframePackets(kVideoSsrc2));
+
+  queue.RemovePacketsForSsrc(kVideoSsrc2);
+
+  EXPECT_FALSE(queue.HasKeyframePackets(kVideoSsrc1));
+  EXPECT_FALSE(queue.HasKeyframePackets(kVideoSsrc2));
+}
+
 }  // namespace webrtc