Replace interfaces for sending RTCP with std::functions in ReceiveSideCongestionController

Logic for throttling how often REMB messages are sent is added to ReceiveSideCongestionController as well as a new method SetMaxDesiredReceiveBitrate. These are based on the logic in PacketRouter. The logic for throttling REMB and setting the max REMB will be removed from PacketRouter in a follow up cl.
The purpose is to eventually decouple PacketRouter from sending RTCP messages when RtcpTransceiver is used.

Bug: webrtc:12693
Change-Id: I9fb5cbcd14bb17d977e76d329a906fc0a9abc276
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/215685
Reviewed-by: Philip Eliasson <philipel@webrtc.org>
Reviewed-by: Christoffer Rodbro <crodbro@webrtc.org>
Reviewed-by: Danil Chapovalov <danilchap@webrtc.org>
Commit-Queue: Per Kjellander <perkj@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#33801}
diff --git a/modules/congestion_controller/BUILD.gn b/modules/congestion_controller/BUILD.gn
index 3e1e8c0..c0b064d 100644
--- a/modules/congestion_controller/BUILD.gn
+++ b/modules/congestion_controller/BUILD.gn
@@ -22,12 +22,17 @@
   sources = [
     "include/receive_side_congestion_controller.h",
     "receive_side_congestion_controller.cc",
+    "remb_throttler.cc",
+    "remb_throttler.h",
   ]
 
   deps = [
     "..:module_api",
     "../../api/transport:field_trial_based_config",
     "../../api/transport:network_control",
+    "../../api/units:data_rate",
+    "../../api/units:time_delta",
+    "../../api/units:timestamp",
     "../../rtc_base/synchronization:mutex",
     "../pacing",
     "../remote_bitrate_estimator",
@@ -43,11 +48,17 @@
   rtc_library("congestion_controller_unittests") {
     testonly = true
 
-    sources = [ "receive_side_congestion_controller_unittest.cc" ]
+    sources = [
+      "receive_side_congestion_controller_unittest.cc",
+      "remb_throttler_unittest.cc",
+    ]
     deps = [
       ":congestion_controller",
       "../../api/test/network_emulation",
       "../../api/test/network_emulation:create_cross_traffic",
+      "../../api/units:data_rate",
+      "../../api/units:time_delta",
+      "../../api/units:timestamp",
       "../../system_wrappers",
       "../../test:test_support",
       "../../test/scenario",
diff --git a/modules/congestion_controller/include/receive_side_congestion_controller.h b/modules/congestion_controller/include/receive_side_congestion_controller.h
index 034f2e9..b46cd8d 100644
--- a/modules/congestion_controller/include/receive_side_congestion_controller.h
+++ b/modules/congestion_controller/include/receive_side_congestion_controller.h
@@ -16,7 +16,10 @@
 
 #include "api/transport/field_trial_based_config.h"
 #include "api/transport/network_control.h"
+#include "api/units/data_rate.h"
+#include "modules/congestion_controller/remb_throttler.h"
 #include "modules/include/module.h"
+#include "modules/pacing/packet_router.h"
 #include "modules/remote_bitrate_estimator/remote_estimator_proxy.h"
 #include "rtc_base/synchronization/mutex.h"
 
@@ -32,12 +35,20 @@
 class ReceiveSideCongestionController : public CallStatsObserver,
                                         public Module {
  public:
+  // TODO(bugs.webrtc.org/12693): Deprecate
   ReceiveSideCongestionController(Clock* clock, PacketRouter* packet_router);
+  // TODO(bugs.webrtc.org/12693): Deprecate
   ReceiveSideCongestionController(
       Clock* clock,
       PacketRouter* packet_router,
       NetworkStateEstimator* network_state_estimator);
 
+  ReceiveSideCongestionController(
+      Clock* clock,
+      RemoteEstimatorProxy::TransportFeedbackSender feedback_sender,
+      RembThrottler::RembSender remb_sender,
+      NetworkStateEstimator* network_state_estimator);
+
   ~ReceiveSideCongestionController() override {}
 
   virtual void OnReceivedPacket(int64_t arrival_time_ms,
@@ -56,6 +67,10 @@
   // This is send bitrate, used to control the rate of feedback messages.
   void OnBitrateChanged(int bitrate_bps);
 
+  // Ensures the remote party is notified of the receive bitrate no larger than
+  // |bitrate| using RTCP REMB.
+  void SetMaxDesiredReceiveBitrate(DataRate bitrate);
+
   // Implements Module.
   int64_t TimeUntilNextProcess() override;
   void Process() override;
@@ -103,6 +118,7 @@
   };
 
   const FieldTrialBasedConfig field_trial_config_;
+  RembThrottler remb_throttler_;
   WrappingBitrateEstimator remote_bitrate_estimator_;
   RemoteEstimatorProxy remote_estimator_proxy_;
 };
diff --git a/modules/congestion_controller/receive_side_congestion_controller.cc b/modules/congestion_controller/receive_side_congestion_controller.cc
index 638cb2d..e4e6cc9 100644
--- a/modules/congestion_controller/receive_side_congestion_controller.cc
+++ b/modules/congestion_controller/receive_side_congestion_controller.cc
@@ -10,6 +10,7 @@
 
 #include "modules/congestion_controller/include/receive_side_congestion_controller.h"
 
+#include "api/units/data_rate.h"
 #include "modules/pacing/packet_router.h"
 #include "modules/remote_bitrate_estimator/include/bwe_defines.h"
 #include "modules/remote_bitrate_estimator/remote_bitrate_estimator_abs_send_time.h"
@@ -127,9 +128,26 @@
     Clock* clock,
     PacketRouter* packet_router,
     NetworkStateEstimator* network_state_estimator)
-    : remote_bitrate_estimator_(packet_router, clock),
+    : remb_throttler_([](auto...) {}, clock),
+      remote_bitrate_estimator_(packet_router, clock),
+      remote_estimator_proxy_(
+          clock,
+          [packet_router](
+              std::vector<std::unique_ptr<rtcp::RtcpPacket>> packets) {
+            packet_router->SendCombinedRtcpPacket(std::move(packets));
+          },
+          &field_trial_config_,
+          network_state_estimator) {}
+
+ReceiveSideCongestionController::ReceiveSideCongestionController(
+    Clock* clock,
+    RemoteEstimatorProxy::TransportFeedbackSender feedback_sender,
+    RembThrottler::RembSender remb_sender,
+    NetworkStateEstimator* network_state_estimator)
+    : remb_throttler_(std::move(remb_sender), clock),
+      remote_bitrate_estimator_(&remb_throttler_, clock),
       remote_estimator_proxy_(clock,
-                              packet_router,
+                              std::move(feedback_sender),
                               &field_trial_config_,
                               network_state_estimator) {}
 
@@ -186,4 +204,9 @@
   remote_bitrate_estimator_.Process();
 }
 
+void ReceiveSideCongestionController::SetMaxDesiredReceiveBitrate(
+    DataRate bitrate) {
+  remb_throttler_.SetMaxDesiredReceiveBitrate(bitrate);
+}
+
 }  // namespace webrtc
diff --git a/modules/congestion_controller/receive_side_congestion_controller_unittest.cc b/modules/congestion_controller/receive_side_congestion_controller_unittest.cc
index 5622c84..5e03179 100644
--- a/modules/congestion_controller/receive_side_congestion_controller_unittest.cc
+++ b/modules/congestion_controller/receive_side_congestion_controller_unittest.cc
@@ -20,10 +20,8 @@
 
 using ::testing::_;
 using ::testing::AtLeast;
-using ::testing::NiceMock;
-using ::testing::Return;
-using ::testing::SaveArg;
-using ::testing::StrictMock;
+using ::testing::ElementsAre;
+using ::testing::MockFunction;
 
 namespace webrtc {
 
@@ -37,34 +35,28 @@
   return (((t << 18) + (denom >> 1)) / denom) & 0x00fffffful;
 }
 
-class MockPacketRouter : public PacketRouter {
- public:
-  MOCK_METHOD(void,
-              OnReceiveBitrateChanged,
-              (const std::vector<uint32_t>& ssrcs, uint32_t bitrate),
-              (override));
-};
-
 const uint32_t kInitialBitrateBps = 60000;
 
 }  // namespace
 
 namespace test {
 
-TEST(ReceiveSideCongestionControllerTest, OnReceivedPacketWithAbsSendTime) {
-  StrictMock<MockPacketRouter> packet_router;
+TEST(ReceiveSideCongestionControllerTest, SendsRembWithAbsSendTime) {
+  MockFunction<void(std::vector<std::unique_ptr<rtcp::RtcpPacket>>)>
+      feedback_sender;
+  MockFunction<void(uint64_t, std::vector<uint32_t>)> remb_sender;
   SimulatedClock clock_(123456);
 
-  ReceiveSideCongestionController controller(&clock_, &packet_router);
+  ReceiveSideCongestionController controller(
+      &clock_, feedback_sender.AsStdFunction(), remb_sender.AsStdFunction(),
+      nullptr);
 
   size_t payload_size = 1000;
   RTPHeader header;
   header.ssrc = 0x11eb21c;
   header.extension.hasAbsoluteSendTime = true;
 
-  std::vector<unsigned int> ssrcs;
-  EXPECT_CALL(packet_router, OnReceiveBitrateChanged(_, _))
-      .WillRepeatedly(SaveArg<0>(&ssrcs));
+  EXPECT_CALL(remb_sender, Call(_, ElementsAre(header.ssrc))).Times(AtLeast(1));
 
   for (int i = 0; i < 10; ++i) {
     clock_.AdvanceTimeMilliseconds((1000 * payload_size) / kInitialBitrateBps);
@@ -72,9 +64,20 @@
     header.extension.absoluteSendTime = AbsSendTime(now_ms, 1000);
     controller.OnReceivedPacket(now_ms, payload_size, header);
   }
+}
 
-  ASSERT_EQ(1u, ssrcs.size());
-  EXPECT_EQ(header.ssrc, ssrcs[0]);
+TEST(ReceiveSideCongestionControllerTest,
+     SendsRembAfterSetMaxDesiredReceiveBitrate) {
+  MockFunction<void(std::vector<std::unique_ptr<rtcp::RtcpPacket>>)>
+      feedback_sender;
+  MockFunction<void(uint64_t, std::vector<uint32_t>)> remb_sender;
+  SimulatedClock clock_(123456);
+
+  ReceiveSideCongestionController controller(
+      &clock_, feedback_sender.AsStdFunction(), remb_sender.AsStdFunction(),
+      nullptr);
+  EXPECT_CALL(remb_sender, Call(123, _));
+  controller.SetMaxDesiredReceiveBitrate(DataRate::BitsPerSec(123));
 }
 
 TEST(ReceiveSideCongestionControllerTest, ConvergesToCapacity) {
diff --git a/modules/congestion_controller/remb_throttler.cc b/modules/congestion_controller/remb_throttler.cc
new file mode 100644
index 0000000..fcc30af
--- /dev/null
+++ b/modules/congestion_controller/remb_throttler.cc
@@ -0,0 +1,63 @@
+/*
+ *  Copyright (c) 2021 The WebRTC project authors. All Rights Reserved.
+ *
+ *  Use of this source code is governed by a BSD-style license
+ *  that can be found in the LICENSE file in the root of the source
+ *  tree. An additional intellectual property rights grant can be found
+ *  in the file PATENTS.  All contributing project authors may
+ *  be found in the AUTHORS file in the root of the source tree.
+ */
+
+#include "modules/congestion_controller/remb_throttler.h"
+
+#include <algorithm>
+#include <utility>
+
+namespace webrtc {
+
+namespace {
+constexpr TimeDelta kRembSendInterval = TimeDelta::Millis(200);
+}  // namespace
+
+RembThrottler::RembThrottler(RembSender remb_sender, Clock* clock)
+    : remb_sender_(std::move(remb_sender)),
+      clock_(clock),
+      last_remb_time_(Timestamp::MinusInfinity()),
+      last_send_remb_bitrate_(DataRate::PlusInfinity()),
+      max_remb_bitrate_(DataRate::PlusInfinity()) {}
+
+void RembThrottler::OnReceiveBitrateChanged(const std::vector<uint32_t>& ssrcs,
+                                            uint32_t bitrate_bps) {
+  DataRate receive_bitrate = DataRate::BitsPerSec(bitrate_bps);
+  Timestamp now = clock_->CurrentTime();
+  {
+    MutexLock lock(&mutex_);
+    // % threshold for if we should send a new REMB asap.
+    const int64_t kSendThresholdPercent = 103;
+    if (receive_bitrate * kSendThresholdPercent / 100 >
+            last_send_remb_bitrate_ &&
+        now < last_remb_time_ + kRembSendInterval) {
+      return;
+    }
+    last_remb_time_ = now;
+    last_send_remb_bitrate_ = receive_bitrate;
+    receive_bitrate = std::min(last_send_remb_bitrate_, max_remb_bitrate_);
+  }
+  remb_sender_(receive_bitrate.bps(), ssrcs);
+}
+
+void RembThrottler::SetMaxDesiredReceiveBitrate(DataRate bitrate) {
+  Timestamp now = clock_->CurrentTime();
+  {
+    MutexLock lock(&mutex_);
+    max_remb_bitrate_ = bitrate;
+    if (now - last_remb_time_ < kRembSendInterval &&
+        !last_send_remb_bitrate_.IsZero() &&
+        last_send_remb_bitrate_ <= max_remb_bitrate_) {
+      return;
+    }
+  }
+  remb_sender_(bitrate.bps(), /*ssrcs=*/{});
+}
+
+}  // namespace webrtc
diff --git a/modules/congestion_controller/remb_throttler.h b/modules/congestion_controller/remb_throttler.h
new file mode 100644
index 0000000..67c0280
--- /dev/null
+++ b/modules/congestion_controller/remb_throttler.h
@@ -0,0 +1,54 @@
+/*
+ *  Copyright (c) 2021 The WebRTC project authors. All Rights Reserved.
+ *
+ *  Use of this source code is governed by a BSD-style license
+ *  that can be found in the LICENSE file in the root of the source
+ *  tree. An additional intellectual property rights grant can be found
+ *  in the file PATENTS.  All contributing project authors may
+ *  be found in the AUTHORS file in the root of the source tree.
+ */
+#ifndef MODULES_CONGESTION_CONTROLLER_REMB_THROTTLER_H_
+#define MODULES_CONGESTION_CONTROLLER_REMB_THROTTLER_H_
+
+#include <functional>
+#include <vector>
+
+#include "api/units/data_rate.h"
+#include "api/units/time_delta.h"
+#include "api/units/timestamp.h"
+#include "modules/remote_bitrate_estimator/remote_estimator_proxy.h"
+#include "rtc_base/synchronization/mutex.h"
+
+namespace webrtc {
+
+// RembThrottler is a helper class used for throttling RTCP REMB messages.
+// Throttles small changes to the received BWE within 200ms.
+class RembThrottler : public RemoteBitrateObserver {
+ public:
+  using RembSender =
+      std::function<void(int64_t bitrate_bps, std::vector<uint32_t> ssrcs)>;
+  RembThrottler(RembSender remb_sender, Clock* clock);
+
+  // Ensures the remote party is notified of the receive bitrate no larger than
+  // |bitrate| using RTCP REMB.
+  void SetMaxDesiredReceiveBitrate(DataRate bitrate);
+
+  // Implements RemoteBitrateObserver;
+  // Called every time there is a new bitrate estimate for a receive channel
+  // group. This call will trigger a new RTCP REMB packet if the bitrate
+  // estimate has decreased or if no RTCP REMB packet has been sent for
+  // a certain time interval.
+  void OnReceiveBitrateChanged(const std::vector<uint32_t>& ssrcs,
+                               uint32_t bitrate_bps) override;
+
+ private:
+  const RembSender remb_sender_;
+  Clock* const clock_;
+  mutable Mutex mutex_;
+  Timestamp last_remb_time_ RTC_GUARDED_BY(mutex_);
+  DataRate last_send_remb_bitrate_ RTC_GUARDED_BY(mutex_);
+  DataRate max_remb_bitrate_ RTC_GUARDED_BY(mutex_);
+};
+
+}  // namespace webrtc
+#endif  // MODULES_CONGESTION_CONTROLLER_REMB_THROTTLER_H_
diff --git a/modules/congestion_controller/remb_throttler_unittest.cc b/modules/congestion_controller/remb_throttler_unittest.cc
new file mode 100644
index 0000000..3f8df8a
--- /dev/null
+++ b/modules/congestion_controller/remb_throttler_unittest.cc
@@ -0,0 +1,100 @@
+/*
+ *  Copyright (c) 2021 The WebRTC project authors. All Rights Reserved.
+ *
+ *  Use of this source code is governed by a BSD-style license
+ *  that can be found in the LICENSE file in the root of the source
+ *  tree. An additional intellectual property rights grant can be found
+ *  in the file PATENTS.  All contributing project authors may
+ *  be found in the AUTHORS file in the root of the source tree.
+ */
+#include "modules/congestion_controller/remb_throttler.h"
+
+#include <vector>
+
+#include "api/units/data_rate.h"
+#include "api/units/time_delta.h"
+#include "system_wrappers/include/clock.h"
+#include "test/gmock.h"
+#include "test/gtest.h"
+
+namespace webrtc {
+
+using ::testing::_;
+using ::testing::MockFunction;
+
+TEST(RembThrottlerTest, CallRembSenderOnFirstReceiveBitrateChange) {
+  SimulatedClock clock(Timestamp::Zero());
+  MockFunction<void(uint64_t, std::vector<uint32_t>)> remb_sender;
+  RembThrottler remb_throttler(remb_sender.AsStdFunction(), &clock);
+
+  EXPECT_CALL(remb_sender, Call(12345, std::vector<uint32_t>({1, 2, 3})));
+  remb_throttler.OnReceiveBitrateChanged({1, 2, 3}, /*bitrate_bps=*/12345);
+}
+
+TEST(RembThrottlerTest, ThrottlesSmallReceiveBitrateDecrease) {
+  SimulatedClock clock(Timestamp::Zero());
+  MockFunction<void(uint64_t, std::vector<uint32_t>)> remb_sender;
+  RembThrottler remb_throttler(remb_sender.AsStdFunction(), &clock);
+
+  EXPECT_CALL(remb_sender, Call);
+  remb_throttler.OnReceiveBitrateChanged({1, 2, 3}, /*bitrate_bps=*/12346);
+  clock.AdvanceTime(TimeDelta::Millis(100));
+  remb_throttler.OnReceiveBitrateChanged({1, 2, 3}, /*bitrate_bps=*/12345);
+
+  EXPECT_CALL(remb_sender, Call(12345, _));
+  clock.AdvanceTime(TimeDelta::Millis(101));
+  remb_throttler.OnReceiveBitrateChanged({1, 2, 3}, /*bitrate_bps=*/12345);
+}
+
+TEST(RembThrottlerTest, DoNotThrottleLargeReceiveBitrateDecrease) {
+  SimulatedClock clock(Timestamp::Zero());
+  MockFunction<void(uint64_t, std::vector<uint32_t>)> remb_sender;
+  RembThrottler remb_throttler(remb_sender.AsStdFunction(), &clock);
+
+  EXPECT_CALL(remb_sender, Call(2345, _));
+  EXPECT_CALL(remb_sender, Call(1234, _));
+  remb_throttler.OnReceiveBitrateChanged({1, 2, 3}, /*bitrate_bps=*/2345);
+  clock.AdvanceTime(TimeDelta::Millis(1));
+  remb_throttler.OnReceiveBitrateChanged({1, 2, 3}, /*bitrate_bps=*/1234);
+}
+
+TEST(RembThrottlerTest, ThrottlesReceiveBitrateIncrease) {
+  SimulatedClock clock(Timestamp::Zero());
+  MockFunction<void(uint64_t, std::vector<uint32_t>)> remb_sender;
+  RembThrottler remb_throttler(remb_sender.AsStdFunction(), &clock);
+
+  EXPECT_CALL(remb_sender, Call);
+  remb_throttler.OnReceiveBitrateChanged({1, 2, 3}, /*bitrate_bps=*/1234);
+  clock.AdvanceTime(TimeDelta::Millis(100));
+  remb_throttler.OnReceiveBitrateChanged({1, 2, 3}, /*bitrate_bps=*/2345);
+
+  // Updates 200ms after previous callback is not throttled.
+  EXPECT_CALL(remb_sender, Call(2345, _));
+  clock.AdvanceTime(TimeDelta::Millis(101));
+  remb_throttler.OnReceiveBitrateChanged({1, 2, 3}, /*bitrate_bps=*/2345);
+}
+
+TEST(RembThrottlerTest, CallRembSenderOnSetMaxDesiredReceiveBitrate) {
+  SimulatedClock clock(Timestamp::Zero());
+  MockFunction<void(uint64_t, std::vector<uint32_t>)> remb_sender;
+  RembThrottler remb_throttler(remb_sender.AsStdFunction(), &clock);
+  EXPECT_CALL(remb_sender, Call(1234, _));
+  remb_throttler.SetMaxDesiredReceiveBitrate(DataRate::BitsPerSec(1234));
+}
+
+TEST(RembThrottlerTest, CallRembSenderWithMinOfMaxDesiredAndOnReceivedBitrate) {
+  SimulatedClock clock(Timestamp::Zero());
+  MockFunction<void(uint64_t, std::vector<uint32_t>)> remb_sender;
+  RembThrottler remb_throttler(remb_sender.AsStdFunction(), &clock);
+
+  EXPECT_CALL(remb_sender, Call(1234, _));
+  remb_throttler.OnReceiveBitrateChanged({1, 2, 3}, /*bitrate_bps=*/1234);
+  clock.AdvanceTime(TimeDelta::Millis(1));
+  remb_throttler.SetMaxDesiredReceiveBitrate(DataRate::BitsPerSec(4567));
+
+  clock.AdvanceTime(TimeDelta::Millis(200));
+  EXPECT_CALL(remb_sender, Call(4567, _));
+  remb_throttler.OnReceiveBitrateChanged({1, 2, 3}, /*bitrate_bps=*/5678);
+}
+
+}  // namespace webrtc
diff --git a/modules/remote_bitrate_estimator/include/remote_bitrate_estimator.h b/modules/remote_bitrate_estimator/include/remote_bitrate_estimator.h
index c60c030..a9edfb3 100644
--- a/modules/remote_bitrate_estimator/include/remote_bitrate_estimator.h
+++ b/modules/remote_bitrate_estimator/include/remote_bitrate_estimator.h
@@ -38,6 +38,7 @@
   virtual ~RemoteBitrateObserver() {}
 };
 
+// TODO(bugs.webrtc.org/12693): Deprecate
 class TransportFeedbackSenderInterface {
  public:
   virtual ~TransportFeedbackSenderInterface() = default;
diff --git a/modules/remote_bitrate_estimator/remote_estimator_proxy.cc b/modules/remote_bitrate_estimator/remote_estimator_proxy.cc
index a9cc170..a336109 100644
--- a/modules/remote_bitrate_estimator/remote_estimator_proxy.cc
+++ b/modules/remote_bitrate_estimator/remote_estimator_proxy.cc
@@ -33,11 +33,11 @@
 
 RemoteEstimatorProxy::RemoteEstimatorProxy(
     Clock* clock,
-    TransportFeedbackSenderInterface* feedback_sender,
+    TransportFeedbackSender feedback_sender,
     const WebRtcKeyValueConfig* key_value_config,
     NetworkStateEstimator* network_state_estimator)
     : clock_(clock),
-      feedback_sender_(feedback_sender),
+      feedback_sender_(std::move(feedback_sender)),
       send_config_(key_value_config),
       last_process_time_ms_(-1),
       network_state_estimator_(network_state_estimator),
@@ -217,7 +217,7 @@
     }
     packets.push_back(std::move(feedback_packet));
 
-    feedback_sender_->SendCombinedRtcpPacket(std::move(packets));
+    feedback_sender_(std::move(packets));
     // Note: Don't erase items from packet_arrival_times_ after sending, in case
     // they need to be re-sent after a reordering. Removal will be handled
     // by OnPacketArrival once packets are too old.
@@ -250,7 +250,7 @@
   RTC_DCHECK(feedback_sender_ != nullptr);
   std::vector<std::unique_ptr<rtcp::RtcpPacket>> packets;
   packets.push_back(std::move(feedback_packet));
-  feedback_sender_->SendCombinedRtcpPacket(std::move(packets));
+  feedback_sender_(std::move(packets));
 }
 
 int64_t RemoteEstimatorProxy::BuildFeedbackPacket(
diff --git a/modules/remote_bitrate_estimator/remote_estimator_proxy.h b/modules/remote_bitrate_estimator/remote_estimator_proxy.h
index a4adefc..5aabfe1 100644
--- a/modules/remote_bitrate_estimator/remote_estimator_proxy.h
+++ b/modules/remote_bitrate_estimator/remote_estimator_proxy.h
@@ -11,7 +11,9 @@
 #ifndef MODULES_REMOTE_BITRATE_ESTIMATOR_REMOTE_ESTIMATOR_PROXY_H_
 #define MODULES_REMOTE_BITRATE_ESTIMATOR_REMOTE_ESTIMATOR_PROXY_H_
 
+#include <functional>
 #include <map>
+#include <memory>
 #include <vector>
 
 #include "api/transport/network_control.h"
@@ -24,7 +26,6 @@
 namespace webrtc {
 
 class Clock;
-class PacketRouter;
 namespace rtcp {
 class TransportFeedback;
 }
@@ -32,11 +33,14 @@
 // Class used when send-side BWE is enabled: This proxy is instantiated on the
 // receive side. It buffers a number of receive timestamps and then sends
 // transport feedback messages back too the send side.
-
 class RemoteEstimatorProxy : public RemoteBitrateEstimator {
  public:
+  // Used for sending transport feedback messages when send side
+  // BWE is used.
+  using TransportFeedbackSender = std::function<void(
+      std::vector<std::unique_ptr<rtcp::RtcpPacket>> packets)>;
   RemoteEstimatorProxy(Clock* clock,
-                       TransportFeedbackSenderInterface* feedback_sender,
+                       TransportFeedbackSender feedback_sender,
                        const WebRtcKeyValueConfig* key_value_config,
                        NetworkStateEstimator* network_state_estimator);
   ~RemoteEstimatorProxy() override;
@@ -88,7 +92,7 @@
       rtcp::TransportFeedback* feedback_packet);
 
   Clock* const clock_;
-  TransportFeedbackSenderInterface* const feedback_sender_;
+  const TransportFeedbackSender feedback_sender_;
   const TransportWideFeedbackConfig send_config_;
   int64_t last_process_time_ms_;
 
diff --git a/modules/remote_bitrate_estimator/remote_estimator_proxy_unittest.cc b/modules/remote_bitrate_estimator/remote_estimator_proxy_unittest.cc
index da99592..296724f 100644
--- a/modules/remote_bitrate_estimator/remote_estimator_proxy_unittest.cc
+++ b/modules/remote_bitrate_estimator/remote_estimator_proxy_unittest.cc
@@ -16,8 +16,8 @@
 #include "api/transport/field_trial_based_config.h"
 #include "api/transport/network_types.h"
 #include "api/transport/test/mock_network_control.h"
-#include "modules/pacing/packet_router.h"
 #include "modules/rtp_rtcp/source/rtcp_packet/transport_feedback.h"
+#include "modules/rtp_rtcp/source/rtp_header_extensions.h"
 #include "system_wrappers/include/clock.h"
 #include "test/gmock.h"
 #include "test/gtest.h"
@@ -25,6 +25,7 @@
 using ::testing::_;
 using ::testing::ElementsAre;
 using ::testing::Invoke;
+using ::testing::MockFunction;
 using ::testing::Return;
 using ::testing::SizeIs;
 
@@ -63,20 +64,12 @@
   return timestamps;
 }
 
-class MockTransportFeedbackSender : public TransportFeedbackSenderInterface {
- public:
-  MOCK_METHOD(bool,
-              SendCombinedRtcpPacket,
-              (std::vector<std::unique_ptr<rtcp::RtcpPacket>> feedback_packets),
-              (override));
-};
-
 class RemoteEstimatorProxyTest : public ::testing::Test {
  public:
   RemoteEstimatorProxyTest()
       : clock_(0),
         proxy_(&clock_,
-               &router_,
+               feedback_sender_.AsStdFunction(),
                &field_trial_config_,
                &network_state_estimator_) {}
 
@@ -113,7 +106,8 @@
 
   FieldTrialBasedConfig field_trial_config_;
   SimulatedClock clock_;
-  ::testing::StrictMock<MockTransportFeedbackSender> router_;
+  MockFunction<void(std::vector<std::unique_ptr<rtcp::RtcpPacket>>)>
+      feedback_sender_;
   ::testing::NiceMock<MockNetworkStateEstimator> network_state_estimator_;
   RemoteEstimatorProxy proxy_;
 };
@@ -121,7 +115,7 @@
 TEST_F(RemoteEstimatorProxyTest, SendsSinglePacketFeedback) {
   IncomingPacket(kBaseSeq, kBaseTimeMs);
 
-  EXPECT_CALL(router_, SendCombinedRtcpPacket)
+  EXPECT_CALL(feedback_sender_, Call)
       .WillOnce(Invoke(
           [](std::vector<std::unique_ptr<rtcp::RtcpPacket>> feedback_packets) {
             rtcp::TransportFeedback* feedback_packet =
@@ -134,7 +128,6 @@
                         ElementsAre(kBaseSeq));
             EXPECT_THAT(TimestampsMs(*feedback_packet),
                         ElementsAre(kBaseTimeMs));
-            return true;
           }));
 
   Process();
@@ -144,7 +137,7 @@
   IncomingPacket(kBaseSeq, kBaseTimeMs);
   IncomingPacket(kBaseSeq, kBaseTimeMs + 1000);
 
-  EXPECT_CALL(router_, SendCombinedRtcpPacket)
+  EXPECT_CALL(feedback_sender_, Call)
       .WillOnce(Invoke(
           [](std::vector<std::unique_ptr<rtcp::RtcpPacket>> feedback_packets) {
             rtcp::TransportFeedback* feedback_packet =
@@ -167,13 +160,13 @@
   // First feedback.
   IncomingPacket(kBaseSeq, kBaseTimeMs);
   IncomingPacket(kBaseSeq + 1, kBaseTimeMs + 1000);
-  EXPECT_CALL(router_, SendCombinedRtcpPacket).WillOnce(Return(true));
+  EXPECT_CALL(feedback_sender_, Call);
   Process();
 
   // Second feedback starts with a missing packet (DROP kBaseSeq + 2).
   IncomingPacket(kBaseSeq + 3, kBaseTimeMs + 3000);
 
-  EXPECT_CALL(router_, SendCombinedRtcpPacket)
+  EXPECT_CALL(feedback_sender_, Call)
       .WillOnce(Invoke(
           [](std::vector<std::unique_ptr<rtcp::RtcpPacket>> feedback_packets) {
             rtcp::TransportFeedback* feedback_packet =
@@ -186,7 +179,6 @@
                         ElementsAre(kBaseSeq + 3));
             EXPECT_THAT(TimestampsMs(*feedback_packet),
                         ElementsAre(kBaseTimeMs + 3000));
-            return true;
           }));
 
   Process();
@@ -197,7 +189,7 @@
   IncomingPacket(kBaseSeq + 1, kBaseTimeMs + kMaxSmallDeltaMs);
   IncomingPacket(kBaseSeq + 2, kBaseTimeMs + (2 * kMaxSmallDeltaMs) + 1);
 
-  EXPECT_CALL(router_, SendCombinedRtcpPacket)
+  EXPECT_CALL(feedback_sender_, Call)
       .WillOnce(Invoke(
           [](std::vector<std::unique_ptr<rtcp::RtcpPacket>> feedback_packets) {
             rtcp::TransportFeedback* feedback_packet =
@@ -211,7 +203,6 @@
             EXPECT_THAT(TimestampsMs(*feedback_packet),
                         ElementsAre(kBaseTimeMs, kBaseTimeMs + kMaxSmallDeltaMs,
                                     kBaseTimeMs + (2 * kMaxSmallDeltaMs) + 1));
-            return true;
           }));
 
   Process();
@@ -224,7 +215,7 @@
   IncomingPacket(kBaseSeq, kBaseTimeMs);
   IncomingPacket(kBaseSeq + 1, kBaseTimeMs + kTooLargeDelta);
 
-  EXPECT_CALL(router_, SendCombinedRtcpPacket)
+  EXPECT_CALL(feedback_sender_, Call)
       .WillOnce(Invoke(
           [](std::vector<std::unique_ptr<rtcp::RtcpPacket>> feedback_packets) {
             rtcp::TransportFeedback* feedback_packet =
@@ -237,7 +228,6 @@
                         ElementsAre(kBaseSeq));
             EXPECT_THAT(TimestampsMs(*feedback_packet),
                         ElementsAre(kBaseTimeMs));
-            return true;
           }))
       .WillOnce(Invoke(
           [](std::vector<std::unique_ptr<rtcp::RtcpPacket>> feedback_packets) {
@@ -251,7 +241,6 @@
                         ElementsAre(kBaseSeq + 1));
             EXPECT_THAT(TimestampsMs(*feedback_packet),
                         ElementsAre(kBaseTimeMs + kTooLargeDelta));
-            return true;
           }));
 
   Process();
@@ -263,7 +252,7 @@
   IncomingPacket(kBaseSeq, kBaseTimeMs);
   IncomingPacket(kLargeSeq, kBaseTimeMs + kDeltaMs);
 
-  EXPECT_CALL(router_, SendCombinedRtcpPacket)
+  EXPECT_CALL(feedback_sender_, Call)
       .WillOnce(Invoke(
           [&](std::vector<std::unique_ptr<rtcp::RtcpPacket>> feedback_packets) {
             rtcp::TransportFeedback* feedback_packet =
@@ -274,7 +263,6 @@
 
             EXPECT_THAT(TimestampsMs(*feedback_packet),
                         ElementsAre(kBaseTimeMs + kDeltaMs, kBaseTimeMs));
-            return true;
           }));
 
   Process();
@@ -293,7 +281,7 @@
   }
 
   // Only expect feedback for the last two packets.
-  EXPECT_CALL(router_, SendCombinedRtcpPacket)
+  EXPECT_CALL(feedback_sender_, Call)
       .WillOnce(Invoke(
           [&](std::vector<std::unique_ptr<rtcp::RtcpPacket>> feedback_packets) {
             rtcp::TransportFeedback* feedback_packet =
@@ -306,7 +294,6 @@
             EXPECT_THAT(TimestampsMs(*feedback_packet),
                         ElementsAre(kBaseTimeMs + 28 * kDeltaMs,
                                     kBaseTimeMs + 29 * kDeltaMs));
-            return true;
           }));
 
   Process();
@@ -324,7 +311,7 @@
   }
 
   // Only expect feedback for the first two packets.
-  EXPECT_CALL(router_, SendCombinedRtcpPacket)
+  EXPECT_CALL(feedback_sender_, Call)
       .WillOnce(Invoke(
           [&](std::vector<std::unique_ptr<rtcp::RtcpPacket>> feedback_packets) {
             rtcp::TransportFeedback* feedback_packet =
@@ -336,7 +323,6 @@
                         ElementsAre(kBaseSeq + 40000, kBaseSeq));
             EXPECT_THAT(TimestampsMs(*feedback_packet),
                         ElementsAre(kBaseTimeMs + kDeltaMs, kBaseTimeMs));
-            return true;
           }));
 
   Process();
@@ -346,7 +332,7 @@
   IncomingPacket(kBaseSeq, kBaseTimeMs);
   IncomingPacket(kBaseSeq + 2, kBaseTimeMs + 2);
 
-  EXPECT_CALL(router_, SendCombinedRtcpPacket)
+  EXPECT_CALL(feedback_sender_, Call)
       .WillOnce(Invoke(
           [](std::vector<std::unique_ptr<rtcp::RtcpPacket>> feedback_packets) {
             rtcp::TransportFeedback* feedback_packet =
@@ -359,14 +345,13 @@
                         ElementsAre(kBaseSeq, kBaseSeq + 2));
             EXPECT_THAT(TimestampsMs(*feedback_packet),
                         ElementsAre(kBaseTimeMs, kBaseTimeMs + 2));
-            return true;
           }));
 
   Process();
 
   IncomingPacket(kBaseSeq + 1, kBaseTimeMs + 1);
 
-  EXPECT_CALL(router_, SendCombinedRtcpPacket)
+  EXPECT_CALL(feedback_sender_, Call)
       .WillOnce(Invoke(
           [](std::vector<std::unique_ptr<rtcp::RtcpPacket>> feedback_packets) {
             rtcp::TransportFeedback* feedback_packet =
@@ -379,7 +364,6 @@
                         ElementsAre(kBaseSeq + 1, kBaseSeq + 2));
             EXPECT_THAT(TimestampsMs(*feedback_packet),
                         ElementsAre(kBaseTimeMs + 1, kBaseTimeMs + 2));
-            return true;
           }));
 
   Process();
@@ -390,7 +374,7 @@
 
   IncomingPacket(kBaseSeq + 2, kBaseTimeMs);
 
-  EXPECT_CALL(router_, SendCombinedRtcpPacket)
+  EXPECT_CALL(feedback_sender_, Call)
       .WillOnce(Invoke(
           [](std::vector<std::unique_ptr<rtcp::RtcpPacket>> feedback_packets) {
             rtcp::TransportFeedback* feedback_packet =
@@ -400,14 +384,13 @@
 
             EXPECT_THAT(TimestampsMs(*feedback_packet),
                         ElementsAre(kBaseTimeMs));
-            return true;
           }));
 
   Process();
 
   IncomingPacket(kBaseSeq + 3, kTimeoutTimeMs);  // kBaseSeq + 2 times out here.
 
-  EXPECT_CALL(router_, SendCombinedRtcpPacket)
+  EXPECT_CALL(feedback_sender_, Call)
       .WillOnce(Invoke(
           [&](std::vector<std::unique_ptr<rtcp::RtcpPacket>> feedback_packets) {
             rtcp::TransportFeedback* feedback_packet =
@@ -417,7 +400,6 @@
 
             EXPECT_THAT(TimestampsMs(*feedback_packet),
                         ElementsAre(kTimeoutTimeMs));
-            return true;
           }));
 
   Process();
@@ -427,7 +409,7 @@
   IncomingPacket(kBaseSeq, kBaseTimeMs - 1);
   IncomingPacket(kBaseSeq + 1, kTimeoutTimeMs - 1);
 
-  EXPECT_CALL(router_, SendCombinedRtcpPacket)
+  EXPECT_CALL(feedback_sender_, Call)
       .WillOnce(Invoke(
           [&](std::vector<std::unique_ptr<rtcp::RtcpPacket>> feedback_packets) {
             rtcp::TransportFeedback* feedback_packet =
@@ -440,7 +422,6 @@
             EXPECT_THAT(TimestampsMs(*feedback_packet),
                         ElementsAre(kBaseTimeMs - 1, kTimeoutTimeMs - 1,
                                     kTimeoutTimeMs));
-            return true;
           }));
 
   Process();
@@ -496,7 +477,7 @@
 TEST_F(RemoteEstimatorProxyOnRequestTest, ProcessDoesNotSendFeedback) {
   proxy_.SetSendPeriodicFeedback(false);
   IncomingPacket(kBaseSeq, kBaseTimeMs);
-  EXPECT_CALL(router_, SendCombinedRtcpPacket).Times(0);
+  EXPECT_CALL(feedback_sender_, Call).Times(0);
   Process();
 }
 
@@ -506,7 +487,7 @@
   IncomingPacket(kBaseSeq + 1, kBaseTimeMs + kMaxSmallDeltaMs);
   IncomingPacket(kBaseSeq + 2, kBaseTimeMs + 2 * kMaxSmallDeltaMs);
 
-  EXPECT_CALL(router_, SendCombinedRtcpPacket)
+  EXPECT_CALL(feedback_sender_, Call)
       .WillOnce(Invoke(
           [](std::vector<std::unique_ptr<rtcp::RtcpPacket>> feedback_packets) {
             rtcp::TransportFeedback* feedback_packet =
@@ -519,7 +500,6 @@
                         ElementsAre(kBaseSeq + 3));
             EXPECT_THAT(TimestampsMs(*feedback_packet),
                         ElementsAre(kBaseTimeMs + 3 * kMaxSmallDeltaMs));
-            return true;
           }));
 
   constexpr FeedbackRequest kSinglePacketFeedbackRequest = {
@@ -535,7 +515,7 @@
     IncomingPacket(kBaseSeq + i, kBaseTimeMs + i * kMaxSmallDeltaMs);
   }
 
-  EXPECT_CALL(router_, SendCombinedRtcpPacket)
+  EXPECT_CALL(feedback_sender_, Call)
       .WillOnce(Invoke(
           [](std::vector<std::unique_ptr<rtcp::RtcpPacket>> feedback_packets) {
             rtcp::TransportFeedback* feedback_packet =
@@ -553,7 +533,6 @@
                                     kBaseTimeMs + 8 * kMaxSmallDeltaMs,
                                     kBaseTimeMs + 9 * kMaxSmallDeltaMs,
                                     kBaseTimeMs + 10 * kMaxSmallDeltaMs));
-            return true;
           }));
 
   constexpr FeedbackRequest kFivePacketsFeedbackRequest = {
@@ -571,7 +550,7 @@
       IncomingPacket(kBaseSeq + i, kBaseTimeMs + i * kMaxSmallDeltaMs);
   }
 
-  EXPECT_CALL(router_, SendCombinedRtcpPacket)
+  EXPECT_CALL(feedback_sender_, Call)
       .WillOnce(Invoke(
           [](std::vector<std::unique_ptr<rtcp::RtcpPacket>> feedback_packets) {
             rtcp::TransportFeedback* feedback_packet =
@@ -586,7 +565,6 @@
                         ElementsAre(kBaseTimeMs + 6 * kMaxSmallDeltaMs,
                                     kBaseTimeMs + 8 * kMaxSmallDeltaMs,
                                     kBaseTimeMs + 10 * kMaxSmallDeltaMs));
-            return true;
           }));
 
   constexpr FeedbackRequest kFivePacketsFeedbackRequest = {
@@ -658,13 +636,7 @@
                    AbsoluteSendTime::MsTo24Bits(kBaseTimeMs - 1)));
   EXPECT_CALL(network_state_estimator_, GetCurrentEstimate())
       .WillOnce(Return(NetworkStateEstimate()));
-  EXPECT_CALL(router_, SendCombinedRtcpPacket)
-      .WillOnce(
-          [](std::vector<std::unique_ptr<rtcp::RtcpPacket>> feedback_packets) {
-            EXPECT_THAT(feedback_packets, SizeIs(2));
-            return true;
-          });
-
+  EXPECT_CALL(feedback_sender_, Call(SizeIs(2)));
   Process();
 }