Minor fixes and refactoring for RtpTransport until the Demux.

This change fixes some inefficiencies and quirks in the code that
originates in RtpTransport leading up to the demux.

This work is in preparation for more refactoring of the Demux stage
onwards.

Bug: webrtc:10297
Change-Id: I7b8f00134657d62c722939618a55a91a2b6040bd
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/128220
Commit-Queue: Amit Hilbuch <amithi@webrtc.org>
Reviewed-by: Steve Anton <steveanton@webrtc.org>
Reviewed-by: Seth Hampson <shampson@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#27185}
diff --git a/media/base/rtp_utils.cc b/media/base/rtp_utils.cc
index 57e719b..0669933 100644
--- a/media/base/rtp_utils.cc
+++ b/media/base/rtp_utils.cc
@@ -277,26 +277,22 @@
           SetRtpSsrc(data, len, header.ssrc));
 }
 
-bool IsRtpPacket(const void* data, size_t len) {
-  if (len < kMinRtpPacketLen)
-    return false;
+static bool HasCorrectRtpVersion(rtc::ArrayView<const char> packet) {
+  return reinterpret_cast<const uint8_t*>(packet.data())[0] >> 6 == kRtpVersion;
+}
 
-  return (static_cast<const uint8_t*>(data)[0] >> 6) == kRtpVersion;
+bool IsRtpPacket(rtc::ArrayView<const char> packet) {
+  return packet.size() >= kMinRtpPacketLen && HasCorrectRtpVersion(packet);
 }
 
 // Check the RTP payload type. If 63 < payload type < 96, it's RTCP.
 // For additional details, see http://tools.ietf.org/html/rfc5761.
-bool IsRtcpPacket(const char* data, size_t len) {
-  if (len < kMinRtcpPacketLen) {
+bool IsRtcpPacket(rtc::ArrayView<const char> packet) {
+  if (packet.size() < kMinRtcpPacketLen || !HasCorrectRtpVersion(packet)) {
     return false;
   }
 
-  // RTCP must be a valid RTP packet.
-  if ((static_cast<uint8_t>(data[0]) >> 6) != kRtpVersion) {
-    return false;
-  }
-
-  char pt = data[1] & 0x7F;
+  char pt = packet[1] & 0x7F;
   return (63 < pt) && (pt < 96);
 }
 
@@ -304,13 +300,35 @@
   return payload_type >= 0 && payload_type <= 127;
 }
 
-bool IsValidRtpRtcpPacketSize(bool rtcp, size_t size) {
-  return (rtcp ? size >= kMinRtcpPacketLen : size >= kMinRtpPacketLen) &&
-         size <= kMaxRtpPacketLen;
+bool IsValidRtpPacketSize(RtpPacketType packet_type, size_t size) {
+  // TODO(webrtc:10418): uncomment when relands.
+  // RTC_DCHECK_NE(RtpPacketType::kUnknown, packet_type);
+  size_t min_packet_length = packet_type == RtpPacketType::kRtcp
+                                 ? kMinRtcpPacketLen
+                                 : kMinRtpPacketLen;
+  return size >= min_packet_length && size <= kMaxRtpPacketLen;
 }
 
-const char* RtpRtcpStringLiteral(bool rtcp) {
-  return rtcp ? "RTCP" : "RTP";
+absl::string_view RtpPacketTypeToString(RtpPacketType packet_type) {
+  switch (packet_type) {
+    case RtpPacketType::kRtp:
+      return "RTP";
+    case RtpPacketType::kRtcp:
+      return "RTCP";
+    case RtpPacketType::kUnknown:
+      return "Unknown";
+  }
+}
+
+RtpPacketType InferRtpPacketType(rtc::ArrayView<const char> packet) {
+  // RTCP packets are RTP packets so must check that first.
+  if (IsRtcpPacket(packet)) {
+    return RtpPacketType::kRtcp;
+  }
+  if (IsRtpPacket(packet)) {
+    return RtpPacketType::kRtp;
+  }
+  return RtpPacketType::kUnknown;
 }
 
 bool ValidateRtpHeader(const uint8_t* rtp,
@@ -475,7 +493,9 @@
   }
 
   // Making sure we have a valid RTP packet at the end.
-  if (!IsRtpPacket(data + rtp_start_pos, rtp_length) ||
+  auto packet = rtc::MakeArrayView(
+      reinterpret_cast<const char*>(data + rtp_start_pos), rtp_length);
+  if (!IsRtpPacket(packet) ||
       !ValidateRtpHeader(data + rtp_start_pos, rtp_length, nullptr)) {
     RTC_NOTREACHED();
     return false;
diff --git a/media/base/rtp_utils.h b/media/base/rtp_utils.h
index 93f3103..9ef9f9c 100644
--- a/media/base/rtp_utils.h
+++ b/media/base/rtp_utils.h
@@ -11,6 +11,8 @@
 #ifndef MEDIA_BASE_RTP_UTILS_H_
 #define MEDIA_BASE_RTP_UTILS_H_
 
+#include "absl/strings/string_view.h"
+#include "api/array_view.h"
 #include "rtc_base/byte_order.h"
 #include "rtc_base/system/rtc_export.h"
 
@@ -41,6 +43,12 @@
   kRtcpTypePSFB = 206,   // Payload-specific Feedback message payload type.
 };
 
+enum class RtpPacketType {
+  kRtp,
+  kRtcp,
+  kUnknown,
+};
+
 bool GetRtpPayloadType(const void* data, size_t len, int* value);
 bool GetRtpSeqNum(const void* data, size_t len, int* value);
 bool GetRtpTimestamp(const void* data, size_t len, uint32_t* value);
@@ -54,19 +62,19 @@
 // Assumes version 2, no padding, no extensions, no csrcs.
 bool SetRtpHeader(void* data, size_t len, const RtpHeader& header);
 
-bool IsRtpPacket(const void* data, size_t len);
+bool IsRtpPacket(rtc::ArrayView<const char> packet);
 
-bool IsRtcpPacket(const char* data, size_t len);
+bool IsRtcpPacket(rtc::ArrayView<const char> packet);
+// Checks the packet header to determine if it can be an RTP or RTCP packet.
+RtpPacketType InferRtpPacketType(rtc::ArrayView<const char> packet);
 // True if |payload type| is 0-127.
 bool IsValidRtpPayloadType(int payload_type);
 
 // True if |size| is appropriate for the indicated packet type.
-bool IsValidRtpRtcpPacketSize(bool rtcp, size_t size);
+bool IsValidRtpPacketSize(RtpPacketType packet_type, size_t size);
 
-// TODO(zstein): Consider using an enum instead of a bool to differentiate
-// between RTP and RTCP.
-// Returns "RTCP" or "RTP" according to |rtcp|.
-const char* RtpRtcpStringLiteral(bool rtcp);
+// Returns "RTCP", "RTP" or "Unknown" according to |packet_type|.
+absl::string_view RtpPacketTypeToString(RtpPacketType packet_type);
 
 // Verifies that a packet has a valid RTP header.
 bool RTC_EXPORT ValidateRtpHeader(const uint8_t* rtp,
diff --git a/media/base/rtp_utils_unittest.cc b/media/base/rtp_utils_unittest.cc
index 8ac68a4..d88b160 100644
--- a/media/base/rtp_utils_unittest.cc
+++ b/media/base/rtp_utils_unittest.cc
@@ -79,8 +79,18 @@
 // Index of AbsSendTimeExtn data in message |kRtpMsgWithAbsSendTimeExtension|.
 static const int kAstIndexInRtpMsg = 21;
 
+static const rtc::ArrayView<const char> kPcmuFrameArrayView =
+    rtc::MakeArrayView(reinterpret_cast<const char*>(kPcmuFrame),
+                       sizeof(kPcmuFrame));
+static const rtc::ArrayView<const char> kRtcpReportArrayView =
+    rtc::MakeArrayView(reinterpret_cast<const char*>(kRtcpReport),
+                       sizeof(kRtcpReport));
+static const rtc::ArrayView<const char> kInvalidPacketArrayView =
+    rtc::MakeArrayView(reinterpret_cast<const char*>(kInvalidPacket),
+                       sizeof(kInvalidPacket));
+
 TEST(RtpUtilsTest, GetRtp) {
-  EXPECT_TRUE(IsRtpPacket(kPcmuFrame, sizeof(kPcmuFrame)));
+  EXPECT_TRUE(IsRtpPacket(kPcmuFrameArrayView));
 
   int pt;
   EXPECT_TRUE(GetRtpPayloadType(kPcmuFrame, sizeof(kPcmuFrame), &pt));
@@ -344,4 +354,11 @@
                       sizeof(kExpectedTimestamp)));
 }
 
+TEST(RtpUtilsTest, InferRtpPacketType) {
+  EXPECT_EQ(RtpPacketType::kRtp, InferRtpPacketType(kPcmuFrameArrayView));
+  EXPECT_EQ(RtpPacketType::kRtcp, InferRtpPacketType(kRtcpReportArrayView));
+  EXPECT_EQ(RtpPacketType::kUnknown,
+            InferRtpPacketType(kInvalidPacketArrayView));
+}
+
 }  // namespace cricket
diff --git a/pc/channel.cc b/pc/channel.cc
index 991e9e3..647663e 100644
--- a/pc/channel.cc
+++ b/pc/channel.cc
@@ -93,11 +93,6 @@
   }
 }
 
-static bool ValidPacket(bool rtcp, const rtc::CopyOnWriteBuffer* packet) {
-  // Check the packet size. We could check the header too if needed.
-  return packet && IsValidRtpRtcpPacketSize(rtcp, packet->size());
-}
-
 template <class Codec>
 void RtpParametersFromMediaDescription(
     const MediaContentDescriptionImpl<Codec>* desc,
@@ -402,6 +397,8 @@
 bool BaseChannel::SendPacket(bool rtcp,
                              rtc::CopyOnWriteBuffer* packet,
                              const rtc::PacketOptions& options) {
+  // Until all the code is migrated to use RtpPacketType instead of bool.
+  RtpPacketType packet_type = rtcp ? RtpPacketType::kRtcp : RtpPacketType::kRtp;
   // SendPacket gets called from MediaEngine, on a pacer or an encoder thread.
   // If the thread is not our network thread, we will post to our network
   // so that the real work happens on our network. This avoids us having to
@@ -430,9 +427,9 @@
   }
 
   // Protect ourselves against crazy data.
-  if (!ValidPacket(rtcp, packet)) {
+  if (!IsValidRtpPacketSize(packet_type, packet->size())) {
     RTC_LOG(LS_ERROR) << "Dropping outgoing " << content_name_ << " "
-                      << RtpRtcpStringLiteral(rtcp)
+                      << RtpPacketTypeToString(packet_type)
                       << " packet: wrong size=" << packet->size();
     return false;
   }
@@ -524,7 +521,9 @@
     //    for us to just eat packets here. This is all sidestepped if RTCP mux
     //    is used anyway.
     RTC_LOG(LS_WARNING)
-        << "Can't process incoming " << RtpRtcpStringLiteral(rtcp)
+        << "Can't process incoming "
+        << RtpPacketTypeToString(rtcp ? RtpPacketType::kRtcp
+                                      : RtpPacketType::kRtp)
         << " packet when SRTP is inactive and crypto is required";
     return;
   }
diff --git a/pc/rtp_transport.cc b/pc/rtp_transport.cc
index 20559e0..bd11e57 100644
--- a/pc/rtp_transport.cc
+++ b/pc/rtp_transport.cc
@@ -184,10 +184,10 @@
   return parameters_;
 }
 
-void RtpTransport::DemuxPacket(rtc::CopyOnWriteBuffer* packet,
+void RtpTransport::DemuxPacket(rtc::CopyOnWriteBuffer packet,
                                int64_t packet_time_us) {
   webrtc::RtpPacketReceived parsed_packet(&header_extension_map_);
-  if (!parsed_packet.Parse(std::move(*packet))) {
+  if (!parsed_packet.Parse(std::move(packet))) {
     RTC_LOG(LS_ERROR)
         << "Failed to parse the incoming RTP packet before demuxing. Drop it.";
     return;
@@ -233,14 +233,14 @@
   SignalSentPacket(sent_packet);
 }
 
-void RtpTransport::OnRtpPacketReceived(rtc::CopyOnWriteBuffer* packet,
+void RtpTransport::OnRtpPacketReceived(rtc::CopyOnWriteBuffer packet,
                                        int64_t packet_time_us) {
   DemuxPacket(packet, packet_time_us);
 }
 
-void RtpTransport::OnRtcpPacketReceived(rtc::CopyOnWriteBuffer* packet,
+void RtpTransport::OnRtcpPacketReceived(rtc::CopyOnWriteBuffer packet,
                                         int64_t packet_time_us) {
-  SignalRtcpPacketReceived(packet, packet_time_us);
+  SignalRtcpPacketReceived(&packet, packet_time_us);
 }
 
 void RtpTransport::OnReadPacket(rtc::PacketTransportInternal* transport,
@@ -252,27 +252,26 @@
 
   // When using RTCP multiplexing we might get RTCP packets on the RTP
   // transport. We check the RTP payload type to determine if it is RTCP.
-  bool rtcp =
-      transport == rtcp_packet_transport() || cricket::IsRtcpPacket(data, len);
-
+  auto array_view = rtc::MakeArrayView(data, len);
+  cricket::RtpPacketType packet_type = cricket::InferRtpPacketType(array_view);
   // Filter out the packet that is neither RTP nor RTCP.
-  if (!rtcp && !cricket::IsRtpPacket(data, len)) {
+  if (packet_type == cricket::RtpPacketType::kUnknown) {
+    return;
+  }
+
+  // Protect ourselves against crazy data.
+  if (!cricket::IsValidRtpPacketSize(packet_type, len)) {
+    RTC_LOG(LS_ERROR) << "Dropping incoming "
+                      << cricket::RtpPacketTypeToString(packet_type)
+                      << " packet: wrong size=" << len;
     return;
   }
 
   rtc::CopyOnWriteBuffer packet(data, len);
-  // Protect ourselves against crazy data.
-  if (!cricket::IsValidRtpRtcpPacketSize(rtcp, packet.size())) {
-    RTC_LOG(LS_ERROR) << "Dropping incoming "
-                      << cricket::RtpRtcpStringLiteral(rtcp)
-                      << " packet: wrong size=" << packet.size();
-    return;
-  }
-
-  if (rtcp) {
-    OnRtcpPacketReceived(&packet, packet_time_us);
+  if (packet_type == cricket::RtpPacketType::kRtcp) {
+    OnRtcpPacketReceived(std::move(packet), packet_time_us);
   } else {
-    OnRtpPacketReceived(&packet, packet_time_us);
+    OnRtpPacketReceived(std::move(packet), packet_time_us);
   }
 }
 
diff --git a/pc/rtp_transport.h b/pc/rtp_transport.h
index f188a17..dfdabbc 100644
--- a/pc/rtp_transport.h
+++ b/pc/rtp_transport.h
@@ -87,7 +87,7 @@
   RtpTransportAdapter* GetInternal() override;
 
   // These methods will be used in the subclasses.
-  void DemuxPacket(rtc::CopyOnWriteBuffer* packet, int64_t packet_time_us);
+  void DemuxPacket(rtc::CopyOnWriteBuffer packet, int64_t packet_time_us);
 
   bool SendPacket(bool rtcp,
                   rtc::CopyOnWriteBuffer* packet,
@@ -97,9 +97,9 @@
   // Overridden by SrtpTransport.
   virtual void OnNetworkRouteChanged(
       absl::optional<rtc::NetworkRoute> network_route);
-  virtual void OnRtpPacketReceived(rtc::CopyOnWriteBuffer* packet,
+  virtual void OnRtpPacketReceived(rtc::CopyOnWriteBuffer packet,
                                    int64_t packet_time_us);
-  virtual void OnRtcpPacketReceived(rtc::CopyOnWriteBuffer* packet,
+  virtual void OnRtcpPacketReceived(rtc::CopyOnWriteBuffer packet,
                                     int64_t packet_time_us);
   // Overridden by SrtpTransport and DtlsSrtpTransport.
   virtual void OnWritableState(rtc::PacketTransportInternal* packet_transport);
diff --git a/pc/srtp_transport.cc b/pc/srtp_transport.cc
index c7e4f0e..20e32f5 100644
--- a/pc/srtp_transport.cc
+++ b/pc/srtp_transport.cc
@@ -13,6 +13,7 @@
 #include <stdint.h>
 #include <string.h>
 #include <string>
+#include <utility>
 #include <vector>
 
 #include "media/base/rtp_utils.h"
@@ -197,7 +198,7 @@
   return SendPacket(/*rtcp=*/true, packet, options, flags);
 }
 
-void SrtpTransport::OnRtpPacketReceived(rtc::CopyOnWriteBuffer* packet,
+void SrtpTransport::OnRtpPacketReceived(rtc::CopyOnWriteBuffer packet,
                                         int64_t packet_time_us) {
   if (!IsSrtpActive()) {
     RTC_LOG(LS_WARNING)
@@ -205,8 +206,8 @@
     return;
   }
   TRACE_EVENT0("webrtc", "SRTP Decode");
-  char* data = packet->data<char>();
-  int len = rtc::checked_cast<int>(packet->size());
+  char* data = packet.data<char>();
+  int len = rtc::checked_cast<int>(packet.size());
   if (!UnprotectRtp(data, len, &len)) {
     int seq_num = -1;
     uint32_t ssrc = 0;
@@ -225,11 +226,11 @@
     ++decryption_failure_count_;
     return;
   }
-  packet->SetSize(len);
-  DemuxPacket(packet, packet_time_us);
+  packet.SetSize(len);
+  DemuxPacket(std::move(packet), packet_time_us);
 }
 
-void SrtpTransport::OnRtcpPacketReceived(rtc::CopyOnWriteBuffer* packet,
+void SrtpTransport::OnRtcpPacketReceived(rtc::CopyOnWriteBuffer packet,
                                          int64_t packet_time_us) {
   if (!IsSrtpActive()) {
     RTC_LOG(LS_WARNING)
@@ -237,8 +238,8 @@
     return;
   }
   TRACE_EVENT0("webrtc", "SRTP Decode");
-  char* data = packet->data<char>();
-  int len = rtc::checked_cast<int>(packet->size());
+  char* data = packet.data<char>();
+  int len = rtc::checked_cast<int>(packet.size());
   if (!UnprotectRtcp(data, len, &len)) {
     int type = -1;
     cricket::GetRtcpType(data, len, &type);
@@ -246,8 +247,8 @@
                       << ", type=" << type;
     return;
   }
-  packet->SetSize(len);
-  SignalRtcpPacketReceived(packet, packet_time_us);
+  packet.SetSize(len);
+  SignalRtcpPacketReceived(&packet, packet_time_us);
 }
 
 void SrtpTransport::OnNetworkRouteChanged(
diff --git a/pc/srtp_transport.h b/pc/srtp_transport.h
index 7512711..e725733 100644
--- a/pc/srtp_transport.h
+++ b/pc/srtp_transport.h
@@ -116,9 +116,9 @@
   void ConnectToRtpTransport();
   void CreateSrtpSessions();
 
-  void OnRtpPacketReceived(rtc::CopyOnWriteBuffer* packet,
+  void OnRtpPacketReceived(rtc::CopyOnWriteBuffer packet,
                            int64_t packet_time_us) override;
-  void OnRtcpPacketReceived(rtc::CopyOnWriteBuffer* packet,
+  void OnRtcpPacketReceived(rtc::CopyOnWriteBuffer packet,
                             int64_t packet_time_us) override;
   void OnNetworkRouteChanged(
       absl::optional<rtc::NetworkRoute> network_route) override;