When SDES is used, pass pre-shared key to media transport.

This allows to use secure, end to end communication if SDES cryptos are
passed. MediaTransport can use a derived key to secure its own
communication.

Bug: webrtc:9719
Change-Id: If1a20b136b3b4af0cb24f10b52fc5ce1eb31daa2
Reviewed-on: https://webrtc-review.googlesource.com/c/108504
Commit-Queue: Peter Slatala <psla@webrtc.org>
Reviewed-by: Seth Hampson <shampson@webrtc.org>
Reviewed-by: Qingsi Wang <qingsi@webrtc.org>
Reviewed-by: Benjamin Wright <benwright@webrtc.org>
Reviewed-by: Bjorn Mellem <mellem@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#25452}
diff --git a/api/media_transport_interface.h b/api/media_transport_interface.h
index ca26851..1323635 100644
--- a/api/media_transport_interface.h
+++ b/api/media_transport_interface.h
@@ -48,6 +48,8 @@
   bool is_caller;
 
   // Must be set if a pre-shared key is used for the call.
+  // TODO(bugs.webrtc.org/9944): This should become zero buffer in the distant
+  // future.
   absl::optional<std::string> pre_shared_key;
 };
 
diff --git a/api/test/fake_media_transport.h b/api/test/fake_media_transport.h
index 956316d..5609126 100644
--- a/api/test/fake_media_transport.h
+++ b/api/test/fake_media_transport.h
@@ -12,6 +12,7 @@
 #define API_TEST_FAKE_MEDIA_TRANSPORT_H_
 
 #include <memory>
+#include <string>
 #include <utility>
 
 #include "absl/memory/memory.h"
@@ -25,7 +26,8 @@
 // could unit test audio / video integration.
 class FakeMediaTransport : public MediaTransportInterface {
  public:
-  explicit FakeMediaTransport(bool is_caller) : is_caller_(is_caller) {}
+  explicit FakeMediaTransport(const MediaTransportSettings& settings)
+      : settings_(settings) {}
   ~FakeMediaTransport() = default;
 
   RTCError SendAudioFrame(uint64_t channel_id,
@@ -46,14 +48,17 @@
   void SetReceiveAudioSink(MediaTransportAudioSinkInterface* sink) override {}
   void SetReceiveVideoSink(MediaTransportVideoSinkInterface* sink) override {}
 
-  // Returns true if fake media trasport was created as a caller.
-  bool is_caller() const { return is_caller_; }
+  // Returns true if fake media transport was created as a caller.
+  bool is_caller() const { return settings_.is_caller; }
+  absl::optional<std::string> pre_shared_key() const {
+    return settings_.pre_shared_key;
+  }
 
   void SetTargetTransferRateObserver(
       webrtc::TargetTransferRateObserver* observer) override {}
 
  private:
-  const bool is_caller_;
+  const MediaTransportSettings settings_;
 };
 
 // Fake media transport factory creates fake media transport.
@@ -66,9 +71,17 @@
       rtc::PacketTransportInternal* packet_transport,
       rtc::Thread* network_thread,
       bool is_caller) override {
-    std::unique_ptr<MediaTransportInterface> media_transport =
-        absl::make_unique<FakeMediaTransport>(is_caller);
+    MediaTransportSettings settings;
+    settings.is_caller = is_caller;
+    return CreateMediaTransport(packet_transport, network_thread, settings);
+  }
 
+  RTCErrorOr<std::unique_ptr<MediaTransportInterface>> CreateMediaTransport(
+      rtc::PacketTransportInternal* packet_transport,
+      rtc::Thread* network_thread,
+      const MediaTransportSettings& settings) override {
+    std::unique_ptr<MediaTransportInterface> media_transport =
+        absl::make_unique<FakeMediaTransport>(settings);
     return std::move(media_transport);
   }
 };
diff --git a/pc/jseptransportcontroller.cc b/pc/jseptransportcontroller.cc
index f00c22d..19b5025 100644
--- a/pc/jseptransportcontroller.cc
+++ b/pc/jseptransportcontroller.cc
@@ -15,8 +15,10 @@
 #include <utility>
 
 #include "p2p/base/port.h"
+#include "pc/srtpfilter.h"
 #include "rtc_base/bind.h"
 #include "rtc_base/checks.h"
+#include "rtc_base/key_derivation.h"
 #include "rtc_base/thread.h"
 
 using webrtc::SdpType;
@@ -940,16 +942,82 @@
         CreateDtlsTransport(content_info.name, /*rtcp =*/true);
   }
 
+  absl::optional<cricket::CryptoParams> selected_crypto_for_media_transport;
+  if (content_info.media_description() &&
+      !content_info.media_description()->cryptos().empty()) {
+    // Order of cryptos is deterministic (rfc4568, 5.1.1), so we just select the
+    // first one (in fact the first one should be the most preferred one.) We
+    // ignore the HMAC size, as media transport crypto settings currently don't
+    // expose HMAC size, nor crypto protocol for that matter.
+    selected_crypto_for_media_transport =
+        content_info.media_description()->cryptos()[0];
+  }
+
   if (config_.media_transport_factory != nullptr) {
-    auto media_transport_result =
-        config_.media_transport_factory->CreateMediaTransport(
-            rtp_dtls_transport->ice_transport(), network_thread_,
-            /*is_caller=*/local);
+    if (!selected_crypto_for_media_transport.has_value()) {
+      RTC_LOG(LS_WARNING) << "a=cryto line was not found in the offer. Most "
+                             "likely you did not enable SDES. "
+                             "Make sure to pass config.enable_dtls_srtp=false "
+                             "to RTCConfiguration. "
+                             "Cannot continue with media transport. Falling "
+                             "back to RTP. is_local="
+                          << local;
 
-    // TODO(sukhanov): Proper error handling.
-    RTC_CHECK(media_transport_result.ok());
+      // Remove media_transport_factory from config, because we don't want to
+      // use it on the subsequent call (for the other side of the offer).
+      config_.media_transport_factory = nullptr;
+    } else {
+      // Note that we ignore here lifetime and length.
+      // In fact we take those bits (inline, lifetime and length) and keep it as
+      // part of key derivation.
+      //
+      // Technically, we are also not following rfc4568, which requires us to
+      // send and answer with the key that we chose. In practice, for media
+      // transport, the current approach should be sufficient (we take the key
+      // that sender offered, and caller assumes we will use it. We are not
+      // signaling back that we indeed used it.)
+      std::unique_ptr<rtc::KeyDerivation> key_derivation =
+          rtc::KeyDerivation::Create(rtc::KeyDerivationAlgorithm::HKDF_SHA256);
+      const std::string label = "MediaTransportLabel";
+      constexpr int kDerivedKeyByteSize = 32;
 
-    media_transport = std::move(media_transport_result.value());
+      int key_len, salt_len;
+      if (!rtc::GetSrtpKeyAndSaltLengths(
+              rtc::SrtpCryptoSuiteFromName(
+                  selected_crypto_for_media_transport.value().cipher_suite),
+              &key_len, &salt_len)) {
+        RTC_CHECK(false) << "Cannot set up secure media transport";
+      }
+      rtc::ZeroOnFreeBuffer<uint8_t> raw_key(key_len + salt_len);
+
+      cricket::SrtpFilter::ParseKeyParams(
+          selected_crypto_for_media_transport.value().key_params,
+          raw_key.data(), raw_key.size());
+      absl::optional<rtc::ZeroOnFreeBuffer<uint8_t>> key =
+          key_derivation->DeriveKey(
+              raw_key,
+              /*salt=*/nullptr,
+              rtc::ArrayView<const uint8_t>(
+                  reinterpret_cast<const uint8_t*>(label.data()), label.size()),
+              kDerivedKeyByteSize);
+
+      // We want to crash the app if we don't have a key, and not silently fall
+      // back to the unsecure communication.
+      RTC_CHECK(key.has_value());
+      MediaTransportSettings settings;
+      settings.is_caller = local;
+      settings.pre_shared_key =
+          std::string(reinterpret_cast<const char*>(key.value().data()),
+                      key.value().size());
+      auto media_transport_result =
+          config_.media_transport_factory->CreateMediaTransport(
+              rtp_dtls_transport->ice_transport(), network_thread_, settings);
+
+      // TODO(sukhanov): Proper error handling.
+      RTC_CHECK(media_transport_result.ok());
+
+      media_transport = std::move(media_transport_result.value());
+    }
   }
 
   // TODO(sukhanov): Do not create RTP/RTCP transports if media transport is
diff --git a/pc/jseptransportcontroller_unittest.cc b/pc/jseptransportcontroller_unittest.cc
index ba9b72d..08b1f9b 100644
--- a/pc/jseptransportcontroller_unittest.cc
+++ b/pc/jseptransportcontroller_unittest.cc
@@ -41,6 +41,20 @@
 
 namespace webrtc {
 
+namespace {
+
+// Media transport factory requires crypto settings to be present in order to
+// create media transport.
+void AddCryptoSettings(cricket::SessionDescription* description) {
+  for (auto& content : description->contents()) {
+    content.media_description()->AddCrypto(cricket::CryptoParams(
+        /*t=*/0, std::string(rtc::CS_AES_CM_128_HMAC_SHA1_80),
+        "inline:YUJDZGVmZ2hpSktMbW9QUXJzVHVWd3l6MTIzNDU2", ""));
+  }
+}
+
+}  // namespace
+
 class FakeTransportFactory : public cricket::TransportFactoryInterface {
  public:
   std::unique_ptr<cricket::IceTransportInternal> CreateIceTransport(
@@ -380,6 +394,8 @@
   config.media_transport_factory = &fake_media_transport_factory;
   CreateJsepTransportController(config);
   auto description = CreateSessionDescriptionWithoutBundle();
+  AddCryptoSettings(description.get());
+
   EXPECT_TRUE(transport_controller_
                   ->SetLocalDescription(SdpType::kOffer, description.get())
                   .ok());
@@ -391,6 +407,7 @@
 
   // After SetLocalDescription, media transport should be created as caller.
   EXPECT_TRUE(media_transport->is_caller());
+  EXPECT_TRUE(media_transport->pre_shared_key().has_value());
 
   // Return nullptr for non-existing mids.
   EXPECT_EQ(nullptr, transport_controller_->GetMediaTransport(kVideoMid2));
@@ -404,6 +421,7 @@
   config.media_transport_factory = &fake_media_transport_factory;
   CreateJsepTransportController(config);
   auto description = CreateSessionDescriptionWithoutBundle();
+  AddCryptoSettings(description.get());
   EXPECT_TRUE(transport_controller_
                   ->SetRemoteDescription(SdpType::kOffer, description.get())
                   .ok());
@@ -415,11 +433,70 @@
 
   // After SetRemoteDescription, media transport should be created as callee.
   EXPECT_FALSE(media_transport->is_caller());
+  EXPECT_TRUE(media_transport->pre_shared_key().has_value());
 
   // Return nullptr for non-existing mids.
   EXPECT_EQ(nullptr, transport_controller_->GetMediaTransport(kVideoMid2));
 }
 
+TEST_F(JsepTransportControllerTest, GetMediaTransportIsNotSetIfNoSdes) {
+  FakeMediaTransportFactory fake_media_transport_factory;
+  JsepTransportController::Config config;
+
+  config.rtcp_mux_policy = PeerConnectionInterface::kRtcpMuxPolicyNegotiate;
+  config.media_transport_factory = &fake_media_transport_factory;
+  CreateJsepTransportController(config);
+  auto description = CreateSessionDescriptionWithoutBundle();
+  EXPECT_TRUE(transport_controller_
+                  ->SetRemoteDescription(SdpType::kOffer, description.get())
+                  .ok());
+
+  EXPECT_EQ(nullptr, transport_controller_->GetMediaTransport(kAudioMid1));
+
+  // Even if we set local description with crypto now (after the remote offer
+  // was set), media transport won't be provided.
+  auto description2 = CreateSessionDescriptionWithoutBundle();
+  AddCryptoSettings(description2.get());
+  EXPECT_TRUE(transport_controller_
+                  ->SetLocalDescription(SdpType::kAnswer, description2.get())
+                  .ok());
+
+  EXPECT_EQ(nullptr, transport_controller_->GetMediaTransport(kAudioMid1));
+}
+
+TEST_F(JsepTransportControllerTest,
+       AfterSettingAnswerTheSameMediaTransportIsReturned) {
+  FakeMediaTransportFactory fake_media_transport_factory;
+  JsepTransportController::Config config;
+
+  config.rtcp_mux_policy = PeerConnectionInterface::kRtcpMuxPolicyNegotiate;
+  config.media_transport_factory = &fake_media_transport_factory;
+  CreateJsepTransportController(config);
+  auto description = CreateSessionDescriptionWithoutBundle();
+  AddCryptoSettings(description.get());
+  EXPECT_TRUE(transport_controller_
+                  ->SetRemoteDescription(SdpType::kOffer, description.get())
+                  .ok());
+
+  FakeMediaTransport* media_transport = static_cast<FakeMediaTransport*>(
+      transport_controller_->GetMediaTransport(kAudioMid1));
+  EXPECT_NE(nullptr, media_transport);
+  EXPECT_TRUE(media_transport->pre_shared_key().has_value());
+
+  // Even if we set local description with crypto now (after the remote offer
+  // was set), media transport won't be provided.
+  auto description2 = CreateSessionDescriptionWithoutBundle();
+  AddCryptoSettings(description2.get());
+
+  RTCError result = transport_controller_->SetLocalDescription(
+      SdpType::kAnswer, description2.get());
+  EXPECT_TRUE(result.ok()) << result.message();
+
+  // Media transport did not change.
+  EXPECT_EQ(media_transport,
+            transport_controller_->GetMediaTransport(kAudioMid1));
+}
+
 TEST_F(JsepTransportControllerTest, SetIceConfig) {
   CreateJsepTransportController(JsepTransportController::Config());
   auto description = CreateSessionDescriptionWithoutBundle();
diff --git a/pc/peerconnection_media_unittest.cc b/pc/peerconnection_media_unittest.cc
index 6b50091..bb4592d 100644
--- a/pc/peerconnection_media_unittest.cc
+++ b/pc/peerconnection_media_unittest.cc
@@ -1092,6 +1092,9 @@
   // Setup PeerConnection to use media transport.
   config.use_media_transport = true;
 
+  // Force SDES.
+  config.enable_dtls_srtp = false;
+
   auto caller = CreatePeerConnectionWithAudioVideo(config);
   auto callee = CreatePeerConnectionWithAudioVideo(config);
 
diff --git a/pc/srtpfilter.h b/pc/srtpfilter.h
index a4dd54f..4ab0dd7 100644
--- a/pc/srtpfilter.h
+++ b/pc/srtpfilter.h
@@ -78,6 +78,10 @@
 
   bool ResetParams();
 
+  static bool ParseKeyParams(const std::string& params,
+                             uint8_t* key,
+                             size_t len);
+
   absl::optional<int> send_cipher_suite() { return send_cipher_suite_; }
   absl::optional<int> recv_cipher_suite() { return recv_cipher_suite_; }
 
@@ -104,10 +108,6 @@
 
   bool ApplyRecvParams(const CryptoParams& recv_params);
 
-  static bool ParseKeyParams(const std::string& params,
-                             uint8_t* key,
-                             size_t len);
-
   enum State {
     ST_INIT,                    // SRTP filter unused.
     ST_SENTOFFER,               // Offer with SRTP parameters sent.