Modify MediaEngine's GetRtpHeaderExtensions() call to use field trials from the caller rather than from the media engine.

This is required in order to have field trials select different header extensions for different PeerConnections from the same PeerConnectionFactory.

This will allow us to have immutable field trials, yay!
see (not yet complete) https://webrtc-review.googlesource.com/c/src/+/409040

Bug: b/444370738
Change-Id: I9ddedbeaf7b5d19fddf96d62d73020e4910305f3
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/409540
Reviewed-by: Harald Alvestrand <hta@webrtc.org>
Commit-Queue: Jonas Oreland <jonaso@webrtc.org>
Cr-Commit-Position: refs/heads/main@{#45650}
diff --git a/media/base/fake_media_engine.cc b/media/base/fake_media_engine.cc
index f4edea9..8acc173 100644
--- a/media/base/fake_media_engine.cc
+++ b/media/base/fake_media_engine.cc
@@ -624,7 +624,8 @@
 void FakeVoiceEngine::StopAecDump() {}
 
 std::vector<RtpHeaderExtensionCapability>
-FakeVoiceEngine::GetRtpHeaderExtensions() const {
+FakeVoiceEngine::GetRtpHeaderExtensions(
+    const FieldTrialsView* field_trials) const {
   return header_extensions_;
 }
 
@@ -699,7 +700,8 @@
   return true;
 }
 std::vector<RtpHeaderExtensionCapability>
-FakeVideoEngine::GetRtpHeaderExtensions() const {
+FakeVideoEngine::GetRtpHeaderExtensions(
+    const FieldTrialsView* field_trials) const {
   return header_extensions_;
 }
 void FakeVideoEngine::SetRtpHeaderExtensions(
diff --git a/media/base/fake_media_engine.h b/media/base/fake_media_engine.h
index f06e485..b4f40da 100644
--- a/media/base/fake_media_engine.h
+++ b/media/base/fake_media_engine.h
@@ -822,8 +822,8 @@
   bool StartAecDump(FileWrapper file, int64_t max_size_bytes) override;
   void StopAecDump() override;
   std::optional<AudioDeviceModule::Stats> GetAudioDeviceStats() override;
-  std::vector<RtpHeaderExtensionCapability> GetRtpHeaderExtensions()
-      const override;
+  std::vector<RtpHeaderExtensionCapability> GetRtpHeaderExtensions(
+      const FieldTrialsView* field_trials) const override;
   void SetRtpHeaderExtensions(
       std::vector<RtpHeaderExtensionCapability> header_extensions);
 
@@ -924,8 +924,8 @@
   void SetSendCodecs(const std::vector<Codec>& codecs);
   void SetRecvCodecs(const std::vector<Codec>& codecs);
   bool SetCapture(bool capture);
-  std::vector<RtpHeaderExtensionCapability> GetRtpHeaderExtensions()
-      const override;
+  std::vector<RtpHeaderExtensionCapability> GetRtpHeaderExtensions(
+      const FieldTrialsView* field_trials) const override;
   void SetRtpHeaderExtensions(
       std::vector<RtpHeaderExtensionCapability> header_extensions);
 
diff --git a/media/base/media_engine.cc b/media/base/media_engine.cc
index 752573c..8c64710 100644
--- a/media/base/media_engine.cc
+++ b/media/base/media_engine.cc
@@ -77,9 +77,11 @@
 }
 
 std::vector<RtpExtension> GetDefaultEnabledRtpHeaderExtensions(
-    const RtpHeaderExtensionQueryInterface& query_interface) {
+    const RtpHeaderExtensionQueryInterface& query_interface,
+    const webrtc::FieldTrialsView* field_trials) {
   std::vector<RtpExtension> extensions;
-  for (const auto& entry : query_interface.GetRtpHeaderExtensions()) {
+  for (const auto& entry :
+       query_interface.GetRtpHeaderExtensions(field_trials)) {
     if (entry.direction != RtpTransceiverDirection::kStopped)
       extensions.emplace_back(entry.uri, *entry.preferred_id);
   }
diff --git a/media/base/media_engine.h b/media/base/media_engine.h
index f18a5a0..651d944 100644
--- a/media/base/media_engine.h
+++ b/media/base/media_engine.h
@@ -77,8 +77,8 @@
 
   // Returns a vector of RtpHeaderExtensionCapability, whose direction is
   // kStopped if the extension is stopped (not used) by default.
-  virtual std::vector<RtpHeaderExtensionCapability> GetRtpHeaderExtensions()
-      const = 0;
+  virtual std::vector<RtpHeaderExtensionCapability> GetRtpHeaderExtensions(
+      const webrtc::FieldTrialsView* field_trials) const = 0;
 };
 
 class VoiceEngineInterface : public RtpHeaderExtensionQueryInterface {
@@ -238,7 +238,8 @@
 // offered by default, i.e. the list of extensions returned from
 // GetRtpHeaderExtensions() that are not kStopped.
 std::vector<RtpExtension> GetDefaultEnabledRtpHeaderExtensions(
-    const RtpHeaderExtensionQueryInterface& query_interface);
+    const RtpHeaderExtensionQueryInterface& query_interface,
+    const webrtc::FieldTrialsView* field_trials);
 
 }  //  namespace webrtc
 
diff --git a/media/base/media_engine_unittest.cc b/media/base/media_engine_unittest.cc
index 54cf713..ae2f7a8 100644
--- a/media/base/media_engine_unittest.cc
+++ b/media/base/media_engine_unittest.cc
@@ -40,7 +40,7 @@
  public:
   MOCK_METHOD(std::vector<RtpHeaderExtensionCapability>,
               GetRtpHeaderExtensions,
-              (),
+              (const FieldTrialsView*),
               (const, override));
 };
 
@@ -60,7 +60,7 @@
        RtpHeaderExtensionCapability("uri5", 5,
                                     RtpTransceiverDirection::kRecvOnly)});
   EXPECT_CALL(mock, GetRtpHeaderExtensions).WillOnce(Return(extensions));
-  EXPECT_THAT(GetDefaultEnabledRtpHeaderExtensions(mock),
+  EXPECT_THAT(GetDefaultEnabledRtpHeaderExtensions(mock, nullptr),
               ElementsAre(Field(&RtpExtension::uri, StrEq("uri1")),
                           Field(&RtpExtension::uri, StrEq("uri2")),
                           Field(&RtpExtension::uri, StrEq("uri4")),
@@ -74,7 +74,7 @@
  public:
   MOCK_METHOD(std::vector<RtpHeaderExtensionCapability>,
               GetRtpHeaderExtensions,
-              (),
+              (const FieldTrialsView*),
               (const, override));
   MOCK_METHOD(void, Init, (), (override));
   MOCK_METHOD(scoped_refptr<AudioState>, GetAudioState, (), (const, override));
diff --git a/media/engine/webrtc_video_engine.cc b/media/engine/webrtc_video_engine.cc
index bc14a7b..3d7c4bf 100644
--- a/media/engine/webrtc_video_engine.cc
+++ b/media/engine/webrtc_video_engine.cc
@@ -858,7 +858,13 @@
 }
 
 std::vector<RtpHeaderExtensionCapability>
-WebRtcVideoEngine::GetRtpHeaderExtensions() const {
+WebRtcVideoEngine::GetRtpHeaderExtensions(
+    const webrtc::FieldTrialsView* field_trials) const {
+  // Use field trials from PeerConnection `field_trials` or from
+  // PeerConnectionFactory `trials_`.
+  const webrtc::FieldTrialsView& trials =
+      (field_trials != nullptr ? *field_trials : trials_);
+
   std::vector<RtpHeaderExtensionCapability> result;
   // id is *not* incremented for non-default extensions, UsedIds needs to
   // resolve conflicts.
@@ -880,21 +886,20 @@
     result.emplace_back(uri, id, RtpTransceiverDirection::kStopped);
   }
   result.emplace_back(RtpExtension::kGenericFrameDescriptorUri00, id,
-                      trials_.IsEnabled("WebRTC-GenericDescriptorAdvertised")
+                      trials.IsEnabled("WebRTC-GenericDescriptorAdvertised")
                           ? RtpTransceiverDirection::kSendRecv
                           : RtpTransceiverDirection::kStopped);
   result.emplace_back(RtpExtension::kDependencyDescriptorUri, id,
-                      trials_.IsEnabled("WebRTC-DependencyDescriptorAdvertised")
+                      trials.IsEnabled("WebRTC-DependencyDescriptorAdvertised")
                           ? RtpTransceiverDirection::kSendRecv
                           : RtpTransceiverDirection::kStopped);
-  result.emplace_back(
-      RtpExtension::kVideoLayersAllocationUri, id,
-      trials_.IsEnabled("WebRTC-VideoLayersAllocationAdvertised")
-          ? RtpTransceiverDirection::kSendRecv
-          : RtpTransceiverDirection::kStopped);
+  result.emplace_back(RtpExtension::kVideoLayersAllocationUri, id,
+                      trials.IsEnabled("WebRTC-VideoLayersAllocationAdvertised")
+                          ? RtpTransceiverDirection::kSendRecv
+                          : RtpTransceiverDirection::kStopped);
 
   // VideoFrameTrackingId is a test-only extension.
-  if (trials_.IsEnabled("WebRTC-VideoFrameTrackingIdAdvertised")) {
+  if (trials.IsEnabled("WebRTC-VideoFrameTrackingIdAdvertised")) {
     result.emplace_back(RtpExtension::kVideoFrameTrackingIdUri, id,
                         RtpTransceiverDirection::kSendRecv);
   }
diff --git a/media/engine/webrtc_video_engine.h b/media/engine/webrtc_video_engine.h
index a3276bb..456c6df 100644
--- a/media/engine/webrtc_video_engine.h
+++ b/media/engine/webrtc_video_engine.h
@@ -125,15 +125,18 @@
   }
   std::vector<Codec> LegacySendCodecs(bool include_rtx) const override;
   std::vector<Codec> LegacyRecvCodecs(bool include_rtx) const override;
-  std::vector<RtpHeaderExtensionCapability> GetRtpHeaderExtensions()
-      const override;
+
+  std::vector<RtpHeaderExtensionCapability> GetRtpHeaderExtensions(
+      /* optional field trials from PeerConnection that override those from
+         PeerConnectionFactory */
+      const webrtc::FieldTrialsView* field_trials) const override;
 
  private:
   const std::unique_ptr<VideoDecoderFactory> decoder_factory_;
   const std::unique_ptr<VideoEncoderFactory> encoder_factory_;
   const std::unique_ptr<VideoBitrateAllocatorFactory>
       bitrate_allocator_factory_;
-  const FieldTrialsView& trials_;
+  const FieldTrialsView& trials_;  // from PeerConnectionFactory
 };
 
 struct VideoCodecSettings {
diff --git a/media/engine/webrtc_video_engine_unittest.cc b/media/engine/webrtc_video_engine_unittest.cc
index be0245a..52d8f22 100644
--- a/media/engine/webrtc_video_engine_unittest.cc
+++ b/media/engine/webrtc_video_engine_unittest.cc
@@ -991,7 +991,8 @@
 void WebRtcVideoEngineTest::ExpectRtpCapabilitySupport(const char* uri,
                                                        bool supported) const {
   const std::vector<RtpExtension> header_extensions =
-      GetDefaultEnabledRtpHeaderExtensions(engine_);
+      GetDefaultEnabledRtpHeaderExtensions(engine_,
+                                           /* field_trials= */ nullptr);
   if (supported) {
     EXPECT_THAT(header_extensions, Contains(Field(&RtpExtension::uri, uri)));
   } else {
diff --git a/media/engine/webrtc_voice_engine.cc b/media/engine/webrtc_voice_engine.cc
index a5b9c52..2cecaed 100644
--- a/media/engine/webrtc_voice_engine.cc
+++ b/media/engine/webrtc_voice_engine.cc
@@ -765,7 +765,8 @@
 }
 
 std::vector<RtpHeaderExtensionCapability>
-WebRtcVoiceEngine::GetRtpHeaderExtensions() const {
+WebRtcVoiceEngine::GetRtpHeaderExtensions(
+    const webrtc::FieldTrialsView* field_trials) const {
   RTC_DCHECK(signal_thread_checker_.IsCurrent());
   std::vector<RtpHeaderExtensionCapability> result;
   // id is *not* incremented for non-default extensions, UsedIds needs to
diff --git a/media/engine/webrtc_voice_engine.h b/media/engine/webrtc_voice_engine.h
index 25c1c14..b73e443 100644
--- a/media/engine/webrtc_voice_engine.h
+++ b/media/engine/webrtc_voice_engine.h
@@ -120,8 +120,8 @@
   AudioDecoderFactory* decoder_factory() const override {
     return decoder_factory_.get();
   }
-  std::vector<RtpHeaderExtensionCapability> GetRtpHeaderExtensions()
-      const override;
+  std::vector<RtpHeaderExtensionCapability> GetRtpHeaderExtensions(
+      const webrtc::FieldTrialsView* field_trials) const override;
 
   // Starts AEC dump using an existing file. A maximum file size in bytes can be
   // specified. When the maximum file size is reached, logging is stopped and
diff --git a/media/engine/webrtc_voice_engine_unittest.cc b/media/engine/webrtc_voice_engine_unittest.cc
index fff9b04..ce5627a 100644
--- a/media/engine/webrtc_voice_engine_unittest.cc
+++ b/media/engine/webrtc_voice_engine_unittest.cc
@@ -2365,7 +2365,8 @@
 TEST_P(WebRtcVoiceEngineTestFake,
        SupportsTransportSequenceNumberHeaderExtension) {
   const std::vector<webrtc::RtpExtension> header_extensions =
-      webrtc::GetDefaultEnabledRtpHeaderExtensions(*engine_);
+      webrtc::GetDefaultEnabledRtpHeaderExtensions(*engine_,
+                                                   /* field_trials= */ nullptr);
   EXPECT_THAT(header_extensions,
               Contains(::testing::Field(
                   "uri", &webrtc::RtpExtension::uri,
@@ -3603,7 +3604,8 @@
 
   // Set up receive extensions.
   const std::vector<webrtc::RtpExtension> header_extensions =
-      webrtc::GetDefaultEnabledRtpHeaderExtensions(*engine_);
+      webrtc::GetDefaultEnabledRtpHeaderExtensions(*engine_,
+                                                   /* field_trials= */ nullptr);
   webrtc::AudioReceiverParameters recv_parameters;
   recv_parameters.extensions = header_extensions;
   receive_channel_->SetReceiverParameters(recv_parameters);
diff --git a/pc/peer_connection_factory.cc b/pc/peer_connection_factory.cc
index 6e706b0..a0b8848 100644
--- a/pc/peer_connection_factory.cc
+++ b/pc/peer_connection_factory.cc
@@ -141,15 +141,15 @@
     case MediaType::AUDIO: {
       Codecs cricket_codecs;
       cricket_codecs = codec_vendor_.audio_send_codecs().codecs();
-      auto extensions =
-          GetDefaultEnabledRtpHeaderExtensions(media_engine()->voice());
+      auto extensions = GetDefaultEnabledRtpHeaderExtensions(
+          media_engine()->voice(), /* field_trials= */ nullptr);
       return ToRtpCapabilities(cricket_codecs, extensions);
     }
     case MediaType::VIDEO: {
       Codecs cricket_codecs;
       cricket_codecs = codec_vendor_.video_send_codecs().codecs();
-      auto extensions =
-          GetDefaultEnabledRtpHeaderExtensions(media_engine()->video());
+      auto extensions = GetDefaultEnabledRtpHeaderExtensions(
+          media_engine()->video(), /* field_trials= */ nullptr);
       return ToRtpCapabilities(cricket_codecs, extensions);
     }
     default:
@@ -166,14 +166,14 @@
     case MediaType::AUDIO: {
       Codecs cricket_codecs;
       cricket_codecs = codec_vendor_.audio_recv_codecs().codecs();
-      auto extensions =
-          GetDefaultEnabledRtpHeaderExtensions(media_engine()->voice());
+      auto extensions = GetDefaultEnabledRtpHeaderExtensions(
+          media_engine()->voice(), /* field_trials= */ nullptr);
       return ToRtpCapabilities(cricket_codecs, extensions);
     }
     case MediaType::VIDEO: {
       Codecs cricket_codecs = codec_vendor_.video_recv_codecs().codecs();
-      auto extensions =
-          GetDefaultEnabledRtpHeaderExtensions(media_engine()->video());
+      auto extensions = GetDefaultEnabledRtpHeaderExtensions(
+          media_engine()->video(), /* field_trials= */ nullptr);
       return ToRtpCapabilities(cricket_codecs, extensions);
     }
     default:
diff --git a/pc/peer_connection_integrationtest.cc b/pc/peer_connection_integrationtest.cc
index 7a8d086..8d4b54c 100644
--- a/pc/peer_connection_integrationtest.cc
+++ b/pc/peer_connection_integrationtest.cc
@@ -100,6 +100,7 @@
 namespace {
 
 using ::testing::AtLeast;
+using ::testing::Contains;
 using ::testing::Eq;
 using ::testing::Field;
 using ::testing::Gt;
@@ -108,6 +109,7 @@
 using ::testing::IsTrue;
 using ::testing::MockFunction;
 using ::testing::NiceMock;
+using ::testing::Not;
 using ::testing::NotNull;
 using ::testing::Return;
 using ::testing::WithParamInterface;
@@ -5027,6 +5029,49 @@
 
 #endif  // WEBRTC_HAVE_SCTP
 
+TEST_F(PeerConnectionIntegrationTestUnifiedPlan,
+       PerPeerConnectionHeaderExtensions) {
+  SetFieldTrials("caller", "WebRTC-VideoFrameTrackingIdAdvertised/Enabled/");
+  SetFieldTrials("callee", "WebRTC-VideoFrameTrackingIdAdvertised/Disabled/");
+  PeerConnectionInterface::RTCConfiguration config;
+  PeerConnectionFactoryInterface::Options options;
+  options.ssl_max_version = SSL_PROTOCOL_DTLS_13;
+
+  const bool create_media_engine = true;
+  SetCallerPcWrapperAndReturnCurrent(CreatePeerConnectionWrapper(
+      "caller", &options, &config, PeerConnectionDependencies(nullptr),
+      /* event_log_factory= */ nullptr,
+      /* reset_encoder_factory= */ false,
+      /* reset_decoder_factory= */ false, create_media_engine));
+  SetCalleePcWrapperAndReturnCurrent(CreatePeerConnectionWrapper(
+      "callee", &options, &config, PeerConnectionDependencies(nullptr),
+      /* event_log_factory= */ nullptr,
+      /* reset_encoder_factory= */ false,
+      /* reset_decoder_factory= */ false, create_media_engine));
+
+  const std::string uri =
+      "http://www.webrtc.org/experiments/rtp-hdrext/video-frame-tracking-id";
+  {
+    caller()->pc()->AddTransceiver(MediaType::VIDEO);
+    auto session_description = caller()->CreateOfferAndWait();
+    EXPECT_THAT(session_description->description()
+                    ->contents()[0]
+                    .media_description()
+                    ->rtp_header_extensions(),
+                Contains(Field(&RtpExtension::uri, uri)));
+  }
+
+  {
+    callee()->pc()->AddTransceiver(MediaType::VIDEO);
+    auto session_description = callee()->CreateOfferAndWait();
+    EXPECT_THAT(session_description->description()
+                    ->contents()[0]
+                    .media_description()
+                    ->rtp_header_extensions(),
+                Not(Contains(Field(&RtpExtension::uri, uri))));
+  }
+}
+
 }  // namespace
 
 }  // namespace webrtc
diff --git a/pc/rtp_transceiver_unittest.cc b/pc/rtp_transceiver_unittest.cc
index 2165c32..ef209d3 100644
--- a/pc/rtp_transceiver_unittest.cc
+++ b/pc/rtp_transceiver_unittest.cc
@@ -195,7 +195,8 @@
         RtpReceiverProxyWithInternal<RtpReceiverInternal>::Create(
             Thread::Current(), Thread::Current(), std::move(receiver)),
         context(), codec_lookup_helper(),
-        media_engine()->voice().GetRtpHeaderExtensions(),
+        media_engine()->voice().GetRtpHeaderExtensions(
+            &context()->env().field_trials()),
         /* on_negotiation_needed= */ [] {});
   }
 
diff --git a/pc/rtp_transmission_manager.cc b/pc/rtp_transmission_manager.cc
index 4b9dc59..c8a6a08 100644
--- a/pc/rtp_transmission_manager.cc
+++ b/pc/rtp_transmission_manager.cc
@@ -335,9 +335,11 @@
   }
   if (header_extensions.empty()) {
     if (sender->media_type() == MediaType::AUDIO) {
-      header_extensions = media_engine()->voice().GetRtpHeaderExtensions();
+      header_extensions =
+          media_engine()->voice().GetRtpHeaderExtensions(&env_.field_trials());
     } else {
-      header_extensions = media_engine()->video().GetRtpHeaderExtensions();
+      header_extensions =
+          media_engine()->video().GetRtpHeaderExtensions(&env_.field_trials());
     }
   }
 
diff --git a/pc/sdp_offer_answer.cc b/pc/sdp_offer_answer.cc
index 506ec61..1ccccc2 100644
--- a/pc/sdp_offer_answer.cc
+++ b/pc/sdp_offer_answer.cc
@@ -4400,7 +4400,8 @@
           MediaType::AUDIO, GetDefaultMidForPlanB(MediaType::AUDIO),
           RtpTransceiverDirectionFromSendRecv(send_audio, recv_audio), false);
       options.header_extensions =
-          media_engine()->voice().GetRtpHeaderExtensions();
+          media_engine()->voice().GetRtpHeaderExtensions(
+              &context_->env().field_trials());
       session_options->media_description_options.push_back(options);
       audio_index = session_options->media_description_options.size() - 1;
     }
@@ -4409,7 +4410,8 @@
           MediaType::VIDEO, GetDefaultMidForPlanB(MediaType::VIDEO),
           RtpTransceiverDirectionFromSendRecv(send_video, recv_video), false);
       options.header_extensions =
-          media_engine()->video().GetRtpHeaderExtensions();
+          media_engine()->video().GetRtpHeaderExtensions(
+              &context_->env().field_trials());
       session_options->media_description_options.push_back(options);
       video_index = session_options->media_description_options.size() - 1;
     }
@@ -5543,7 +5545,8 @@
         *audio_index = session_options->media_description_options.size() - 1;
       }
       session_options->media_description_options.back().header_extensions =
-          media_engine()->voice().GetRtpHeaderExtensions();
+          media_engine()->voice().GetRtpHeaderExtensions(
+              &context_->env().field_trials());
     } else if (IsVideoContent(&content)) {
       // If we already have an video m= section, reject this extra one.
       if (*video_index) {
@@ -5559,7 +5562,8 @@
         *video_index = session_options->media_description_options.size() - 1;
       }
       session_options->media_description_options.back().header_extensions =
-          media_engine()->video().GetRtpHeaderExtensions();
+          media_engine()->video().GetRtpHeaderExtensions(
+              &context_->env().field_trials());
     } else if (IsUnsupportedContent(&content)) {
       session_options->media_description_options.push_back(
           MediaDescriptionOptions(MediaType::UNSUPPORTED, content.mid(),