[MessageHandler] Remove rtc::MessageHandler inheritance from StunRequest
This removes MessageHandler and Thread dependencies from StunRequest
and StunRequestManager. Instead the TaskQueueBase abstraction is
used for async posting and synchronous Clear() operations removed by
using a pending task safety flag.
Bug: webrtc:9702
Change-Id: I6e9ed5e1b4c446fd1f91af06e3ab36bccb5d7320
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/265060
Reviewed-by: Harald Alvestrand <hta@webrtc.org>
Reviewed-by: Niels Moller <nisse@webrtc.org>
Commit-Queue: Tomas Gunnarsson <tommi@webrtc.org>
Cr-Commit-Position: refs/heads/main@{#37218}
diff --git a/p2p/BUILD.gn b/p2p/BUILD.gn
index 1579f06..4874350 100644
--- a/p2p/BUILD.gn
+++ b/p2p/BUILD.gn
@@ -101,6 +101,7 @@
"../api/transport:enums",
"../api/transport:field_trial_based_config",
"../api/transport:stun_types",
+ "../api/units:time_delta",
"../logging:ice_log",
"../rtc_base",
"../rtc_base:async_resolver_interface",
diff --git a/p2p/base/stun_request.cc b/p2p/base/stun_request.cc
index 2a1ad65..c4d586c 100644
--- a/p2p/base/stun_request.cc
+++ b/p2p/base/stun_request.cc
@@ -20,12 +20,11 @@
#include "rtc_base/helpers.h"
#include "rtc_base/logging.h"
#include "rtc_base/string_encode.h"
+#include "rtc_base/task_utils/to_queued_task.h"
#include "rtc_base/time_utils.h" // For TimeMillis
namespace cricket {
-const uint32_t MSG_STUN_SEND = 1;
-
// RFC 5389 says SHOULD be 500ms.
// For years, this was 100ms, but for networks that
// experience moments of high RTT (such as 2G networks), this doesn't
@@ -44,7 +43,7 @@
const int STUN_MAX_RTO = 8000; // milliseconds, or 5 doublings
StunRequestManager::StunRequestManager(
- rtc::Thread* thread,
+ webrtc::TaskQueueBase* thread,
std::function<void(const void*, size_t, StunRequest*)> send_packet)
: thread_(thread), send_packet_(std::move(send_packet)) {}
@@ -60,20 +59,21 @@
auto [iter, was_inserted] =
requests_.emplace(request->id(), absl::WrapUnique(request));
RTC_DCHECK(was_inserted);
- if (delay > 0) {
- thread_->PostDelayed(RTC_FROM_HERE, delay, iter->second.get(),
- MSG_STUN_SEND, NULL);
- } else {
- thread_->Send(RTC_FROM_HERE, iter->second.get(), MSG_STUN_SEND, NULL);
- }
+ request->Send(webrtc::TimeDelta::Millis(delay));
}
void StunRequestManager::FlushForTest(int msg_type) {
RTC_DCHECK_RUN_ON(thread_);
for (const auto& [unused, request] : requests_) {
if (msg_type == kAllRequests || msg_type == request->type()) {
- thread_->Clear(request.get(), MSG_STUN_SEND);
- thread_->Send(RTC_FROM_HERE, request.get(), MSG_STUN_SEND, NULL);
+ // Calling `Send` implies starting the send operation which may be posted
+ // on a timer and be repeated on a timer until timeout. To make sure that
+ // a call to `Send` doesn't conflict with a previously started `Send`
+ // operation, we reset the `task_safety_` flag here, which has the effect
+ // of canceling any outstanding tasks and prepare a new flag for
+ // operations related to this call to `Send`.
+ request->ResetTasksForTest();
+ request->Send(webrtc::TimeDelta::Millis(0));
}
}
}
@@ -96,11 +96,8 @@
bool StunRequestManager::CheckResponse(StunMessage* msg) {
RTC_DCHECK_RUN_ON(thread_);
RequestMap::iterator iter = requests_.find(msg->transaction_id());
- if (iter == requests_.end()) {
- // TODO(pthatcher): Log unknown responses without being too spammy
- // in the logs.
+ if (iter == requests_.end())
return false;
- }
StunRequest* request = iter->second.get();
@@ -159,11 +156,8 @@
id.append(data + kStunTransactionIdOffset, kStunTransactionIdLength);
RequestMap::iterator iter = requests_.find(id);
- if (iter == requests_.end()) {
- // TODO(pthatcher): Log unknown responses without being too spammy
- // in the logs.
+ if (iter == requests_.end())
return false;
- }
// Parse the STUN message and continue processing as usual.
@@ -195,7 +189,9 @@
msg_(new StunMessage(STUN_INVALID_MESSAGE_TYPE)),
tstamp_(0),
count_(0),
- timeout_(false) {}
+ timeout_(false) {
+ RTC_DCHECK_RUN_ON(network_thread());
+}
StunRequest::StunRequest(StunRequestManager& manager,
std::unique_ptr<StunMessage> message)
@@ -204,12 +200,11 @@
tstamp_(0),
count_(0),
timeout_(false) {
+ RTC_DCHECK_RUN_ON(network_thread());
RTC_DCHECK(!msg_->transaction_id().empty());
}
-StunRequest::~StunRequest() {
- manager_.network_thread()->Clear(this);
-}
+StunRequest::~StunRequest() {}
int StunRequest::type() {
RTC_DCHECK(msg_ != NULL);
@@ -225,10 +220,8 @@
return static_cast<int>(rtc::TimeMillis() - tstamp_);
}
-void StunRequest::OnMessage(rtc::Message* pmsg) {
+void StunRequest::SendInternal() {
RTC_DCHECK_RUN_ON(network_thread());
- RTC_DCHECK(pmsg->message_id == MSG_STUN_SEND);
-
if (timeout_) {
OnTimeout();
manager_.OnRequestTimedOut(this);
@@ -242,8 +235,30 @@
manager_.SendPacket(buf.Data(), buf.Length(), this);
OnSent();
- manager_.network_thread()->PostDelayed(RTC_FROM_HERE, resend_delay(), this,
- MSG_STUN_SEND, NULL);
+ SendDelayed(webrtc::TimeDelta::Millis(resend_delay()));
+}
+
+void StunRequest::SendDelayed(webrtc::TimeDelta delay) {
+ network_thread()->PostDelayedTask(
+ webrtc::ToQueuedTask(task_safety_, [this]() { SendInternal(); }),
+ delay.ms());
+}
+
+void StunRequest::Send(webrtc::TimeDelta delay) {
+ RTC_DCHECK_RUN_ON(network_thread());
+ RTC_DCHECK_GE(delay.ms(), 0);
+
+ RTC_DCHECK(!task_safety_.flag()->alive()) << "Send already called?";
+ task_safety_.flag()->SetAlive();
+
+ delay.IsZero() ? SendInternal() : SendDelayed(delay);
+}
+
+void StunRequest::ResetTasksForTest() {
+ RTC_DCHECK_RUN_ON(network_thread());
+ task_safety_.reset(webrtc::PendingTaskSafetyFlag::CreateDetachedInactive());
+ count_ = 0;
+ RTC_DCHECK(!timeout_);
}
void StunRequest::OnSent() {
diff --git a/p2p/base/stun_request.h b/p2p/base/stun_request.h
index c6d1076..56d2597 100644
--- a/p2p/base/stun_request.h
+++ b/p2p/base/stun_request.h
@@ -19,9 +19,10 @@
#include <memory>
#include <string>
+#include "api/task_queue/task_queue_base.h"
#include "api/transport/stun.h"
-#include "rtc_base/message_handler.h"
-#include "rtc_base/thread.h"
+#include "api/units/time_delta.h"
+#include "rtc_base/task_utils/pending_task_safety_flag.h"
namespace cricket {
@@ -39,7 +40,7 @@
class StunRequestManager {
public:
StunRequestManager(
- rtc::Thread* thread,
+ webrtc::TaskQueueBase* thread,
std::function<void(const void*, size_t, StunRequest*)> send_packet);
~StunRequestManager();
@@ -50,10 +51,14 @@
// If `msg_type` is kAllRequests, sends all pending requests right away.
// Otherwise, sends those that have a matching type right away.
// Only for testing.
+ // TODO(tommi): Remove this method and update tests that use it to simulate
+ // production code.
void FlushForTest(int msg_type);
// Returns true if at least one request with `msg_type` is scheduled for
// transmission. For testing only.
+ // TODO(tommi): Remove this method and update tests that use it to simulate
+ // production code.
bool HasRequestForTest(int msg_type);
// Removes all stun requests that were added previously.
@@ -69,27 +74,26 @@
bool empty() const;
- // TODO(tommi): Use TaskQueueBase* instead of rtc::Thread.
- rtc::Thread* network_thread() const { return thread_; }
+ webrtc::TaskQueueBase* network_thread() const { return thread_; }
void SendPacket(const void* data, size_t size, StunRequest* request);
private:
typedef std::map<std::string, std::unique_ptr<StunRequest>> RequestMap;
- rtc::Thread* const thread_;
+ webrtc::TaskQueueBase* const thread_;
RequestMap requests_ RTC_GUARDED_BY(thread_);
const std::function<void(const void*, size_t, StunRequest*)> send_packet_;
};
// Represents an individual request to be sent. The STUN message can either be
// constructed beforehand or built on demand.
-class StunRequest : public rtc::MessageHandler {
+class StunRequest {
public:
explicit StunRequest(StunRequestManager& manager);
StunRequest(StunRequestManager& manager,
std::unique_ptr<StunMessage> message);
- ~StunRequest() override;
+ virtual ~StunRequest();
// The manager handling this request (if it has been scheduled for sending).
StunRequestManager* manager() { return &manager_; }
@@ -114,6 +118,13 @@
protected:
friend class StunRequestManager;
+ // Called by StunRequestManager.
+ void Send(webrtc::TimeDelta delay);
+
+ // Called from FlushForTest.
+ // TODO(tommi): Remove when FlushForTest gets removed.
+ void ResetTasksForTest();
+
StunMessage* mutable_msg() { return msg_.get(); }
// Called when the message receives a response or times out.
@@ -122,7 +133,7 @@
virtual void OnTimeout() {}
// Called when the message is sent.
virtual void OnSent();
- // Returns the next delay for resends.
+ // Returns the next delay for resends in milliseconds.
virtual int resend_delay();
webrtc::TaskQueueBase* network_thread() const {
@@ -132,14 +143,18 @@
void set_timed_out();
private:
- // Handles messages for sending and timeout.
- void OnMessage(rtc::Message* pmsg) override;
+ void SendInternal();
+ // Calls `PostDelayedTask` to queue up a call to SendInternal after the
+ // specified timeout.
+ void SendDelayed(webrtc::TimeDelta delay);
StunRequestManager& manager_;
const std::unique_ptr<StunMessage> msg_;
int64_t tstamp_ RTC_GUARDED_BY(network_thread());
int count_ RTC_GUARDED_BY(network_thread());
bool timeout_ RTC_GUARDED_BY(network_thread());
+ webrtc::ScopedTaskSafety task_safety_{
+ webrtc::PendingTaskSafetyFlag::CreateDetachedInactive()};
};
} // namespace cricket
diff --git a/rtc_base/task_utils/pending_task_safety_flag.h b/rtc_base/task_utils/pending_task_safety_flag.h
index 58772bc..604e058 100644
--- a/rtc_base/task_utils/pending_task_safety_flag.h
+++ b/rtc_base/task_utils/pending_task_safety_flag.h
@@ -119,11 +119,20 @@
class ScopedTaskSafety final {
public:
ScopedTaskSafety() = default;
+ explicit ScopedTaskSafety(rtc::scoped_refptr<PendingTaskSafetyFlag> flag)
+ : flag_(std::move(flag)) {}
~ScopedTaskSafety() { flag_->SetNotAlive(); }
// Returns a new reference to the safety flag.
rtc::scoped_refptr<PendingTaskSafetyFlag> flag() const { return flag_; }
+ // Marks the current flag as not-alive and attaches to a new one.
+ void reset(rtc::scoped_refptr<PendingTaskSafetyFlag> new_flag =
+ PendingTaskSafetyFlag::Create()) {
+ flag_->SetNotAlive();
+ flag_ = std::move(new_flag);
+ }
+
private:
rtc::scoped_refptr<PendingTaskSafetyFlag> flag_ =
PendingTaskSafetyFlag::Create();