Skip calling rtcp callback on packets containing invalid rtcp message

Low-Coverage-Reason: added code handles invalid rtcp packet scenarios which are covered by rtcp_receiver_fuzzer
Bug: webrtc:5260
Change-Id: Ia6ba0b736f5732fa6516edb2a5ed1b96079eb4ef
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/296580
Reviewed-by: Åsa Persson <asapersson@webrtc.org>
Commit-Queue: Danil Chapovalov <danilchap@webrtc.org>
Cr-Commit-Position: refs/heads/main@{#39503}
diff --git a/modules/rtp_rtcp/source/rtcp_receiver.cc b/modules/rtp_rtcp/source/rtcp_receiver.cc
index 9b4cf18..69fd1f6 100644
--- a/modules/rtp_rtcp/source/rtcp_receiver.cc
+++ b/modules/rtp_rtcp/source/rtcp_receiver.cc
@@ -416,59 +416,56 @@
   // For each remote SSRC we store if we've received a sender report or a DLRR
   // block.
   flat_map<uint32_t, RtcpReceivedBlock> received_blocks;
-  for (const uint8_t* next_block = packet.begin(); next_block != packet.end();
+  bool valid = true;
+  for (const uint8_t* next_block = packet.begin();
+       valid && next_block != packet.end();
        next_block = rtcp_block.NextPacket()) {
     ptrdiff_t remaining_blocks_size = packet.end() - next_block;
     RTC_DCHECK_GT(remaining_blocks_size, 0);
     if (!rtcp_block.Parse(next_block, remaining_blocks_size)) {
-      if (next_block == packet.begin()) {
-        // Failed to parse 1st header, nothing was extracted from this packet.
-        RTC_LOG(LS_WARNING) << "Incoming invalid RTCP packet";
-        return false;
-      }
-      ++num_skipped_packets_;
+      valid = false;
       break;
     }
 
     switch (rtcp_block.type()) {
       case rtcp::SenderReport::kPacketType:
-        HandleSenderReport(rtcp_block, packet_information);
+        valid = HandleSenderReport(rtcp_block, packet_information);
         received_blocks[packet_information->remote_ssrc].sender_report = true;
         break;
       case rtcp::ReceiverReport::kPacketType:
-        HandleReceiverReport(rtcp_block, packet_information);
+        valid = HandleReceiverReport(rtcp_block, packet_information);
         break;
       case rtcp::Sdes::kPacketType:
-        HandleSdes(rtcp_block, packet_information);
+        valid = HandleSdes(rtcp_block, packet_information);
         break;
       case rtcp::ExtendedReports::kPacketType: {
         bool contains_dlrr = false;
         uint32_t ssrc = 0;
-        HandleXr(rtcp_block, packet_information, contains_dlrr, ssrc);
+        valid = HandleXr(rtcp_block, packet_information, contains_dlrr, ssrc);
         if (contains_dlrr) {
           received_blocks[ssrc].dlrr = true;
         }
         break;
       }
       case rtcp::Bye::kPacketType:
-        HandleBye(rtcp_block);
+        valid = HandleBye(rtcp_block);
         break;
       case rtcp::App::kPacketType:
-        HandleApp(rtcp_block, packet_information);
+        valid = HandleApp(rtcp_block, packet_information);
         break;
       case rtcp::Rtpfb::kPacketType:
         switch (rtcp_block.fmt()) {
           case rtcp::Nack::kFeedbackMessageType:
-            HandleNack(rtcp_block, packet_information);
+            valid = HandleNack(rtcp_block, packet_information);
             break;
           case rtcp::Tmmbr::kFeedbackMessageType:
-            HandleTmmbr(rtcp_block, packet_information);
+            valid = HandleTmmbr(rtcp_block, packet_information);
             break;
           case rtcp::Tmmbn::kFeedbackMessageType:
-            HandleTmmbn(rtcp_block, packet_information);
+            valid = HandleTmmbn(rtcp_block, packet_information);
             break;
           case rtcp::RapidResyncRequest::kFeedbackMessageType:
-            HandleSrReq(rtcp_block, packet_information);
+            valid = HandleSrReq(rtcp_block, packet_information);
             break;
           case rtcp::TransportFeedback::kFeedbackMessageType:
             HandleTransportFeedback(rtcp_block, packet_information);
@@ -481,10 +478,10 @@
       case rtcp::Psfb::kPacketType:
         switch (rtcp_block.fmt()) {
           case rtcp::Pli::kFeedbackMessageType:
-            HandlePli(rtcp_block, packet_information);
+            valid = HandlePli(rtcp_block, packet_information);
             break;
           case rtcp::Fir::kFeedbackMessageType:
-            HandleFir(rtcp_block, packet_information);
+            valid = HandleFir(rtcp_block, packet_information);
             break;
           case rtcp::Psfb::kAfbMessageType:
             HandlePsfbApp(rtcp_block, packet_information);
@@ -500,6 +497,23 @@
     }
   }
 
+  if (num_skipped_packets_ > 0) {
+    const int64_t now_ms = clock_->TimeInMilliseconds();
+    if (now_ms - last_skipped_packets_warning_ms_ >= kMaxWarningLogIntervalMs) {
+      last_skipped_packets_warning_ms_ = now_ms;
+      RTC_LOG(LS_WARNING)
+          << num_skipped_packets_
+          << " RTCP blocks were skipped due to being malformed or of "
+             "unrecognized/unsupported type, during the past "
+          << (kMaxWarningLogIntervalMs / 1000) << " second period.";
+    }
+  }
+
+  if (!valid) {
+    ++num_skipped_packets_;
+    return false;
+  }
+
   for (const auto& rb : received_blocks) {
     if (rb.second.sender_report && !rb.second.dlrr) {
       auto rtt_stats = non_sender_rtts_.find(rb.first);
@@ -514,27 +528,14 @@
         local_media_ssrc(), packet_type_counter_);
   }
 
-  if (num_skipped_packets_ > 0) {
-    const int64_t now_ms = clock_->TimeInMilliseconds();
-    if (now_ms - last_skipped_packets_warning_ms_ >= kMaxWarningLogIntervalMs) {
-      last_skipped_packets_warning_ms_ = now_ms;
-      RTC_LOG(LS_WARNING)
-          << num_skipped_packets_
-          << " RTCP blocks were skipped due to being malformed or of "
-             "unrecognized/unsupported type, during the past "
-          << (kMaxWarningLogIntervalMs / 1000) << " second period.";
-    }
-  }
-
   return true;
 }
 
-void RTCPReceiver::HandleSenderReport(const CommonHeader& rtcp_block,
+bool RTCPReceiver::HandleSenderReport(const CommonHeader& rtcp_block,
                                       PacketInformation* packet_information) {
   rtcp::SenderReport sender_report;
   if (!sender_report.Parse(rtcp_block)) {
-    ++num_skipped_packets_;
-    return;
+    return false;
   }
 
   const uint32_t remote_ssrc = sender_report.sender_ssrc();
@@ -560,16 +561,18 @@
     packet_information->packet_type_flags |= kRtcpRr;
   }
 
-  for (const rtcp::ReportBlock& report_block : sender_report.report_blocks())
+  for (const rtcp::ReportBlock& report_block : sender_report.report_blocks()) {
     HandleReportBlock(report_block, packet_information, remote_ssrc);
+  }
+
+  return true;
 }
 
-void RTCPReceiver::HandleReceiverReport(const CommonHeader& rtcp_block,
+bool RTCPReceiver::HandleReceiverReport(const CommonHeader& rtcp_block,
                                         PacketInformation* packet_information) {
   rtcp::ReceiverReport receiver_report;
   if (!receiver_report.Parse(rtcp_block)) {
-    ++num_skipped_packets_;
-    return;
+    return false;
   }
 
   const uint32_t remote_ssrc = receiver_report.sender_ssrc();
@@ -580,8 +583,11 @@
 
   packet_information->packet_type_flags |= kRtcpRr;
 
-  for (const ReportBlock& report_block : receiver_report.report_blocks())
+  for (const ReportBlock& report_block : receiver_report.report_blocks()) {
     HandleReportBlock(report_block, packet_information, remote_ssrc);
+  }
+
+  return true;
 }
 
 void RTCPReceiver::HandleReportBlock(const ReportBlock& report_block,
@@ -740,12 +746,11 @@
   return tmmbr_info->tmmbn;
 }
 
-void RTCPReceiver::HandleSdes(const CommonHeader& rtcp_block,
+bool RTCPReceiver::HandleSdes(const CommonHeader& rtcp_block,
                               PacketInformation* packet_information) {
   rtcp::Sdes sdes;
   if (!sdes.Parse(rtcp_block)) {
-    ++num_skipped_packets_;
-    return;
+    return false;
   }
 
   for (const rtcp::Sdes::Chunk& chunk : sdes.chunks()) {
@@ -753,18 +758,19 @@
       cname_callback_->OnCname(chunk.ssrc, chunk.cname);
   }
   packet_information->packet_type_flags |= kRtcpSdes;
+
+  return true;
 }
 
-void RTCPReceiver::HandleNack(const CommonHeader& rtcp_block,
+bool RTCPReceiver::HandleNack(const CommonHeader& rtcp_block,
                               PacketInformation* packet_information) {
   rtcp::Nack nack;
   if (!nack.Parse(rtcp_block)) {
-    ++num_skipped_packets_;
-    return;
+    return false;
   }
 
   if (receiver_only_ || local_media_ssrc() != nack.media_ssrc())  // Not to us.
-    return;
+    return true;
 
   packet_information->nack_sequence_numbers.insert(
       packet_information->nack_sequence_numbers.end(),
@@ -778,29 +784,34 @@
     packet_type_counter_.nack_requests = nack_stats_.requests();
     packet_type_counter_.unique_nack_requests = nack_stats_.unique_requests();
   }
+
+  return true;
 }
 
-void RTCPReceiver::HandleApp(const rtcp::CommonHeader& rtcp_block,
+bool RTCPReceiver::HandleApp(const rtcp::CommonHeader& rtcp_block,
                              PacketInformation* packet_information) {
   rtcp::App app;
-  if (app.Parse(rtcp_block)) {
-    if (app.name() == rtcp::RemoteEstimate::kName &&
-        app.sub_type() == rtcp::RemoteEstimate::kSubType) {
-      rtcp::RemoteEstimate estimate(std::move(app));
-      if (estimate.ParseData()) {
-        packet_information->network_state_estimate = estimate.estimate();
-        return;
-      }
-    }
+  if (!app.Parse(rtcp_block)) {
+    return false;
   }
-  ++num_skipped_packets_;
+  if (app.name() == rtcp::RemoteEstimate::kName &&
+      app.sub_type() == rtcp::RemoteEstimate::kSubType) {
+    rtcp::RemoteEstimate estimate(std::move(app));
+    if (estimate.ParseData()) {
+      packet_information->network_state_estimate = estimate.estimate();
+    }
+    // RemoteEstimate is not a standard RTCP message. Failing to parse it
+    // doesn't indicates RTCP packet is invalid. It may indicate sender happens
+    // to use the same id for a different message. Thus don't return false.
+  }
+
+  return true;
 }
 
-void RTCPReceiver::HandleBye(const CommonHeader& rtcp_block) {
+bool RTCPReceiver::HandleBye(const CommonHeader& rtcp_block) {
   rtcp::Bye bye;
   if (!bye.Parse(rtcp_block)) {
-    ++num_skipped_packets_;
-    return;
+    return false;
   }
 
   // Clear our lists.
@@ -820,16 +831,16 @@
     received_rrtrs_ssrc_it_.erase(it);
   }
   xr_rr_rtt_ms_ = 0;
+  return true;
 }
 
-void RTCPReceiver::HandleXr(const CommonHeader& rtcp_block,
+bool RTCPReceiver::HandleXr(const CommonHeader& rtcp_block,
                             PacketInformation* packet_information,
                             bool& contains_dlrr,
                             uint32_t& ssrc) {
   rtcp::ExtendedReports xr;
   if (!xr.Parse(rtcp_block)) {
-    ++num_skipped_packets_;
-    return;
+    return false;
   }
   ssrc = xr.sender_ssrc();
   contains_dlrr = !xr.dlrr().sub_blocks().empty();
@@ -844,6 +855,7 @@
     HandleXrTargetBitrate(xr.sender_ssrc(), *xr.target_bitrate(),
                           packet_information);
   }
+  return true;
 }
 
 void RTCPReceiver::HandleXrReceiveReferenceTime(uint32_t sender_ssrc,
@@ -922,12 +934,11 @@
   packet_information->target_bitrate_allocation.emplace(bitrate_allocation);
 }
 
-void RTCPReceiver::HandlePli(const CommonHeader& rtcp_block,
+bool RTCPReceiver::HandlePli(const CommonHeader& rtcp_block,
                              PacketInformation* packet_information) {
   rtcp::Pli pli;
   if (!pli.Parse(rtcp_block)) {
-    ++num_skipped_packets_;
-    return;
+    return false;
   }
 
   if (local_media_ssrc() == pli.media_ssrc()) {
@@ -935,14 +946,14 @@
     // Received a signal that we need to send a new key frame.
     packet_information->packet_type_flags |= kRtcpPli;
   }
+  return true;
 }
 
-void RTCPReceiver::HandleTmmbr(const CommonHeader& rtcp_block,
+bool RTCPReceiver::HandleTmmbr(const CommonHeader& rtcp_block,
                                PacketInformation* packet_information) {
   rtcp::Tmmbr tmmbr;
   if (!tmmbr.Parse(rtcp_block)) {
-    ++num_skipped_packets_;
-    return;
+    return false;
   }
 
   uint32_t sender_ssrc = tmmbr.sender_ssrc();
@@ -967,14 +978,14 @@
     packet_information->packet_type_flags |= kRtcpTmmbr;
     break;
   }
+  return true;
 }
 
-void RTCPReceiver::HandleTmmbn(const CommonHeader& rtcp_block,
+bool RTCPReceiver::HandleTmmbn(const CommonHeader& rtcp_block,
                                PacketInformation* packet_information) {
   rtcp::Tmmbn tmmbn;
   if (!tmmbn.Parse(rtcp_block)) {
-    ++num_skipped_packets_;
-    return;
+    return false;
   }
 
   TmmbrInformation* tmmbr_info = FindOrCreateTmmbrInfo(tmmbn.sender_ssrc());
@@ -982,17 +993,18 @@
   packet_information->packet_type_flags |= kRtcpTmmbn;
 
   tmmbr_info->tmmbn = tmmbn.items();
+  return true;
 }
 
-void RTCPReceiver::HandleSrReq(const CommonHeader& rtcp_block,
+bool RTCPReceiver::HandleSrReq(const CommonHeader& rtcp_block,
                                PacketInformation* packet_information) {
   rtcp::RapidResyncRequest sr_req;
   if (!sr_req.Parse(rtcp_block)) {
-    ++num_skipped_packets_;
-    return;
+    return false;
   }
 
   packet_information->packet_type_flags |= kRtcpSrReq;
+  return true;
 }
 
 void RTCPReceiver::HandlePsfbApp(const CommonHeader& rtcp_block,
@@ -1017,20 +1029,20 @@
   }
 
   RTC_LOG(LS_WARNING) << "Unknown PSFB-APP packet.";
-
   ++num_skipped_packets_;
+  // Application layer feedback message doesn't have a standard format.
+  // Failing to parse one of known messages doesn't indicate an invalid RTCP.
 }
 
-void RTCPReceiver::HandleFir(const CommonHeader& rtcp_block,
+bool RTCPReceiver::HandleFir(const CommonHeader& rtcp_block,
                              PacketInformation* packet_information) {
   rtcp::Fir fir;
   if (!fir.Parse(rtcp_block)) {
-    ++num_skipped_packets_;
-    return;
+    return false;
   }
 
   if (fir.requests().empty())
-    return;
+    return true;
 
   const int64_t now_ms = clock_->TimeInMilliseconds();
   for (const rtcp::Fir::Request& fir_request : fir.requests()) {
@@ -1059,6 +1071,7 @@
     // Received signal that we need to send a new key frame.
     packet_information->packet_type_flags |= kRtcpFir;
   }
+  return true;
 }
 
 void RTCPReceiver::HandleTransportFeedback(
@@ -1068,6 +1081,9 @@
       new rtcp::TransportFeedback());
   if (!transport_feedback->Parse(rtcp_block)) {
     ++num_skipped_packets_;
+    // Application layer feedback message doesn't have a standard format.
+    // Failing to parse it as transport feedback messages doesn't indicate an
+    // invalid RTCP.
     return;
   }
 
diff --git a/modules/rtp_rtcp/source/rtcp_receiver.h b/modules/rtp_rtcp/source/rtcp_receiver.h
index 5007b42..567b5b3 100644
--- a/modules/rtp_rtcp/source/rtcp_receiver.h
+++ b/modules/rtp_rtcp/source/rtcp_receiver.h
@@ -277,11 +277,11 @@
   TmmbrInformation* GetTmmbrInformation(uint32_t remote_ssrc)
       RTC_EXCLUSIVE_LOCKS_REQUIRED(rtcp_receiver_lock_);
 
-  void HandleSenderReport(const rtcp::CommonHeader& rtcp_block,
+  bool HandleSenderReport(const rtcp::CommonHeader& rtcp_block,
                           PacketInformation* packet_information)
       RTC_EXCLUSIVE_LOCKS_REQUIRED(rtcp_receiver_lock_);
 
-  void HandleReceiverReport(const rtcp::CommonHeader& rtcp_block,
+  bool HandleReceiverReport(const rtcp::CommonHeader& rtcp_block,
                             PacketInformation* packet_information)
       RTC_EXCLUSIVE_LOCKS_REQUIRED(rtcp_receiver_lock_);
 
@@ -290,11 +290,11 @@
                          uint32_t remote_ssrc)
       RTC_EXCLUSIVE_LOCKS_REQUIRED(rtcp_receiver_lock_);
 
-  void HandleSdes(const rtcp::CommonHeader& rtcp_block,
+  bool HandleSdes(const rtcp::CommonHeader& rtcp_block,
                   PacketInformation* packet_information)
       RTC_EXCLUSIVE_LOCKS_REQUIRED(rtcp_receiver_lock_);
 
-  void HandleXr(const rtcp::CommonHeader& rtcp_block,
+  bool HandleXr(const rtcp::CommonHeader& rtcp_block,
                 PacketInformation* packet_information,
                 bool& contains_dlrr,
                 uint32_t& ssrc)
@@ -312,18 +312,18 @@
                              PacketInformation* packet_information)
       RTC_EXCLUSIVE_LOCKS_REQUIRED(rtcp_receiver_lock_);
 
-  void HandleNack(const rtcp::CommonHeader& rtcp_block,
+  bool HandleNack(const rtcp::CommonHeader& rtcp_block,
                   PacketInformation* packet_information)
       RTC_EXCLUSIVE_LOCKS_REQUIRED(rtcp_receiver_lock_);
 
-  void HandleApp(const rtcp::CommonHeader& rtcp_block,
+  bool HandleApp(const rtcp::CommonHeader& rtcp_block,
                  PacketInformation* packet_information)
       RTC_EXCLUSIVE_LOCKS_REQUIRED(rtcp_receiver_lock_);
 
-  void HandleBye(const rtcp::CommonHeader& rtcp_block)
+  bool HandleBye(const rtcp::CommonHeader& rtcp_block)
       RTC_EXCLUSIVE_LOCKS_REQUIRED(rtcp_receiver_lock_);
 
-  void HandlePli(const rtcp::CommonHeader& rtcp_block,
+  bool HandlePli(const rtcp::CommonHeader& rtcp_block,
                  PacketInformation* packet_information)
       RTC_EXCLUSIVE_LOCKS_REQUIRED(rtcp_receiver_lock_);
 
@@ -331,19 +331,19 @@
                      PacketInformation* packet_information)
       RTC_EXCLUSIVE_LOCKS_REQUIRED(rtcp_receiver_lock_);
 
-  void HandleTmmbr(const rtcp::CommonHeader& rtcp_block,
+  bool HandleTmmbr(const rtcp::CommonHeader& rtcp_block,
                    PacketInformation* packet_information)
       RTC_EXCLUSIVE_LOCKS_REQUIRED(rtcp_receiver_lock_);
 
-  void HandleTmmbn(const rtcp::CommonHeader& rtcp_block,
+  bool HandleTmmbn(const rtcp::CommonHeader& rtcp_block,
                    PacketInformation* packet_information)
       RTC_EXCLUSIVE_LOCKS_REQUIRED(rtcp_receiver_lock_);
 
-  void HandleSrReq(const rtcp::CommonHeader& rtcp_block,
+  bool HandleSrReq(const rtcp::CommonHeader& rtcp_block,
                    PacketInformation* packet_information)
       RTC_EXCLUSIVE_LOCKS_REQUIRED(rtcp_receiver_lock_);
 
-  void HandleFir(const rtcp::CommonHeader& rtcp_block,
+  bool HandleFir(const rtcp::CommonHeader& rtcp_block,
                  PacketInformation* packet_information)
       RTC_EXCLUSIVE_LOCKS_REQUIRED(rtcp_receiver_lock_);
 
diff --git a/modules/rtp_rtcp/source/rtcp_receiver_unittest.cc b/modules/rtp_rtcp/source/rtcp_receiver_unittest.cc
index b366731..b64363a 100644
--- a/modules/rtp_rtcp/source/rtcp_receiver_unittest.cc
+++ b/modules/rtp_rtcp/source/rtcp_receiver_unittest.cc
@@ -191,9 +191,8 @@
   // Too short feedback packet.
   const uint8_t bad_packet[] = {0x81, rtcp::Rtpfb::kPacketType, 0, 0};
 
-  // TODO(danilchap): Add expectation RtcpPacketTypesCounterUpdated
-  // is not called once parser would be adjusted to avoid that callback on
-  // semi-valid packets.
+  EXPECT_CALL(mocks.packet_type_counter_observer, RtcpPacketTypesCounterUpdated)
+      .Times(0);
   receiver.IncomingPacket(bad_packet);
 }