Add SharedModuleThread class to share a module thread across Call instances.

This reduces the number of threads allocated per PeerConnection when
more than one PC is needed.

Bug: webrtc:11598
Change-Id: I3c1fd71705f90c4b4bbb1bc3f0f659c94016e69a
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/175904
Commit-Queue: Tommi <tommi@webrtc.org>
Reviewed-by: Erik Språng <sprang@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#31347}
diff --git a/api/test/create_time_controller.cc b/api/test/create_time_controller.cc
index d3b046b..1a49020 100644
--- a/api/test/create_time_controller.cc
+++ b/api/test/create_time_controller.cc
@@ -35,13 +35,17 @@
     explicit TimeControllerBasedCallFactory(TimeController* time_controller)
         : time_controller_(time_controller) {}
     Call* CreateCall(const Call::Config& config) override {
-      return Call::Create(config, time_controller_->GetClock(),
-                          time_controller_->CreateProcessThread("CallModules"),
+      if (!module_thread_) {
+        module_thread_ = SharedModuleThread::Create(
+            "CallModules", [this]() { module_thread_ = nullptr; });
+      }
+      return Call::Create(config, time_controller_->GetClock(), module_thread_,
                           time_controller_->CreateProcessThread("Pacer"));
     }
 
    private:
     TimeController* time_controller_;
+    rtc::scoped_refptr<SharedModuleThread> module_thread_;
   };
   return std::make_unique<TimeControllerBasedCallFactory>(time_controller);
 }
diff --git a/call/call.cc b/call/call.cc
index 210f72d..a4e21c9 100644
--- a/call/call.cc
+++ b/call/call.cc
@@ -177,7 +177,7 @@
   Call(Clock* clock,
        const Call::Config& config,
        std::unique_ptr<RtpTransportControllerSendInterface> transport_send,
-       std::unique_ptr<ProcessThread> module_process_thread,
+       rtc::scoped_refptr<SharedModuleThread> module_process_thread,
        TaskQueueFactory* task_queue_factory);
   ~Call() override;
 
@@ -270,7 +270,7 @@
   TaskQueueFactory* const task_queue_factory_;
 
   const int num_cpu_cores_;
-  const std::unique_ptr<ProcessThread> module_process_thread_;
+  const rtc::scoped_refptr<SharedModuleThread> module_process_thread_;
   const std::unique_ptr<CallStats> call_stats_;
   const std::unique_ptr<BitrateAllocator> bitrate_allocator_;
   Call::Config config_;
@@ -407,14 +407,20 @@
 }
 
 Call* Call::Create(const Call::Config& config) {
-  return Create(config, Clock::GetRealTimeClock(),
-                ProcessThread::Create("ModuleProcessThread"),
+  rtc::scoped_refptr<SharedModuleThread> call_thread =
+      SharedModuleThread::Create("ModuleProcessThread", nullptr);
+  return Create(config, std::move(call_thread));
+}
+
+Call* Call::Create(const Call::Config& config,
+                   rtc::scoped_refptr<SharedModuleThread> call_thread) {
+  return Create(config, Clock::GetRealTimeClock(), std::move(call_thread),
                 ProcessThread::Create("PacerThread"));
 }
 
 Call* Call::Create(const Call::Config& config,
                    Clock* clock,
-                   std::unique_ptr<ProcessThread> call_thread,
+                   rtc::scoped_refptr<SharedModuleThread> call_thread,
                    std::unique_ptr<ProcessThread> pacer_thread) {
   RTC_DCHECK(config.task_queue_factory);
   return new internal::Call(
@@ -426,6 +432,104 @@
       std::move(call_thread), config.task_queue_factory);
 }
 
+class SharedModuleThread::Impl {
+ public:
+  Impl(std::unique_ptr<ProcessThread> process_thread,
+       std::function<void()> on_one_ref_remaining)
+      : module_thread_(std::move(process_thread)),
+        on_one_ref_remaining_(std::move(on_one_ref_remaining)) {}
+
+  void EnsureStarted() {
+    RTC_DCHECK_RUN_ON(&sequence_checker_);
+    if (started_)
+      return;
+    started_ = true;
+    module_thread_->Start();
+  }
+
+  ProcessThread* process_thread() {
+    RTC_DCHECK_RUN_ON(&sequence_checker_);
+    return module_thread_.get();
+  }
+
+  void AddRef() const {
+    RTC_DCHECK_RUN_ON(&sequence_checker_);
+    ++ref_count_;
+  }
+
+  rtc::RefCountReleaseStatus Release() const {
+    RTC_DCHECK_RUN_ON(&sequence_checker_);
+    --ref_count_;
+
+    if (ref_count_ == 0) {
+      module_thread_->Stop();
+      return rtc::RefCountReleaseStatus::kDroppedLastRef;
+    }
+
+    if (ref_count_ == 1 && on_one_ref_remaining_) {
+      auto moved_fn = std::move(on_one_ref_remaining_);
+      // NOTE: after this function returns, chances are that |this| has been
+      // deleted - do not touch any member variables.
+      // If the owner of the last reference implements a lambda that releases
+      // that last reference inside of the callback (which is legal according
+      // to this implementation), we will recursively enter Release() above,
+      // call Stop() and release the last reference.
+      moved_fn();
+    }
+
+    return rtc::RefCountReleaseStatus::kOtherRefsRemained;
+  }
+
+ private:
+  SequenceChecker sequence_checker_;
+  mutable int ref_count_ RTC_GUARDED_BY(sequence_checker_) = 0;
+  std::unique_ptr<ProcessThread> const module_thread_;
+  std::function<void()> const on_one_ref_remaining_;
+  bool started_ = false;
+};
+
+SharedModuleThread::SharedModuleThread(
+    std::unique_ptr<ProcessThread> process_thread,
+    std::function<void()> on_one_ref_remaining)
+    : impl_(std::make_unique<Impl>(std::move(process_thread),
+                                   std::move(on_one_ref_remaining))) {}
+
+SharedModuleThread::~SharedModuleThread() = default;
+
+// static
+rtc::scoped_refptr<SharedModuleThread> SharedModuleThread::Create(
+    const char* name,
+    std::function<void()> on_one_ref_remaining) {
+  return new SharedModuleThread(ProcessThread::Create(name),
+                                std::move(on_one_ref_remaining));
+}
+
+rtc::scoped_refptr<SharedModuleThread> SharedModuleThread::Create(
+    std::unique_ptr<ProcessThread> process_thread,
+    std::function<void()> on_one_ref_remaining) {
+  return new SharedModuleThread(std::move(process_thread),
+                                std::move(on_one_ref_remaining));
+}
+
+void SharedModuleThread::EnsureStarted() {
+  impl_->EnsureStarted();
+}
+
+ProcessThread* SharedModuleThread::process_thread() {
+  return impl_->process_thread();
+}
+
+void SharedModuleThread::AddRef() const {
+  impl_->AddRef();
+}
+
+rtc::RefCountReleaseStatus SharedModuleThread::Release() const {
+  auto ret = impl_->Release();
+  if (ret == rtc::RefCountReleaseStatus::kDroppedLastRef)
+    delete this;
+  return ret;
+}
+
 // This method here to avoid subclasses has to implement this method.
 // Call perf test will use Internal::Call::CreateVideoSendStream() to inject
 // FecController.
@@ -441,7 +545,7 @@
 Call::Call(Clock* clock,
            const Call::Config& config,
            std::unique_ptr<RtpTransportControllerSendInterface> transport_send,
-           std::unique_ptr<ProcessThread> module_process_thread,
+           rtc::scoped_refptr<SharedModuleThread> module_process_thread,
            TaskQueueFactory* task_queue_factory)
     : clock_(clock),
       task_queue_factory_(task_queue_factory),
@@ -477,9 +581,10 @@
 
   call_stats_->RegisterStatsObserver(&receive_side_cc_);
 
-  module_process_thread_->RegisterModule(
+  module_process_thread_->process_thread()->RegisterModule(
       receive_side_cc_.GetRemoteBitrateEstimator(true), RTC_FROM_HERE);
-  module_process_thread_->RegisterModule(&receive_side_cc_, RTC_FROM_HERE);
+  module_process_thread_->process_thread()->RegisterModule(&receive_side_cc_,
+                                                           RTC_FROM_HERE);
 }
 
 Call::~Call() {
@@ -491,10 +596,9 @@
   RTC_CHECK(audio_receive_streams_.empty());
   RTC_CHECK(video_receive_streams_.empty());
 
-  module_process_thread_->Stop();
-  module_process_thread_->DeRegisterModule(
+  module_process_thread_->process_thread()->DeRegisterModule(
       receive_side_cc_.GetRemoteBitrateEstimator(true));
-  module_process_thread_->DeRegisterModule(&receive_side_cc_);
+  module_process_thread_->process_thread()->DeRegisterModule(&receive_side_cc_);
   call_stats_->DeregisterStatsObserver(&receive_side_cc_);
 
   absl::optional<Timestamp> first_sent_packet_ms =
@@ -523,7 +627,7 @@
   // off being kicked off on request rather than in the ctor.
   transport_send_ptr_->RegisterTargetTransferRateObserver(this);
 
-  module_process_thread_->Start();
+  module_process_thread_->EnsureStarted();
 }
 
 void Call::SetClientBitratePreferences(const BitrateSettings& preferences) {
@@ -632,7 +736,7 @@
 
   AudioSendStream* send_stream = new AudioSendStream(
       clock_, config, config_.audio_state, task_queue_factory_,
-      module_process_thread_.get(), transport_send_ptr_,
+      module_process_thread_->process_thread(), transport_send_ptr_,
       bitrate_allocator_.get(), event_log_, call_stats_->AsRtcpRttStats(),
       suspended_rtp_state);
   {
@@ -690,7 +794,7 @@
       CreateRtcLogStreamConfig(config)));
   AudioReceiveStream* receive_stream = new AudioReceiveStream(
       clock_, &audio_receiver_controller_, transport_send_ptr_->packet_router(),
-      module_process_thread_.get(), config_.neteq_factory, config,
+      module_process_thread_->process_thread(), config_.neteq_factory, config,
       config_.audio_state, event_log_);
   {
     WriteLockScoped write_lock(*receive_crit_);
@@ -761,8 +865,8 @@
   std::vector<uint32_t> ssrcs = config.rtp.ssrcs;
 
   VideoSendStream* send_stream = new VideoSendStream(
-      clock_, num_cpu_cores_, module_process_thread_.get(), task_queue_factory_,
-      call_stats_->AsRtcpRttStats(), transport_send_ptr_,
+      clock_, num_cpu_cores_, module_process_thread_->process_thread(),
+      task_queue_factory_, call_stats_->AsRtcpRttStats(), transport_send_ptr_,
       bitrate_allocator_.get(), video_send_delay_stats_.get(), event_log_,
       std::move(config), std::move(encoder_config), suspended_video_send_ssrcs_,
       suspended_video_payload_states_, std::move(fec_controller));
@@ -847,7 +951,7 @@
   VideoReceiveStream2* receive_stream = new VideoReceiveStream2(
       task_queue_factory_, current, &video_receiver_controller_, num_cpu_cores_,
       transport_send_ptr_->packet_router(), std::move(configuration),
-      module_process_thread_.get(), call_stats_.get(), clock_,
+      module_process_thread_->process_thread(), call_stats_.get(), clock_,
       new VCMTiming(clock_));
 
   const webrtc::VideoReceiveStream::Config& config = receive_stream->config();
@@ -921,7 +1025,8 @@
     // this locked scope.
     receive_stream = new FlexfecReceiveStreamImpl(
         clock_, &video_receiver_controller_, config, recovered_packet_receiver,
-        call_stats_->AsRtcpRttStats(), module_process_thread_.get());
+        call_stats_->AsRtcpRttStats(),
+        module_process_thread_->process_thread());
 
     RTC_DCHECK(receive_rtp_config_.find(config.remote_ssrc) ==
                receive_rtp_config_.end());
diff --git a/call/call.h b/call/call.h
index 77cd3d2..a6ce769 100644
--- a/call/call.h
+++ b/call/call.h
@@ -28,9 +28,46 @@
 #include "rtc_base/copy_on_write_buffer.h"
 #include "rtc_base/network/sent_packet.h"
 #include "rtc_base/network_route.h"
+#include "rtc_base/ref_count.h"
 
 namespace webrtc {
 
+// A restricted way to share the module process thread across multiple instances
+// of Call that are constructed on the same worker thread (which is what the
+// peer connection factory guarantees).
+// SharedModuleThread supports a callback that is issued when only one reference
+// remains, which is used to indicate to the original owner that the thread may
+// be discarded.
+class SharedModuleThread : public rtc::RefCountInterface {
+ protected:
+  SharedModuleThread(std::unique_ptr<ProcessThread> process_thread,
+                     std::function<void()> on_one_ref_remaining);
+  friend class rtc::scoped_refptr<SharedModuleThread>;
+  ~SharedModuleThread() override;
+
+ public:
+  // Instantiates a default implementation of ProcessThread.
+  static rtc::scoped_refptr<SharedModuleThread> Create(
+      const char* name,
+      std::function<void()> on_one_ref_remaining);
+
+  // Allows injection of an externally created process thread.
+  static rtc::scoped_refptr<SharedModuleThread> Create(
+      std::unique_ptr<ProcessThread> process_thread,
+      std::function<void()> on_one_ref_remaining);
+
+  void EnsureStarted();
+
+  ProcessThread* process_thread();
+
+ private:
+  void AddRef() const override;
+  rtc::RefCountReleaseStatus Release() const override;
+
+  class Impl;
+  mutable std::unique_ptr<Impl> impl_;
+};
+
 // A Call instance can contain several send and/or receive streams. All streams
 // are assumed to have the same remote endpoint and will share bitrate estimates
 // etc.
@@ -50,8 +87,10 @@
 
   static Call* Create(const Call::Config& config);
   static Call* Create(const Call::Config& config,
+                      rtc::scoped_refptr<SharedModuleThread> call_thread);
+  static Call* Create(const Call::Config& config,
                       Clock* clock,
-                      std::unique_ptr<ProcessThread> call_thread,
+                      rtc::scoped_refptr<SharedModuleThread> call_thread,
                       std::unique_ptr<ProcessThread> pacer_thread);
 
   virtual AudioSendStream* CreateAudioSendStream(
diff --git a/call/call_factory.cc b/call/call_factory.cc
index 6b4f419..a3ebc47 100644
--- a/call/call_factory.cc
+++ b/call/call_factory.cc
@@ -70,7 +70,12 @@
 }
 }  // namespace
 
+CallFactory::CallFactory() {
+  call_thread_.Detach();
+}
+
 Call* CallFactory::CreateCall(const Call::Config& config) {
+  RTC_DCHECK_RUN_ON(&call_thread_);
   absl::optional<webrtc::BuiltInNetworkBehaviorConfig> send_degradation_config =
       ParseDegradationConfig(true);
   absl::optional<webrtc::BuiltInNetworkBehaviorConfig>
@@ -82,7 +87,14 @@
                             config.task_queue_factory);
   }
 
-  return Call::Create(config);
+  if (!module_thread_) {
+    module_thread_ = SharedModuleThread::Create("SharedModThread", [this]() {
+      RTC_DCHECK_RUN_ON(&call_thread_);
+      module_thread_ = nullptr;
+    });
+  }
+
+  return Call::Create(config, module_thread_);
 }
 
 std::unique_ptr<CallFactoryInterface> CreateCallFactory() {
diff --git a/call/call_factory.h b/call/call_factory.h
index f0d695c..65c0b65 100644
--- a/call/call_factory.h
+++ b/call/call_factory.h
@@ -14,13 +14,22 @@
 #include "api/call/call_factory_interface.h"
 #include "call/call.h"
 #include "call/call_config.h"
+#include "rtc_base/synchronization/sequence_checker.h"
 
 namespace webrtc {
 
 class CallFactory : public CallFactoryInterface {
+ public:
+  CallFactory();
+
+ private:
   ~CallFactory() override {}
 
   Call* CreateCall(const CallConfig& config) override;
+
+  SequenceChecker call_thread_;
+  rtc::scoped_refptr<SharedModuleThread> module_thread_
+      RTC_GUARDED_BY(call_thread_);
 };
 
 }  // namespace webrtc
diff --git a/call/call_unittest.cc b/call/call_unittest.cc
index 8afcf25..0b05379 100644
--- a/call/call_unittest.cc
+++ b/call/call_unittest.cc
@@ -325,4 +325,58 @@
   }
 }
 
+TEST(CallTest, SharedModuleThread) {
+  class SharedModuleThreadUser : public Module {
+   public:
+    SharedModuleThreadUser(ProcessThread* expected_thread,
+                           rtc::scoped_refptr<SharedModuleThread> thread)
+        : expected_thread_(expected_thread), thread_(std::move(thread)) {
+      thread_->EnsureStarted();
+      thread_->process_thread()->RegisterModule(this, RTC_FROM_HERE);
+    }
+
+    ~SharedModuleThreadUser() override {
+      thread_->process_thread()->DeRegisterModule(this);
+      EXPECT_TRUE(thread_was_checked_);
+    }
+
+   private:
+    int64_t TimeUntilNextProcess() override { return 1000; }
+    void Process() override {}
+    void ProcessThreadAttached(ProcessThread* process_thread) override {
+      if (!process_thread) {
+        // Being detached.
+        return;
+      }
+      EXPECT_EQ(process_thread, expected_thread_);
+      thread_was_checked_ = true;
+    }
+
+    bool thread_was_checked_ = false;
+    ProcessThread* const expected_thread_;
+    rtc::scoped_refptr<SharedModuleThread> thread_;
+  };
+
+  // Create our test instance and pass a lambda to it that gets executed when
+  // the reference count goes back to 1 - meaning |shared| again is the only
+  // reference, which means we can free the variable and deallocate the thread.
+  rtc::scoped_refptr<SharedModuleThread> shared;
+  shared = SharedModuleThread::Create("MySharedProcessThread",
+                                      [&shared]() { shared = nullptr; });
+  ProcessThread* process_thread = shared->process_thread();
+
+  ASSERT_TRUE(shared.get());
+
+  {
+    // Create a couple of users of the thread.
+    // These instances are in a separate scope to trigger the callback to our
+    // lambda, which will run when these go out of scope.
+    SharedModuleThreadUser user1(process_thread, shared);
+    SharedModuleThreadUser user2(process_thread, shared);
+  }
+
+  // The thread should now have been stopped and freed.
+  EXPECT_FALSE(shared);
+}
+
 }  // namespace webrtc
diff --git a/test/scenario/call_client.cc b/test/scenario/call_client.cc
index fb888df6..0107497 100644
--- a/test/scenario/call_client.cc
+++ b/test/scenario/call_client.cc
@@ -54,7 +54,8 @@
                  RtcEventLog* event_log,
                  CallClientConfig config,
                  LoggingNetworkControllerFactory* network_controller_factory,
-                 rtc::scoped_refptr<AudioState> audio_state) {
+                 rtc::scoped_refptr<AudioState> audio_state,
+                 rtc::scoped_refptr<SharedModuleThread> call_thread) {
   CallConfig call_config(event_log);
   call_config.bitrate_config.max_bitrate_bps =
       config.transport.rates.max_rate.bps_or(-1);
@@ -67,7 +68,7 @@
   call_config.audio_state = audio_state;
   call_config.trials = config.field_trials;
   return Call::Create(call_config, time_controller->GetClock(),
-                      time_controller->CreateProcessThread("CallModules"),
+                      std::move(call_thread),
                       time_controller->CreateProcessThread("Pacer"));
 }
 
@@ -213,9 +214,14 @@
     event_log_ = CreateEventLog(time_controller_->GetTaskQueueFactory(),
                                 log_writer_factory_.get());
     fake_audio_setup_ = InitAudio(time_controller_);
+    RTC_DCHECK(!module_thread_);
+    module_thread_ = SharedModuleThread::Create(
+        time_controller_->CreateProcessThread("CallThread"),
+        [this]() { module_thread_ = nullptr; });
+
     call_.reset(CreateCall(time_controller_, event_log_.get(), config,
                            &network_controller_factory_,
-                           fake_audio_setup_.audio_state));
+                           fake_audio_setup_.audio_state, module_thread_));
     transport_ = std::make_unique<NetworkNodeTransport>(clock_, call_.get());
   });
 }
@@ -223,6 +229,7 @@
 CallClient::~CallClient() {
   SendTask([&] {
     call_.reset();
+    RTC_DCHECK(!module_thread_);  // Should be set to null in the lambda above.
     fake_audio_setup_ = {};
     rtc::Event done;
     event_log_->StopLogging([&done] { done.Set(); });
diff --git a/test/scenario/call_client.h b/test/scenario/call_client.h
index 33fa276..80814eb 100644
--- a/test/scenario/call_client.h
+++ b/test/scenario/call_client.h
@@ -157,6 +157,8 @@
   // Defined last so it's destroyed first.
   TaskQueueForTest task_queue_;
 
+  rtc::scoped_refptr<SharedModuleThread> module_thread_;
+
   const FieldTrialBasedConfig field_trials_;
 };