Cleanup implemenation of AudioState SetRecording/SetPlayout vs. Add/Remove {Send/Recv}stream

So that they behave in the most obvious ways:
Set{Recording/Playout} = TRUE
  - Enables {Recording/Playout} is there are {Send/Recv} streams
  - Set state variable

Set{Recording/Plaout} = FALSE
  - Disable {Recording/Playout}
  - Set state variable

Add {Send/Recv} stream
  - Enables {Recording/Playout} if state variable is TRUE
  - Otherwise does nothing

Remove {Send/Recv} stream
  - Disable {Recording/Playout} if last stream
  - Otherwise does nothing

---

Before this patch the behavior was hard to non obvious,
e.g SetRecording(false) followed by SetRecording(true)
did not work (same for playout).


BUG=b/397376626

Change-Id: I530497d4a46ad73334fcb3d73f4b87264bd18486
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/378740
Reviewed-by: Jakob Ivarsson‎ <jakobi@webrtc.org>
Commit-Queue: Jonas Oreland <jonaso@webrtc.org>
Reviewed-by: Harald Alvestrand <hta@webrtc.org>
Cr-Commit-Position: refs/heads/main@{#44025}
diff --git a/audio/audio_receive_stream.h b/audio/audio_receive_stream.h
index 63bd062..ccabb3d 100644
--- a/audio/audio_receive_stream.h
+++ b/audio/audio_receive_stream.h
@@ -106,6 +106,7 @@
   bool SetBaseMinimumPlayoutDelayMs(int delay_ms) override;
   int GetBaseMinimumPlayoutDelayMs() const override;
   std::vector<webrtc::RtpSource> GetSources() const override;
+  AudioMixer::Source* source() override { return this; }
 
   // AudioMixer::Source
   AudioFrameInfo GetAudioFrameWithInfo(int sample_rate_hz,
diff --git a/audio/audio_state.cc b/audio/audio_state.cc
index 3c96802..c810398 100644
--- a/audio/audio_state.cc
+++ b/audio/audio_state.cc
@@ -51,28 +51,47 @@
   return &audio_transport_;
 }
 
+void AudioState::SetPlayout(bool enabled) {
+  RTC_LOG(LS_INFO) << "SetPlayout(" << enabled << ")";
+  RTC_DCHECK_RUN_ON(&thread_checker_);
+  auto* adm = config_.audio_device_module.get();
+  if (enabled) {
+    if (!receiving_streams_.empty()) {
+      if (!adm->Playing()) {
+        if (adm->InitPlayout() == 0) {
+          adm->StartPlayout();
+        }
+      }
+    }
+  } else {
+    // Disable playout.
+    config_.audio_device_module->StopPlayout();
+  }
+  playout_enabled_ = enabled;
+  UpdateNullAudioPollerState();
+}
+
 void AudioState::AddReceivingStream(
     webrtc::AudioReceiveStreamInterface* stream) {
   RTC_DCHECK_RUN_ON(&thread_checker_);
   RTC_DCHECK_EQ(0, receiving_streams_.count(stream));
   receiving_streams_.insert(stream);
-  if (!config_.audio_mixer->AddSource(
-          static_cast<AudioReceiveStreamImpl*>(stream))) {
+  if (!config_.audio_mixer->AddSource(stream->source())) {
     RTC_DLOG(LS_ERROR) << "Failed to add source to mixer.";
   }
 
   // Make sure playback is initialized; start playing if enabled.
-  UpdateNullAudioPollerState();
-  auto* adm = config_.audio_device_module.get();
-  if (!adm->Playing()) {
-    if (adm->InitPlayout() == 0) {
-      if (playout_enabled_) {
+  if (playout_enabled_) {
+    auto* adm = config_.audio_device_module.get();
+    if (!adm->Playing()) {
+      if (adm->InitPlayout() == 0) {
         adm->StartPlayout();
       }
     } else {
       RTC_DLOG_F(LS_ERROR) << "Failed to initialize playout.";
     }
   }
+  UpdateNullAudioPollerState();
 }
 
 void AudioState::RemoveReceivingStream(
@@ -80,12 +99,30 @@
   RTC_DCHECK_RUN_ON(&thread_checker_);
   auto count = receiving_streams_.erase(stream);
   RTC_DCHECK_EQ(1, count);
-  config_.audio_mixer->RemoveSource(
-      static_cast<AudioReceiveStreamImpl*>(stream));
-  UpdateNullAudioPollerState();
+  config_.audio_mixer->RemoveSource(stream->source());
   if (receiving_streams_.empty()) {
     config_.audio_device_module->StopPlayout();
   }
+  UpdateNullAudioPollerState();
+}
+
+void AudioState::SetRecording(bool enabled) {
+  RTC_LOG(LS_INFO) << "SetRecording(" << enabled << ")";
+  RTC_DCHECK_RUN_ON(&thread_checker_);
+  auto* adm = config_.audio_device_module.get();
+  if (enabled) {
+    if (!sending_streams_.empty()) {
+      if (!adm->Recording()) {
+        if (adm->InitRecording() == 0) {
+          adm->StartRecording();
+        }
+      }
+    }
+  } else {
+    // Disable recording.
+    adm->StopRecording();
+  }
+  recording_enabled_ = enabled;
 }
 
 void AudioState::AddSendingStream(webrtc::AudioSendStream* stream,
@@ -99,9 +136,9 @@
 
   // Make sure recording is initialized; start recording if enabled.
   auto* adm = config_.audio_device_module.get();
-  if (!adm->Recording()) {
-    if (adm->InitRecording() == 0) {
-      if (recording_enabled_) {
+  if (recording_enabled_) {
+    if (!adm->Recording()) {
+      if (adm->InitRecording() == 0) {
         adm->StartRecording();
       }
     } else {
@@ -120,46 +157,6 @@
   }
 }
 
-void AudioState::SetPlayout(bool enabled) {
-  RTC_LOG(LS_INFO) << "SetPlayout(" << enabled << ")";
-  RTC_DCHECK_RUN_ON(&thread_checker_);
-  if (playout_enabled_ != enabled) {
-    playout_enabled_ = enabled;
-    if (enabled) {
-      UpdateNullAudioPollerState();
-      if (!receiving_streams_.empty()) {
-        config_.audio_device_module->StartPlayout();
-      }
-    } else {
-      config_.audio_device_module->StopPlayout();
-      UpdateNullAudioPollerState();
-    }
-  }
-}
-
-void AudioState::SetRecording(bool enabled) {
-  RTC_LOG(LS_INFO) << "SetRecording(" << enabled << ")";
-  RTC_DCHECK_RUN_ON(&thread_checker_);
-  auto* adm = config_.audio_device_module.get();
-  if (recording_enabled_ != enabled) {
-    auto* adm = config_.audio_device_module.get();
-    recording_enabled_ = enabled;
-    if (enabled) {
-      if (!sending_streams_.empty()) {
-        if (adm->InitRecording() == 0) {
-          adm->StartRecording();
-        }
-      }
-    } else {
-      adm->StopRecording();
-    }
-  } else if (!enabled && adm->RecordingIsInitialized()) {
-    // The recording can also be initialized by WebRtcVoiceSendChannel
-    // options_.init_recording_on_send.
-    adm->StopRecording();
-  }
-}
-
 void AudioState::SetStereoChannelSwapping(bool enable) {
   RTC_DCHECK(thread_checker_.IsCurrent());
   audio_transport_.SetStereoChannelSwapping(enable);
diff --git a/audio/audio_state_unittest.cc b/audio/audio_state_unittest.cc
index c2b55bb..6707f61 100644
--- a/audio/audio_state_unittest.cc
+++ b/audio/audio_state_unittest.cc
@@ -15,10 +15,12 @@
 #include <vector>
 
 #include "api/task_queue/test/mock_task_queue_base.h"
+#include "call/test/mock_audio_receive_stream.h"
 #include "call/test/mock_audio_send_stream.h"
 #include "modules/audio_device/include/mock_audio_device.h"
 #include "modules/audio_mixer/audio_mixer_impl.h"
 #include "modules/audio_processing/include/mock_audio_processing.h"
+#include "rtc_base/thread.h"
 #include "test/gtest.h"
 
 namespace webrtc {
@@ -358,6 +360,38 @@
       audio_buffer, n_samples_out, &elapsed_time_ms, &ntp_time_ms);
 }
 
+TEST_P(AudioStateTest, StartRecordingDoesNothingWithoutStream) {
+  ConfigHelper helper(GetParam());
+  rtc::scoped_refptr<internal::AudioState> audio_state(
+      rtc::make_ref_counted<internal::AudioState>(helper.config()));
+
+  auto* adm = reinterpret_cast<MockAudioDeviceModule*>(
+      helper.config().audio_device_module.get());
+
+  EXPECT_CALL(*adm, InitRecording()).Times(0);
+  EXPECT_CALL(*adm, StartRecording()).Times(0);
+  EXPECT_CALL(*adm, StopRecording()).Times(1);
+  audio_state->SetRecording(false);
+  audio_state->SetRecording(true);
+}
+
+TEST_P(AudioStateTest, AddStreamDoesNothingIfRecordingDisabled) {
+  ConfigHelper helper(GetParam());
+  rtc::scoped_refptr<internal::AudioState> audio_state(
+      rtc::make_ref_counted<internal::AudioState>(helper.config()));
+
+  auto* adm = reinterpret_cast<MockAudioDeviceModule*>(
+      helper.config().audio_device_module.get());
+
+  EXPECT_CALL(*adm, StopRecording()).Times(2);
+  audio_state->SetRecording(false);
+
+  MockAudioSendStream stream;
+  EXPECT_CALL(*adm, StartRecording).Times(0);
+  audio_state->AddSendingStream(&stream, kSampleRate, kNumberOfChannels);
+  audio_state->RemoveSendingStream(&stream);
+}
+
 TEST_P(AudioStateTest, AlwaysCallInitRecordingBeforeStartRecording) {
   ConfigHelper helper(GetParam());
   rtc::scoped_refptr<internal::AudioState> audio_state(
@@ -401,11 +435,97 @@
 
   audio_state->SetRecording(false);
 
-  EXPECT_CALL(*adm, RecordingIsInitialized()).WillOnce(testing::Return(true));
   EXPECT_CALL(*adm, StopRecording());
   audio_state->SetRecording(false);
 }
 
+TEST_P(AudioStateTest, StartPlayoutDoesNothingWithoutStream) {
+  ConfigHelper helper(GetParam());
+  rtc::scoped_refptr<internal::AudioState> audio_state(
+      rtc::make_ref_counted<internal::AudioState>(helper.config()));
+
+  auto* adm = reinterpret_cast<MockAudioDeviceModule*>(
+      helper.config().audio_device_module.get());
+
+  EXPECT_CALL(*adm, InitPlayout()).Times(0);
+  EXPECT_CALL(*adm, StartPlayout()).Times(0);
+  EXPECT_CALL(*adm, StopPlayout()).Times(1);
+  audio_state->SetPlayout(false);
+
+  audio_state->SetPlayout(true);
+}
+
+TEST_P(AudioStateTest, AlwaysCallInitPlayoutBeforeStartPlayout) {
+  ConfigHelper helper(GetParam());
+  rtc::scoped_refptr<internal::AudioState> audio_state(
+      rtc::make_ref_counted<internal::AudioState>(helper.config()));
+
+  auto* adm = reinterpret_cast<MockAudioDeviceModule*>(
+      helper.config().audio_device_module.get());
+
+  MockAudioReceiveStream stream;
+  {
+    InSequence s;
+    EXPECT_CALL(*adm, InitPlayout());
+    EXPECT_CALL(*adm, StartPlayout());
+    audio_state->AddReceivingStream(&stream);
+  }
+
+  // SetPlayout(false) starts the NullAudioPoller...which needs a thread.
+  rtc::ThreadManager::Instance()->WrapCurrentThread();
+
+  EXPECT_CALL(*adm, StopPlayout());
+  audio_state->SetPlayout(false);
+
+  {
+    InSequence s;
+    EXPECT_CALL(*adm, InitPlayout());
+    EXPECT_CALL(*adm, StartPlayout());
+    audio_state->SetPlayout(true);
+  }
+
+  // Playout without streams starts the NullAudioPoller...
+  // which needs a thread.
+  rtc::ThreadManager::Instance()->WrapCurrentThread();
+
+  EXPECT_CALL(*adm, StopPlayout());
+  audio_state->RemoveReceivingStream(&stream);
+}
+
+TEST_P(AudioStateTest, CallStopPlayoutIfPlayoutIsInitialized) {
+  ConfigHelper helper(GetParam());
+  rtc::scoped_refptr<internal::AudioState> audio_state(
+      rtc::make_ref_counted<internal::AudioState>(helper.config()));
+
+  auto* adm = reinterpret_cast<MockAudioDeviceModule*>(
+      helper.config().audio_device_module.get());
+
+  audio_state->SetPlayout(false);
+
+  EXPECT_CALL(*adm, StopPlayout());
+  audio_state->SetPlayout(false);
+}
+
+TEST_P(AudioStateTest, AddStreamDoesNothingIfPlayoutDisabled) {
+  ConfigHelper helper(GetParam());
+  rtc::scoped_refptr<internal::AudioState> audio_state(
+      rtc::make_ref_counted<internal::AudioState>(helper.config()));
+
+  auto* adm = reinterpret_cast<MockAudioDeviceModule*>(
+      helper.config().audio_device_module.get());
+
+  EXPECT_CALL(*adm, StopPlayout()).Times(2);
+  audio_state->SetPlayout(false);
+
+  // AddReceivingStream with playout disabled start the NullAudioPoller...
+  // which needs a thread.
+  rtc::ThreadManager::Instance()->WrapCurrentThread();
+
+  MockAudioReceiveStream stream;
+  audio_state->AddReceivingStream(&stream);
+  audio_state->RemoveReceivingStream(&stream);
+}
+
 INSTANTIATE_TEST_SUITE_P(AudioStateTest,
                          AudioStateTest,
                          Values(ConfigHelper::Params({false, false}),
diff --git a/call/BUILD.gn b/call/BUILD.gn
index 8833fe9..09cd9eb 100644
--- a/call/BUILD.gn
+++ b/call/BUILD.gn
@@ -768,9 +768,13 @@
   rtc_source_set("mock_call_interfaces") {
     testonly = true
 
-    sources = [ "test/mock_audio_send_stream.h" ]
+    sources = [
+      "test/mock_audio_receive_stream.h",
+      "test/mock_audio_send_stream.h",
+    ]
     deps = [
       ":call_interfaces",
+      "../api/audio:audio_mixer_api",
       "../test:test_support",
     ]
   }
diff --git a/call/audio_receive_stream.h b/call/audio_receive_stream.h
index 9e64521..42502f5 100644
--- a/call/audio_receive_stream.h
+++ b/call/audio_receive_stream.h
@@ -17,6 +17,7 @@
 #include <optional>
 #include <string>
 
+#include "api/audio/audio_mixer.h"
 #include "api/audio_codecs/audio_codec_pair_id.h"
 #include "api/audio_codecs/audio_decoder_factory.h"
 #include "api/audio_codecs/audio_format.h"
@@ -210,6 +211,10 @@
   // post initialization.
   virtual uint32_t remote_ssrc() const = 0;
 
+  // Get the object suitable to inject into the AudioMixer
+  // (normally "this").
+  virtual AudioMixer::Source* source() = 0;
+
  protected:
   virtual ~AudioReceiveStreamInterface() {}
 };
diff --git a/call/test/mock_audio_receive_stream.h b/call/test/mock_audio_receive_stream.h
new file mode 100644
index 0000000..3613d15
--- /dev/null
+++ b/call/test/mock_audio_receive_stream.h
@@ -0,0 +1,72 @@
+/*
+ *  Copyright (c) 2025 The WebRTC project authors. All Rights Reserved.
+ *
+ *  Use of this source code is governed by a BSD-style license
+ *  that can be found in the LICENSE file in the root of the source
+ *  tree. An additional intellectual property rights grant can be found
+ *  in the file PATENTS.  All contributing project authors may
+ *  be found in the AUTHORS file in the root of the source tree.
+ */
+
+#ifndef CALL_TEST_MOCK_AUDIO_RECEIVE_STREAM_H_
+#define CALL_TEST_MOCK_AUDIO_RECEIVE_STREAM_H_
+
+#include <map>
+#include <vector>
+
+#include "api/audio/audio_mixer.h"
+#include "call/audio_receive_stream.h"
+#include "test/gmock.h"
+
+namespace webrtc {
+namespace test {
+
+class MockAudioReceiveStream : public AudioReceiveStreamInterface,
+                               public AudioMixer::Source {
+ public:
+  MOCK_METHOD(uint32_t, remote_ssrc, (), (const override));
+  MOCK_METHOD(void, Start, (), (override));
+  MOCK_METHOD(void, Stop, (), (override));
+  MOCK_METHOD(bool, IsRunning, (), (const override));
+  MOCK_METHOD(void,
+              SetDepacketizerToDecoderFrameTransformer,
+              (rtc::scoped_refptr<webrtc::FrameTransformerInterface>),
+              (override));
+  MOCK_METHOD(void,
+              SetDecoderMap,
+              ((std::map<int, webrtc::SdpAudioFormat>)),
+              (override));
+  MOCK_METHOD(void, SetNackHistory, (int), (override));
+  MOCK_METHOD(void, SetRtcpMode, (webrtc::RtcpMode), (override));
+  MOCK_METHOD(void, SetNonSenderRttMeasurement, (bool), (override));
+  MOCK_METHOD(void,
+              SetFrameDecryptor,
+              (rtc::scoped_refptr<webrtc::FrameDecryptorInterface>),
+              (override));
+
+  MOCK_METHOD(webrtc::AudioReceiveStreamInterface::Stats,
+              GetStats,
+              (bool),
+              (const override));
+  MOCK_METHOD(void, SetSink, (webrtc::AudioSinkInterface*), (override));
+  MOCK_METHOD(void, SetGain, (float), (override));
+  MOCK_METHOD(bool, SetBaseMinimumPlayoutDelayMs, (int), (override));
+  MOCK_METHOD(int, GetBaseMinimumPlayoutDelayMs, (), (const override));
+  MOCK_METHOD(std::vector<webrtc::RtpSource>, GetSources, (), (const override));
+
+  // TODO (b/397376626): Create a MockAudioMixerSource, and instead
+  // have a member variable here.
+  AudioMixer::Source* source() override { return this; }
+
+  MOCK_METHOD(AudioFrameInfo,
+              GetAudioFrameWithInfo,
+              (int, AudioFrame*),
+              (override));
+  MOCK_METHOD(int, Ssrc, (), (const override));
+  MOCK_METHOD(int, PreferredSampleRate, (), (const override));
+};
+
+}  // namespace test
+}  // namespace webrtc
+
+#endif  // CALL_TEST_MOCK_AUDIO_RECEIVE_STREAM_H_
diff --git a/media/BUILD.gn b/media/BUILD.gn
index e88ef6d..75fd8bd 100644
--- a/media/BUILD.gn
+++ b/media/BUILD.gn
@@ -831,6 +831,7 @@
       "../api:scoped_refptr",
       "../api/adaptation:resource_adaptation_api",
       "../api/audio:audio_frame_api",
+      "../api/audio:audio_mixer_api",
       "../api/audio:audio_processing",
       "../api/audio_codecs:audio_codecs_api",
       "../api/crypto:frame_decryptor_interface",
diff --git a/media/engine/fake_webrtc_call.h b/media/engine/fake_webrtc_call.h
index 28dbe29..4e7f68f 100644
--- a/media/engine/fake_webrtc_call.h
+++ b/media/engine/fake_webrtc_call.h
@@ -32,6 +32,7 @@
 #include "absl/strings/string_view.h"
 #include "api/adaptation/resource.h"
 #include "api/audio/audio_frame.h"
+#include "api/audio/audio_mixer.h"
 #include "api/audio_codecs/audio_format.h"
 #include "api/crypto/frame_decryptor_interface.h"
 #include "api/environment/environment.h"
@@ -171,6 +172,10 @@
   std::vector<webrtc::RtpSource> GetSources() const override {
     return std::vector<webrtc::RtpSource>();
   }
+  webrtc::AudioMixer::Source* source() override {
+    // TODO(b/397376626): Add a Fake AudioMixer::Source
+    return nullptr;
+  }
 
  private:
   int id_ = -1;