Replace use of SignalReadPacket in DtlsTransport
Instead use PacketTransportInternal::NotifyPacketReceived
Bug: webrtc:15368
Change-Id: I70a83865c9b564429366bd297abc7dbd50da02e4
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/340301
Commit-Queue: Per Kjellander <perkj@webrtc.org>
Reviewed-by: Jonas Oreland <jonaso@webrtc.org>
Cr-Commit-Position: refs/heads/main@{#41816}
diff --git a/p2p/BUILD.gn b/p2p/BUILD.gn
index c3858cb..9d25078 100644
--- a/p2p/BUILD.gn
+++ b/p2p/BUILD.gn
@@ -421,10 +421,13 @@
"../rtc_base:checks",
"../rtc_base:dscp",
"../rtc_base:logging",
+ "../rtc_base:socket_address",
"../rtc_base:ssl",
"../rtc_base:stream",
"../rtc_base:stringutils",
"../rtc_base:threading",
+ "../rtc_base:timeutils",
+ "../rtc_base/network:received_packet",
"../rtc_base/system:no_unique_address",
]
absl_deps = [
diff --git a/p2p/base/dtls_transport.cc b/p2p/base/dtls_transport.cc
index a9ff9d3..6f30c6d 100644
--- a/p2p/base/dtls_transport.cc
+++ b/p2p/base/dtls_transport.cc
@@ -11,6 +11,7 @@
#include "p2p/base/dtls_transport.h"
#include <algorithm>
+#include <cstdint>
#include <memory>
#include <utility>
@@ -26,10 +27,13 @@
#include "rtc_base/checks.h"
#include "rtc_base/dscp.h"
#include "rtc_base/logging.h"
+#include "rtc_base/network/received_packet.h"
#include "rtc_base/rtc_certificate.h"
+#include "rtc_base/socket_address.h"
#include "rtc_base/ssl_stream_adapter.h"
#include "rtc_base/stream.h"
#include "rtc_base/thread.h"
+#include "rtc_base/time_utils.h"
namespace cricket {
@@ -50,20 +54,20 @@
static const int kMinHandshakeTimeout = 50;
static const int kMaxHandshakeTimeout = 3000;
-static bool IsDtlsPacket(const char* data, size_t len) {
- const uint8_t* u = reinterpret_cast<const uint8_t*>(data);
- return (len >= kDtlsRecordHeaderLen && (u[0] > 19 && u[0] < 64));
+static bool IsDtlsPacket(rtc::ArrayView<const uint8_t> payload) {
+ const uint8_t* u = payload.data();
+ return (payload.size() >= kDtlsRecordHeaderLen && (u[0] > 19 && u[0] < 64));
}
-static bool IsDtlsClientHelloPacket(const char* data, size_t len) {
- if (!IsDtlsPacket(data, len)) {
+static bool IsDtlsClientHelloPacket(rtc::ArrayView<const uint8_t> payload) {
+ if (!IsDtlsPacket(payload)) {
return false;
}
- const uint8_t* u = reinterpret_cast<const uint8_t*>(data);
- return len > 17 && u[0] == 22 && u[13] == 1;
+ const uint8_t* u = payload.data();
+ return payload.size() > 17 && u[0] == 22 && u[13] == 1;
}
-static bool IsRtpPacket(const char* data, size_t len) {
- const uint8_t* u = reinterpret_cast<const uint8_t*>(data);
- return (len >= kMinRtpPacketLen && (u[0] & 0xC0) == 0x80);
+static bool IsRtpPacket(rtc::ArrayView<const uint8_t> payload) {
+ const uint8_t* u = payload.data();
+ return (payload.size() >= kMinRtpPacketLen && (u[0] & 0xC0) == 0x80);
}
StreamInterfaceChannel::StreamInterfaceChannel(
@@ -146,7 +150,11 @@
ConnectToIceTransport();
}
-DtlsTransport::~DtlsTransport() = default;
+DtlsTransport::~DtlsTransport() {
+ if (ice_transport_) {
+ ice_transport_->DeregisterReceivedPacketCallback(this);
+ }
+}
webrtc::DtlsTransportState DtlsTransport::dtls_state() const {
return dtls_state_;
@@ -444,7 +452,8 @@
case webrtc::DtlsTransportState::kConnected:
if (flags & PF_SRTP_BYPASS) {
RTC_DCHECK(!srtp_ciphers_.empty());
- if (!IsRtpPacket(data, size)) {
+ if (!IsRtpPacket(rtc::MakeArrayView(
+ reinterpret_cast<const uint8_t*>(data), size))) {
return -1;
}
@@ -513,7 +522,12 @@
RTC_DCHECK(ice_transport_);
ice_transport_->SignalWritableState.connect(this,
&DtlsTransport::OnWritableState);
- ice_transport_->SignalReadPacket.connect(this, &DtlsTransport::OnReadPacket);
+ ice_transport_->RegisterReceivedPacketCallback(
+ this, [&](rtc::PacketTransportInternal* transport,
+ const rtc::ReceivedPacket& packet) {
+ OnReadPacket(transport, packet);
+ });
+
ice_transport_->SignalSentPacket.connect(this, &DtlsTransport::OnSentPacket);
ice_transport_->SignalReadyToSend.connect(this,
&DtlsTransport::OnReadyToSend);
@@ -590,17 +604,13 @@
}
void DtlsTransport::OnReadPacket(rtc::PacketTransportInternal* transport,
- const char* data,
- size_t size,
- const int64_t& packet_time_us,
- int flags) {
+ const rtc::ReceivedPacket& packet) {
RTC_DCHECK_RUN_ON(&thread_checker_);
RTC_DCHECK(transport == ice_transport_);
- RTC_DCHECK(flags == 0);
if (!dtls_active_) {
// Not doing DTLS.
- SignalReadPacket(this, data, size, packet_time_us, 0);
+ NotifyPacketReceived(packet);
return;
}
@@ -615,11 +625,11 @@
"doing DTLS or not.";
}
// Cache a client hello packet received before DTLS has actually started.
- if (IsDtlsClientHelloPacket(data, size)) {
+ if (IsDtlsClientHelloPacket(packet.payload())) {
RTC_LOG(LS_INFO) << ToString()
<< ": Caching DTLS ClientHello packet until DTLS is "
"started.";
- cached_client_hello_.SetData(data, size);
+ cached_client_hello_.SetData(packet.payload());
// If we haven't started setting up DTLS yet (because we don't have a
// remote fingerprint/role), we can use the client hello as a clue that
// the peer has chosen the client role, and proceed with the handshake.
@@ -638,8 +648,8 @@
case webrtc::DtlsTransportState::kConnected:
// We should only get DTLS or SRTP packets; STUN's already been demuxed.
// Is this potentially a DTLS packet?
- if (IsDtlsPacket(data, size)) {
- if (!HandleDtlsPacket(data, size)) {
+ if (IsDtlsPacket(packet.payload())) {
+ if (!HandleDtlsPacket(packet.payload())) {
RTC_LOG(LS_ERROR) << ToString() << ": Failed to handle DTLS packet.";
return;
}
@@ -653,7 +663,7 @@
}
// And it had better be a SRTP packet.
- if (!IsRtpPacket(data, size)) {
+ if (!IsRtpPacket(packet.payload())) {
RTC_LOG(LS_ERROR)
<< ToString() << ": Received unexpected non-DTLS packet.";
return;
@@ -663,7 +673,8 @@
RTC_DCHECK(!srtp_ciphers_.empty());
// Signal this upwards as a bypass packet.
- SignalReadPacket(this, data, size, packet_time_us, PF_SRTP_BYPASS);
+ NotifyPacketReceived(
+ packet.CopyAndSet(rtc::ReceivedPacket::kSrtpEncrypted));
}
break;
case webrtc::DtlsTransportState::kFailed:
@@ -710,8 +721,13 @@
do {
ret = dtls_->Read(buf, read, read_error);
if (ret == rtc::SR_SUCCESS) {
- SignalReadPacket(this, reinterpret_cast<const char*>(buf), read,
- rtc::TimeMicros(), 0);
+ // TODO(bugs.webrtc.org/15368): It should be possible to use information
+ // from the original packet here to populate socket address and
+ // timestamp.
+ NotifyPacketReceived(rtc::ReceivedPacket(
+ rtc::MakeArrayView(buf, read), rtc::SocketAddress(),
+ webrtc::Timestamp::Micros(rtc::TimeMicros()),
+ rtc::ReceivedPacket::kDtlsDecrypted));
} else if (ret == rtc::SR_EOS) {
// Remote peer shut down the association with no error.
RTC_LOG(LS_INFO) << ToString() << ": DTLS transport closed by remote";
@@ -775,8 +791,7 @@
if (*dtls_role_ == rtc::SSL_SERVER) {
RTC_LOG(LS_INFO) << ToString()
<< ": Handling cached DTLS ClientHello packet.";
- if (!HandleDtlsPacket(cached_client_hello_.data<char>(),
- cached_client_hello_.size())) {
+ if (!HandleDtlsPacket(cached_client_hello_)) {
RTC_LOG(LS_ERROR) << ToString() << ": Failed to handle DTLS packet.";
}
} else {
@@ -790,11 +805,11 @@
}
// Called from OnReadPacket when a DTLS packet is received.
-bool DtlsTransport::HandleDtlsPacket(const char* data, size_t size) {
+bool DtlsTransport::HandleDtlsPacket(rtc::ArrayView<const uint8_t> payload) {
// Sanity check we're not passing junk that
// just looks like DTLS.
- const uint8_t* tmp_data = reinterpret_cast<const uint8_t*>(data);
- size_t tmp_size = size;
+ const uint8_t* tmp_data = payload.data();
+ size_t tmp_size = payload.size();
while (tmp_size > 0) {
if (tmp_size < kDtlsRecordHeaderLen)
return false; // Too short for the header
@@ -809,7 +824,8 @@
// Looks good. Pass to the SIC which ends up being passed to
// the DTLS stack.
- return downward_->OnPacketReceived(data, size);
+ return downward_->OnPacketReceived(
+ reinterpret_cast<const char*>(payload.data()), payload.size());
}
void DtlsTransport::set_receiving(bool receiving) {
diff --git a/p2p/base/dtls_transport.h b/p2p/base/dtls_transport.h
index 9408025..f479325 100644
--- a/p2p/base/dtls_transport.h
+++ b/p2p/base/dtls_transport.h
@@ -23,6 +23,7 @@
#include "p2p/base/ice_transport_internal.h"
#include "rtc_base/buffer.h"
#include "rtc_base/buffer_queue.h"
+#include "rtc_base/network/received_packet.h"
#include "rtc_base/ssl_stream_adapter.h"
#include "rtc_base/stream.h"
#include "rtc_base/strings/string_builder.h"
@@ -216,10 +217,7 @@
void OnWritableState(rtc::PacketTransportInternal* transport);
void OnReadPacket(rtc::PacketTransportInternal* transport,
- const char* data,
- size_t size,
- const int64_t& packet_time_us,
- int flags);
+ const rtc::ReceivedPacket& packet);
void OnSentPacket(rtc::PacketTransportInternal* transport,
const rtc::SentPacket& sent_packet);
void OnReadyToSend(rtc::PacketTransportInternal* transport);
@@ -228,7 +226,7 @@
void OnNetworkRouteChanged(absl::optional<rtc::NetworkRoute> network_route);
bool SetupDtls();
void MaybeStartDtls();
- bool HandleDtlsPacket(const char* data, size_t size);
+ bool HandleDtlsPacket(rtc::ArrayView<const uint8_t> payload);
void OnDtlsHandshakeError(rtc::SSLHandshakeError error);
void ConfigureHandshakeTimeout();
diff --git a/p2p/base/dtls_transport_unittest.cc b/p2p/base/dtls_transport_unittest.cc
index e338ab6..ddf1874 100644
--- a/p2p/base/dtls_transport_unittest.cc
+++ b/p2p/base/dtls_transport_unittest.cc
@@ -11,6 +11,8 @@
#include "p2p/base/dtls_transport.h"
#include <algorithm>
+#include <cstddef>
+#include <cstdint>
#include <memory>
#include <set>
#include <utility>
@@ -23,6 +25,7 @@
#include "rtc_base/dscp.h"
#include "rtc_base/gunit.h"
#include "rtc_base/helpers.h"
+#include "rtc_base/network/received_packet.h"
#include "rtc_base/rtc_certificate.h"
#include "rtc_base/ssl_adapter.h"
#include "rtc_base/ssl_identity.h"
@@ -82,6 +85,9 @@
}
// Set up fake ICE transport and real DTLS transport under test.
void SetupTransports(IceRole role, int async_delay_ms = 0) {
+ dtls_transport_ = nullptr;
+ fake_ice_transport_ = nullptr;
+
fake_ice_transport_.reset(new FakeIceTransport("fake", 0));
fake_ice_transport_->SetAsync(true);
fake_ice_transport_->SetAsyncDelay(async_delay_ms);
@@ -89,8 +95,11 @@
fake_ice_transport_->SetIceTiebreaker((role == ICEROLE_CONTROLLING) ? 1
: 2);
// Hook the raw packets so that we can verify they are encrypted.
- fake_ice_transport_->SignalReadPacket.connect(
- this, &DtlsTestClient::OnFakeIceTransportReadPacket);
+ fake_ice_transport_->RegisterReceivedPacketCallback(
+ this, [&](rtc::PacketTransportInternal* transport,
+ const rtc::ReceivedPacket& packet) {
+ OnFakeIceTransportReadPacket(transport, packet);
+ });
dtls_transport_ = std::make_unique<DtlsTransport>(
fake_ice_transport_.get(), webrtc::CryptoOptions(),
@@ -200,14 +209,14 @@
size_t NumPacketsReceived() { return received_.size(); }
// Inverse of SendPackets.
- bool VerifyPacket(const char* data, size_t size, uint32_t* out_num) {
+ bool VerifyPacket(const uint8_t* data, size_t size, uint32_t* out_num) {
if (size != packet_size_ ||
(data[0] != 0 && static_cast<uint8_t>(data[0]) != 0x80)) {
return false;
}
uint32_t packet_num = rtc::GetBE32(data + kPacketNumOffset);
for (size_t i = kPacketHeaderLen; i < size; ++i) {
- if (static_cast<uint8_t>(data[i]) != (packet_num & 0xff)) {
+ if (data[i] != (packet_num & 0xff)) {
return false;
}
}
@@ -216,7 +225,7 @@
}
return true;
}
- bool VerifyEncryptedPacket(const char* data, size_t size) {
+ bool VerifyEncryptedPacket(const uint8_t* data, size_t size) {
// This is an encrypted data packet; let's make sure it's mostly random;
// less than 10% of the bytes should be equal to the cleartext packet.
if (size <= packet_size_) {
@@ -225,7 +234,7 @@
uint32_t packet_num = rtc::GetBE32(data + kPacketNumOffset);
int num_matches = 0;
for (size_t i = kPacketNumOffset; i < size; ++i) {
- if (static_cast<uint8_t>(data[i]) == (packet_num & 0xff)) {
+ if (data[i] == (packet_num & 0xff)) {
++num_matches;
}
}
@@ -244,7 +253,8 @@
const int64_t& /* packet_time_us */,
int flags) {
uint32_t packet_num = 0;
- ASSERT_TRUE(VerifyPacket(data, size, &packet_num));
+ ASSERT_TRUE(VerifyPacket(reinterpret_cast<const uint8_t*>(data), size,
+ &packet_num));
received_.insert(packet_num);
// Only DTLS-SRTP packets should have the bypass flag set.
int expected_flags =
@@ -261,15 +271,14 @@
// Hook into the raw packet stream to make sure DTLS packets are encrypted.
void OnFakeIceTransportReadPacket(rtc::PacketTransportInternal* transport,
- const char* data,
- size_t size,
- const int64_t& /* packet_time_us */,
- int flags) {
- // Flags shouldn't be set on the underlying Transport packets.
- ASSERT_EQ(0, flags);
+ const rtc::ReceivedPacket& packet) {
+ // Packets should not be decrypted on the underlying Transport packets.
+ ASSERT_EQ(packet.decryption_info(), rtc::ReceivedPacket::kNotDecrypted);
// Look at the handshake packets to see what role we played.
// Check that non-handshake packets are DTLS data or SRTP bypass.
+ const uint8_t* data = packet.payload().data();
+ size_t size = packet.payload().size();
if (data[0] == 22 && size > 17) {
if (data[13] == 1) {
++received_dtls_client_hellos_;