Make sure that "current" rtc::Thread instances are always current for TaskQueueBase.

This is a necessary part of fulfilling the TaskQueueBase
interface. If a thread does not register as the current TQ, yet offers
the TQ interface, TQ 'current' checks will not work as expected and
code that relies them (TaskQueueBase::Current() and IsCurrent())
will run in unexpected ways.

Bug: webrtc:11572
Change-Id: Iab747bc474e74e6ce4f9e914cfd5b0578b19d19c
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/175080
Reviewed-by: Mirko Bonadei <mbonadei@webrtc.org>
Commit-Queue: Tommi <tommi@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#31254}
diff --git a/api/task_queue/task_queue_test.cc b/api/task_queue/task_queue_test.cc
index a8a799f..3f638b7 100644
--- a/api/task_queue/task_queue_test.cc
+++ b/api/task_queue/task_queue_test.cc
@@ -37,9 +37,11 @@
   rtc::Event event;
   auto queue = CreateTaskQueue(factory, "PostAndCheckCurrent");
 
-  // We're not running a task, so there shouldn't be a current queue.
+  // We're not running a task, so |queue| shouldn't be current.
+  // Note that because rtc::Thread also supports the TQ interface and
+  // TestMainImpl::Init wraps the main test thread (bugs.webrtc.org/9714), that
+  // means that TaskQueueBase::Current() will still return a valid value.
   EXPECT_FALSE(queue->IsCurrent());
-  EXPECT_FALSE(TaskQueueBase::Current());
 
   queue->PostTask(ToQueuedTask([&event, &queue] {
     EXPECT_TRUE(queue->IsCurrent());
diff --git a/rtc_base/thread.cc b/rtc_base/thread.cc
index 0fb2e81..5e48e4b 100644
--- a/rtc_base/thread.cc
+++ b/rtc_base/thread.cc
@@ -296,6 +296,21 @@
     RTC_DLOG(LS_ERROR) << "SetCurrentThread: Overwriting an existing value?";
   }
 #endif  // RTC_DLOG_IS_ON
+
+  if (thread) {
+    thread->EnsureIsCurrentTaskQueue();
+  } else {
+    Thread* current = CurrentThread();
+    if (current) {
+      // The current thread is being cleared, e.g. as a result of
+      // UnwrapCurrent() being called or when a thread is being stopped
+      // (see PreRun()). This signals that the Thread instance is being detached
+      // from the thread, which also means that TaskQueue::Current() must not
+      // return a pointer to the Thread instance.
+      current->ClearCurrentTaskQueue();
+    }
+  }
+
   SetCurrentThreadInternal(thread);
 }
 
@@ -824,7 +839,6 @@
   Thread* thread = static_cast<Thread*>(pv);
   ThreadManager::Instance()->SetCurrentThread(thread);
   rtc::SetCurrentThreadName(thread->name_.c_str());
-  CurrentTaskQueueSetter set_current_task_queue(thread);
 #if defined(WEBRTC_MAC)
   ScopedAutoReleasePool pool;
 #endif
@@ -935,6 +949,17 @@
   Send(posted_from, &handler);
 }
 
+// Called by the ThreadManager when being set as the current thread.
+void Thread::EnsureIsCurrentTaskQueue() {
+  task_queue_registration_ =
+      std::make_unique<TaskQueueBase::CurrentTaskQueueSetter>(this);
+}
+
+// Called by the ThreadManager when being set as the current thread.
+void Thread::ClearCurrentTaskQueue() {
+  task_queue_registration_.reset();
+}
+
 void Thread::QueuedTaskHandler::OnMessage(Message* msg) {
   RTC_DCHECK(msg);
   auto* data = static_cast<ScopedMessageData<webrtc::QueuedTask>*>(msg->pdata);
diff --git a/rtc_base/thread.h b/rtc_base/thread.h
index 74aab62..e25ed4e 100644
--- a/rtc_base/thread.h
+++ b/rtc_base/thread.h
@@ -551,6 +551,12 @@
   void InvokeInternal(const Location& posted_from,
                       rtc::FunctionView<void()> functor);
 
+  // Called by the ThreadManager when being set as the current thread.
+  void EnsureIsCurrentTaskQueue();
+
+  // Called by the ThreadManager when being unset as the current thread.
+  void ClearCurrentTaskQueue();
+
   // Returns a static-lifetime MessageHandler which runs message with
   // MessageLikeTask payload data.
   static MessageHandler* GetPostTaskMessageHandler();
@@ -595,6 +601,8 @@
 
   // Runs webrtc::QueuedTask posted to the Thread.
   QueuedTaskHandler queued_task_handler_;
+  std::unique_ptr<TaskQueueBase::CurrentTaskQueueSetter>
+      task_queue_registration_;
 
   friend class ThreadManager;
 
diff --git a/rtc_base/thread_unittest.cc b/rtc_base/thread_unittest.cc
index d53a387..e1011f4 100644
--- a/rtc_base/thread_unittest.cc
+++ b/rtc_base/thread_unittest.cc
@@ -1148,6 +1148,18 @@
   EXPECT_TRUE(fourth.Wait(0));
 }
 
+TEST(ThreadPostDelayedTaskTest, IsCurrentTaskQueue) {
+  auto current_tq = webrtc::TaskQueueBase::Current();
+  {
+    std::unique_ptr<rtc::Thread> thread(rtc::Thread::Create());
+    thread->WrapCurrent();
+    EXPECT_EQ(webrtc::TaskQueueBase::Current(),
+              static_cast<webrtc::TaskQueueBase*>(thread.get()));
+    thread->UnwrapCurrent();
+  }
+  EXPECT_EQ(webrtc::TaskQueueBase::Current(), current_tq);
+}
+
 class ThreadFactory : public webrtc::TaskQueueFactory {
  public:
   std::unique_ptr<webrtc::TaskQueueBase, webrtc::TaskQueueDeleter>
diff --git a/test/run_loop_unittest.cc b/test/run_loop_unittest.cc
index a356cc2..160aba0 100644
--- a/test/run_loop_unittest.cc
+++ b/test/run_loop_unittest.cc
@@ -17,7 +17,6 @@
 namespace webrtc {
 
 TEST(RunLoopTest, TaskQueueOnThread) {
-  EXPECT_EQ(TaskQueueBase::Current(), nullptr);
   test::RunLoop loop;
   EXPECT_EQ(TaskQueueBase::Current(), loop.task_queue());
   EXPECT_TRUE(loop.task_queue()->IsCurrent());