dcsctp: Add timer safeguards and sanity checks

Ensuring that timer durations never go beyond a safe maximum duration
and that timer IDs are not re-used.

Bug: webrtc:12614
Change-Id: I227a2e9933da16669dc6ea0a39c570892010ba2c
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/215063
Commit-Queue: Victor Boivie <boivie@webrtc.org>
Reviewed-by: Tommi <tommi@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#33860}
diff --git a/net/dcsctp/timer/BUILD.gn b/net/dcsctp/timer/BUILD.gn
index d92aca8..8eec923 100644
--- a/net/dcsctp/timer/BUILD.gn
+++ b/net/dcsctp/timer/BUILD.gn
@@ -14,6 +14,7 @@
     "../../../rtc_base",
     "../../../rtc_base:checks",
     "../../../rtc_base:rtc_base_approved",
+    "../public:strong_alias",
     "../public:types",
   ]
   sources = [
diff --git a/net/dcsctp/timer/timer.cc b/net/dcsctp/timer/timer.cc
index 2376e7a..f3c33ea 100644
--- a/net/dcsctp/timer/timer.cc
+++ b/net/dcsctp/timer/timer.cc
@@ -9,7 +9,9 @@
  */
 #include "net/dcsctp/timer/timer.h"
 
+#include <algorithm>
 #include <cstdint>
+#include <limits>
 #include <memory>
 #include <unordered_map>
 #include <utility>
@@ -17,11 +19,12 @@
 #include "absl/memory/memory.h"
 #include "absl/strings/string_view.h"
 #include "net/dcsctp/public/timeout.h"
+#include "rtc_base/checks.h"
 
 namespace dcsctp {
 namespace {
-TimeoutID MakeTimeoutId(uint32_t timer_id, uint32_t generation) {
-  return TimeoutID(static_cast<uint64_t>(timer_id) << 32 | generation);
+TimeoutID MakeTimeoutId(TimerID timer_id, TimerGeneration generation) {
+  return TimeoutID(static_cast<uint64_t>(*timer_id) << 32 | *generation);
 }
 
 DurationMs GetBackoffDuration(TimerBackoffAlgorithm algorithm,
@@ -30,13 +33,23 @@
   switch (algorithm) {
     case TimerBackoffAlgorithm::kFixed:
       return base_duration;
-    case TimerBackoffAlgorithm::kExponential:
-      return DurationMs(*base_duration * (1 << expiration_count));
+    case TimerBackoffAlgorithm::kExponential: {
+      int32_t duration_ms = *base_duration;
+
+      while (expiration_count > 0 && duration_ms < *Timer::kMaxTimerDuration) {
+        duration_ms *= 2;
+        --expiration_count;
+      }
+
+      return DurationMs(std::min(duration_ms, *Timer::kMaxTimerDuration));
+    }
   }
 }
 }  // namespace
 
-Timer::Timer(uint32_t id,
+constexpr DurationMs Timer::kMaxTimerDuration;
+
+Timer::Timer(TimerID id,
              absl::string_view name,
              OnExpired on_expired,
              UnregisterHandler unregister_handler,
@@ -59,11 +72,13 @@
   expiration_count_ = 0;
   if (!is_running()) {
     is_running_ = true;
-    timeout_->Start(duration_, MakeTimeoutId(id_, ++generation_));
+    generation_ = TimerGeneration(*generation_ + 1);
+    timeout_->Start(duration_, MakeTimeoutId(id_, generation_));
   } else {
     // Timer was running - stop and restart it, to make it expire in `duration_`
     // from now.
-    timeout_->Restart(duration_, MakeTimeoutId(id_, ++generation_));
+    generation_ = TimerGeneration(*generation_ + 1);
+    timeout_->Restart(duration_, MakeTimeoutId(id_, generation_));
   }
 }
 
@@ -75,7 +90,7 @@
   }
 }
 
-void Timer::Trigger(uint32_t generation) {
+void Timer::Trigger(TimerGeneration generation) {
   if (is_running_ && generation == generation_) {
     ++expiration_count_;
     if (options_.max_restarts >= 0 &&
@@ -92,14 +107,15 @@
       // Restart it with new duration.
       DurationMs duration = GetBackoffDuration(options_.backoff_algorithm,
                                                duration_, expiration_count_);
-      timeout_->Start(duration, MakeTimeoutId(id_, ++generation_));
+      generation_ = TimerGeneration(*generation_ + 1);
+      timeout_->Start(duration, MakeTimeoutId(id_, generation_));
     }
   }
 }
 
 void TimerManager::HandleTimeout(TimeoutID timeout_id) {
-  uint32_t timer_id = *timeout_id >> 32;
-  uint32_t generation = *timeout_id;
+  TimerID timer_id(*timeout_id >> 32);
+  TimerGeneration generation(*timeout_id);
   auto it = timers_.find(timer_id);
   if (it != timers_.end()) {
     it->second->Trigger(generation);
@@ -109,7 +125,12 @@
 std::unique_ptr<Timer> TimerManager::CreateTimer(absl::string_view name,
                                                  Timer::OnExpired on_expired,
                                                  const TimerOptions& options) {
-  uint32_t id = ++next_id_;
+  next_id_ = TimerID(*next_id_ + 1);
+  TimerID id = next_id_;
+  // This would overflow after 4 billion timers created, which in SCTP would be
+  // after 800 million reconnections on a single socket. Ensure this will never
+  // happen.
+  RTC_CHECK_NE(*id, std::numeric_limits<uint32_t>::max());
   auto timer = absl::WrapUnique(new Timer(
       id, name, std::move(on_expired), [this, id]() { timers_.erase(id); },
       create_timeout_(), options));
diff --git a/net/dcsctp/timer/timer.h b/net/dcsctp/timer/timer.h
index 6b68c98..bf923ea 100644
--- a/net/dcsctp/timer/timer.h
+++ b/net/dcsctp/timer/timer.h
@@ -12,6 +12,7 @@
 
 #include <stdint.h>
 
+#include <algorithm>
 #include <functional>
 #include <memory>
 #include <string>
@@ -20,10 +21,14 @@
 
 #include "absl/strings/string_view.h"
 #include "absl/types/optional.h"
+#include "net/dcsctp/public/strong_alias.h"
 #include "net/dcsctp/public/timeout.h"
 
 namespace dcsctp {
 
+using TimerID = StrongAlias<class TimerIDTag, uint32_t>;
+using TimerGeneration = StrongAlias<class TimerGenerationTag, uint32_t>;
+
 enum class TimerBackoffAlgorithm {
   // The base duration will be used for any restart.
   kFixed,
@@ -68,6 +73,9 @@
 // backoff algorithm).
 class Timer {
  public:
+  // The maximum timer duration - one day.
+  static constexpr DurationMs kMaxTimerDuration = DurationMs(24 * 3600 * 1000);
+
   // When expired, the timer handler can optionally return a new duration which
   // will be set as `duration` and used as base duration when the timer is
   // restarted and as input to the backoff algorithm.
@@ -89,7 +97,9 @@
 
   // Sets the base duration. The actual timer duration may be larger depending
   // on the backoff algorithm.
-  void set_duration(DurationMs duration) { duration_ = duration; }
+  void set_duration(DurationMs duration) {
+    duration_ = std::min(duration, kMaxTimerDuration);
+  }
 
   // Retrieves the base duration. The actual timer duration may be larger
   // depending on the backoff algorithm.
@@ -110,7 +120,7 @@
  private:
   friend class TimerManager;
   using UnregisterHandler = std::function<void()>;
-  Timer(uint32_t id,
+  Timer(TimerID id,
         absl::string_view name,
         OnExpired on_expired,
         UnregisterHandler unregister,
@@ -122,9 +132,9 @@
   // duration as decided by the backoff algorithm, unless the
   // `TimerOptions::max_restarts` has been reached and then it will be stopped
   // and `is_running()` will return false.
-  void Trigger(uint32_t generation);
+  void Trigger(TimerGeneration generation);
 
-  const uint32_t id_;
+  const TimerID id_;
   const std::string name_;
   const TimerOptions options_;
   const OnExpired on_expired_;
@@ -133,8 +143,16 @@
 
   DurationMs duration_;
 
-  // Increased on each start, and is matched on Trigger, to avoid races.
-  uint32_t generation_ = 0;
+  // Increased on each start, and is matched on Trigger, to avoid races. And by
+  // race, meaning that a timeout - which may be evaluated/expired on a
+  // different thread while this thread has stopped that timer already. Note
+  // that the entire socket is not thread-safe, so `TimerManager::HandleTimeout`
+  // is never executed concurrently with any timer starting/stopping.
+  //
+  // This will wrap around after 4 billion timer restarts, and if it wraps
+  // around, it would just trigger _this_ timer in advance (but it's hard to
+  // restart it 4 billion times within its duration).
+  TimerGeneration generation_ = TimerGeneration(0);
   bool is_running_ = false;
   // Incremented each time time has expired and reset when stopped or restarted.
   int expiration_count_ = 0;
@@ -158,8 +176,8 @@
 
  private:
   const std::function<std::unique_ptr<Timeout>()> create_timeout_;
-  std::unordered_map<int, Timer*> timers_;
-  uint32_t next_id_ = 0;
+  std::unordered_map<TimerID, Timer*, TimerID::Hasher> timers_;
+  TimerID next_id_ = TimerID(0);
 };
 
 }  // namespace dcsctp
diff --git a/net/dcsctp/timer/timer_test.cc b/net/dcsctp/timer/timer_test.cc
index 9533234..719d73e 100644
--- a/net/dcsctp/timer/timer_test.cc
+++ b/net/dcsctp/timer/timer_test.cc
@@ -310,5 +310,41 @@
   AdvanceTimeAndRunTimers(DurationMs(1000));
 }
 
+TEST_F(TimerTest, TimersHaveMaximumDuration) {
+  std::unique_ptr<Timer> t1 = manager_.CreateTimer(
+      "t1", on_expired_.AsStdFunction(),
+      TimerOptions(DurationMs(1000), TimerBackoffAlgorithm::kExponential));
+
+  t1->set_duration(DurationMs(2 * *Timer::kMaxTimerDuration));
+  EXPECT_EQ(t1->duration(), Timer::kMaxTimerDuration);
+}
+
+TEST_F(TimerTest, TimersHaveMaximumBackoffDuration) {
+  std::unique_ptr<Timer> t1 = manager_.CreateTimer(
+      "t1", on_expired_.AsStdFunction(),
+      TimerOptions(DurationMs(1000), TimerBackoffAlgorithm::kExponential));
+
+  t1->Start();
+
+  int max_exponent = static_cast<int>(log2(*Timer::kMaxTimerDuration / 1000));
+  for (int i = 0; i < max_exponent; ++i) {
+    EXPECT_CALL(on_expired_, Call).Times(1);
+    AdvanceTimeAndRunTimers(DurationMs(1000 * (1 << i)));
+  }
+
+  // Reached the maximum duration.
+  EXPECT_CALL(on_expired_, Call).Times(1);
+  AdvanceTimeAndRunTimers(Timer::kMaxTimerDuration);
+
+  EXPECT_CALL(on_expired_, Call).Times(1);
+  AdvanceTimeAndRunTimers(Timer::kMaxTimerDuration);
+
+  EXPECT_CALL(on_expired_, Call).Times(1);
+  AdvanceTimeAndRunTimers(Timer::kMaxTimerDuration);
+
+  EXPECT_CALL(on_expired_, Call).Times(1);
+  AdvanceTimeAndRunTimers(Timer::kMaxTimerDuration);
+}
+
 }  // namespace
 }  // namespace dcsctp