dcsctp: Add Heartbeat Handler

It's responsible for answering incoming Heartbeat Requests, and to
send requests itself when a connection is idle. When it receives
a response, it will measure the RTT and if it doesn't receive a response
in time, that will result in a TX error, which will eventually close
the connection.

Bug: webrtc:12614
Change-Id: I08371d9072ff0461f60e0a2f7696c0fd7ccb57c5
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/214129
Reviewed-by: Tommi <tommi@webrtc.org>
Commit-Queue: Victor Boivie <boivie@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#33828}
diff --git a/net/dcsctp/BUILD.gn b/net/dcsctp/BUILD.gn
index af7082b..8406994 100644
--- a/net/dcsctp/BUILD.gn
+++ b/net/dcsctp/BUILD.gn
@@ -17,6 +17,7 @@
       "packet:dcsctp_packet_unittests",
       "public:dcsctp_public_unittests",
       "rx:dcsctp_rx_unittests",
+      "socket:dcsctp_socket_unittests",
       "timer:dcsctp_timer_unittests",
       "tx:dcsctp_tx_unittests",
     ]
diff --git a/net/dcsctp/socket/BUILD.gn b/net/dcsctp/socket/BUILD.gn
index 45e911a..8d5bdd8 100644
--- a/net/dcsctp/socket/BUILD.gn
+++ b/net/dcsctp/socket/BUILD.gn
@@ -18,6 +18,21 @@
   ]
 }
 
+rtc_library("heartbeat_handler") {
+  deps = [
+    ":context",
+    "../../../api:array_view",
+    "../../../rtc_base",
+    "../../../rtc_base:checks",
+    "../../../rtc_base:rtc_base_approved",
+    "../public:types",
+  ]
+  sources = [
+    "heartbeat_handler.cc",
+    "heartbeat_handler.h",
+  ]
+}
+
 if (rtc_include_tests) {
   rtc_source_set("mock_callbacks") {
     testonly = true
@@ -37,4 +52,18 @@
       "../public:types",
     ]
   }
+
+  rtc_library("dcsctp_socket_unittests") {
+    testonly = true
+
+    deps = [
+      ":heartbeat_handler",
+      "../../../api:array_view",
+      "../../../rtc_base:checks",
+      "../../../rtc_base:gunit_helpers",
+      "../../../rtc_base:rtc_base_approved",
+      "../../../test:test_support",
+    ]
+    sources = [ "heartbeat_handler_test.cc" ]
+  }
 }
diff --git a/net/dcsctp/socket/heartbeat_handler.cc b/net/dcsctp/socket/heartbeat_handler.cc
new file mode 100644
index 0000000..30a0001
--- /dev/null
+++ b/net/dcsctp/socket/heartbeat_handler.cc
@@ -0,0 +1,189 @@
+/*
+ *  Copyright (c) 2021 The WebRTC project authors. All Rights Reserved.
+ *
+ *  Use of this source code is governed by a BSD-style license
+ *  that can be found in the LICENSE file in the root of the source
+ *  tree. An additional intellectual property rights grant can be found
+ *  in the file PATENTS.  All contributing project authors may
+ *  be found in the AUTHORS file in the root of the source tree.
+ */
+#include "net/dcsctp/socket/heartbeat_handler.h"
+
+#include <stddef.h>
+
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "absl/strings/string_view.h"
+#include "absl/types/optional.h"
+#include "api/array_view.h"
+#include "net/dcsctp/packet/bounded_byte_reader.h"
+#include "net/dcsctp/packet/bounded_byte_writer.h"
+#include "net/dcsctp/packet/chunk/heartbeat_ack_chunk.h"
+#include "net/dcsctp/packet/chunk/heartbeat_request_chunk.h"
+#include "net/dcsctp/packet/parameter/heartbeat_info_parameter.h"
+#include "net/dcsctp/packet/parameter/parameter.h"
+#include "net/dcsctp/packet/sctp_packet.h"
+#include "net/dcsctp/public/dcsctp_options.h"
+#include "net/dcsctp/public/dcsctp_socket.h"
+#include "net/dcsctp/socket/context.h"
+#include "net/dcsctp/timer/timer.h"
+#include "rtc_base/logging.h"
+
+namespace dcsctp {
+
+// This is stored (in serialized form) as HeartbeatInfoParameter sent in
+// HeartbeatRequestChunk and received back in HeartbeatAckChunk. It should be
+// well understood that this data may be modified by the peer, so it can't
+// be trusted.
+//
+// It currently only stores a timestamp, in millisecond precision, to allow for
+// RTT measurements. If that would be manipulated by the peer, it would just
+// result in incorrect RTT measurements, which isn't an issue.
+class HeartbeatInfo {
+ public:
+  static constexpr size_t kBufferSize = sizeof(uint64_t);
+  static_assert(kBufferSize == 8, "Unexpected buffer size");
+
+  explicit HeartbeatInfo(TimeMs created_at) : created_at_(created_at) {}
+
+  std::vector<uint8_t> Serialize() {
+    uint32_t high_bits = static_cast<uint32_t>(*created_at_ >> 32);
+    uint32_t low_bits = static_cast<uint32_t>(*created_at_);
+
+    std::vector<uint8_t> data(kBufferSize);
+    BoundedByteWriter<kBufferSize> writer(data);
+    writer.Store32<0>(high_bits);
+    writer.Store32<4>(low_bits);
+    return data;
+  }
+
+  static absl::optional<HeartbeatInfo> Deserialize(
+      rtc::ArrayView<const uint8_t> data) {
+    if (data.size() != kBufferSize) {
+      RTC_LOG(LS_WARNING) << "Invalid heartbeat info: " << data.size()
+                          << " bytes";
+      return absl::nullopt;
+    }
+
+    BoundedByteReader<kBufferSize> reader(data);
+    uint32_t high_bits = reader.Load32<0>();
+    uint32_t low_bits = reader.Load32<4>();
+
+    uint64_t created_at = static_cast<uint64_t>(high_bits) << 32 | low_bits;
+    return HeartbeatInfo(TimeMs(created_at));
+  }
+
+  TimeMs created_at() const { return created_at_; }
+
+ private:
+  const TimeMs created_at_;
+};
+
+HeartbeatHandler::HeartbeatHandler(absl::string_view log_prefix,
+                                   const DcSctpOptions& options,
+                                   Context* context,
+                                   TimerManager* timer_manager)
+    : log_prefix_(std::string(log_prefix) + "heartbeat: "),
+      ctx_(context),
+      timer_manager_(timer_manager),
+      interval_duration_(options.heartbeat_interval),
+      interval_duration_should_include_rtt_(
+          options.heartbeat_interval_include_rtt),
+      interval_timer_(timer_manager_->CreateTimer(
+          "heartbeat-interval",
+          [this]() { return OnIntervalTimerExpiry(); },
+          TimerOptions(interval_duration_, TimerBackoffAlgorithm::kFixed))),
+      timeout_timer_(timer_manager_->CreateTimer(
+          "heartbeat-timeout",
+          [this]() { return OnTimeoutTimerExpiry(); },
+          TimerOptions(options.rto_initial,
+                       TimerBackoffAlgorithm::kExponential,
+                       /*max_restarts=*/0))) {
+  // The interval timer must always be running as long as the association is up.
+  interval_timer_->Start();
+}
+
+void HeartbeatHandler::RestartTimer() {
+  if (interval_duration_should_include_rtt_) {
+    // The RTT should be used, but it's not easy accessible. The RTO will
+    // suffice.
+    interval_timer_->set_duration(interval_duration_ + ctx_->current_rto());
+  } else {
+    interval_timer_->set_duration(interval_duration_);
+  }
+
+  interval_timer_->Start();
+}
+
+void HeartbeatHandler::HandleHeartbeatRequest(HeartbeatRequestChunk chunk) {
+  // https://tools.ietf.org/html/rfc4960#section-8.3
+  // "The receiver of the HEARTBEAT should immediately respond with a
+  // HEARTBEAT ACK that contains the Heartbeat Information TLV, together with
+  // any other received TLVs, copied unchanged from the received HEARTBEAT
+  // chunk."
+  ctx_->Send(ctx_->PacketBuilder().Add(
+      HeartbeatAckChunk(std::move(chunk).extract_parameters())));
+}
+
+void HeartbeatHandler::HandleHeartbeatAck(HeartbeatAckChunk chunk) {
+  timeout_timer_->Stop();
+  absl::optional<HeartbeatInfoParameter> info_param = chunk.info();
+  if (!info_param.has_value()) {
+    ctx_->callbacks().OnError(
+        ErrorKind::kParseFailed,
+        "Failed to parse HEARTBEAT-ACK; No Heartbeat Info parameter");
+    return;
+  }
+  absl::optional<HeartbeatInfo> info =
+      HeartbeatInfo::Deserialize(info_param->info());
+  if (!info.has_value()) {
+    ctx_->callbacks().OnError(ErrorKind::kParseFailed,
+                              "Failed to parse HEARTBEAT-ACK; Failed to "
+                              "deserialized Heartbeat info parameter");
+    return;
+  }
+
+  DurationMs duration(*ctx_->callbacks().TimeMillis() - *info->created_at());
+
+  ctx_->ObserveRTT(duration);
+
+  // https://tools.ietf.org/html/rfc4960#section-8.1
+  // "The counter shall be reset each time ... a HEARTBEAT ACK is received from
+  // the peer endpoint."
+  ctx_->ClearTxErrorCounter();
+}
+
+absl::optional<DurationMs> HeartbeatHandler::OnIntervalTimerExpiry() {
+  if (ctx_->is_connection_established()) {
+    HeartbeatInfo info(ctx_->callbacks().TimeMillis());
+    timeout_timer_->set_duration(ctx_->current_rto());
+    timeout_timer_->Start();
+    RTC_DLOG(LS_INFO) << log_prefix_ << "Sending HEARTBEAT with timeout "
+                      << *timeout_timer_->duration();
+
+    Parameters parameters = Parameters::Builder()
+                                .Add(HeartbeatInfoParameter(info.Serialize()))
+                                .Build();
+
+    ctx_->Send(ctx_->PacketBuilder().Add(
+        HeartbeatRequestChunk(std::move(parameters))));
+  } else {
+    RTC_DLOG(LS_VERBOSE)
+        << log_prefix_
+        << "Will not send HEARTBEAT when connection not established";
+  }
+  return absl::nullopt;
+}
+
+absl::optional<DurationMs> HeartbeatHandler::OnTimeoutTimerExpiry() {
+  // Note that the timeout timer is not restarted. It will be started again when
+  // the interval timer expires.
+  RTC_DCHECK(!timeout_timer_->is_running());
+  ctx_->IncrementTxErrorCounter("HEARTBEAT timeout");
+  return absl::nullopt;
+}
+}  // namespace dcsctp
diff --git a/net/dcsctp/socket/heartbeat_handler.h b/net/dcsctp/socket/heartbeat_handler.h
new file mode 100644
index 0000000..14c3109
--- /dev/null
+++ b/net/dcsctp/socket/heartbeat_handler.h
@@ -0,0 +1,69 @@
+/*
+ *  Copyright (c) 2021 The WebRTC project authors. All Rights Reserved.
+ *
+ *  Use of this source code is governed by a BSD-style license
+ *  that can be found in the LICENSE file in the root of the source
+ *  tree. An additional intellectual property rights grant can be found
+ *  in the file PATENTS.  All contributing project authors may
+ *  be found in the AUTHORS file in the root of the source tree.
+ */
+#ifndef NET_DCSCTP_SOCKET_HEARTBEAT_HANDLER_H_
+#define NET_DCSCTP_SOCKET_HEARTBEAT_HANDLER_H_
+
+#include <stdint.h>
+
+#include <memory>
+#include <string>
+
+#include "absl/strings/string_view.h"
+#include "net/dcsctp/packet/chunk/heartbeat_ack_chunk.h"
+#include "net/dcsctp/packet/chunk/heartbeat_request_chunk.h"
+#include "net/dcsctp/packet/sctp_packet.h"
+#include "net/dcsctp/public/dcsctp_options.h"
+#include "net/dcsctp/socket/context.h"
+#include "net/dcsctp/timer/timer.h"
+
+namespace dcsctp {
+
+// HeartbeatHandler handles all logic around sending heartbeats and receiving
+// the responses, as well as receiving incoming heartbeat requests.
+//
+// Heartbeats are sent on idle connections to ensure that the connection is
+// still healthy and to measure the RTT. If a number of heartbeats time out,
+// the connection will eventually be closed.
+class HeartbeatHandler {
+ public:
+  HeartbeatHandler(absl::string_view log_prefix,
+                   const DcSctpOptions& options,
+                   Context* context,
+                   TimerManager* timer_manager);
+
+  // Called when the heartbeat interval timer should be restarted. This is
+  // generally done every time data is sent, which makes the timer expire when
+  // the connection is idle.
+  void RestartTimer();
+
+  // Called on received HeartbeatRequestChunk chunks.
+  void HandleHeartbeatRequest(HeartbeatRequestChunk chunk);
+
+  // Called on received HeartbeatRequestChunk chunks.
+  void HandleHeartbeatAck(HeartbeatAckChunk chunk);
+
+ private:
+  absl::optional<DurationMs> OnIntervalTimerExpiry();
+  absl::optional<DurationMs> OnTimeoutTimerExpiry();
+
+  const std::string log_prefix_;
+  Context* ctx_;
+  TimerManager* timer_manager_;
+  // The time for a connection to be idle before a heartbeat is sent.
+  const DurationMs interval_duration_;
+  // Adding RTT to the duration will add some jitter, which is good in
+  // production, but less good in unit tests, which is why it can be disabled.
+  const bool interval_duration_should_include_rtt_;
+  const std::unique_ptr<Timer> interval_timer_;
+  const std::unique_ptr<Timer> timeout_timer_;
+};
+}  // namespace dcsctp
+
+#endif  // NET_DCSCTP_SOCKET_HEARTBEAT_HANDLER_H_
diff --git a/net/dcsctp/socket/heartbeat_handler_test.cc b/net/dcsctp/socket/heartbeat_handler_test.cc
new file mode 100644
index 0000000..58dbcff
--- /dev/null
+++ b/net/dcsctp/socket/heartbeat_handler_test.cc
@@ -0,0 +1,123 @@
+/*
+ *  Copyright (c) 2021 The WebRTC project authors. All Rights Reserved.
+ *
+ *  Use of this source code is governed by a BSD-style license
+ *  that can be found in the LICENSE file in the root of the source
+ *  tree. An additional intellectual property rights grant can be found
+ *  in the file PATENTS.  All contributing project authors may
+ *  be found in the AUTHORS file in the root of the source tree.
+ */
+#include "net/dcsctp/socket/heartbeat_handler.h"
+
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "net/dcsctp/packet/chunk/heartbeat_ack_chunk.h"
+#include "net/dcsctp/packet/chunk/heartbeat_request_chunk.h"
+#include "net/dcsctp/packet/parameter/heartbeat_info_parameter.h"
+#include "net/dcsctp/public/types.h"
+#include "net/dcsctp/socket/mock_context.h"
+#include "net/dcsctp/testing/testing_macros.h"
+#include "rtc_base/gunit.h"
+#include "test/gmock.h"
+
+namespace dcsctp {
+namespace {
+using ::testing::ElementsAre;
+using ::testing::IsEmpty;
+using ::testing::NiceMock;
+using ::testing::Return;
+using ::testing::SizeIs;
+
+DcSctpOptions MakeOptions() {
+  DcSctpOptions options;
+  options.heartbeat_interval_include_rtt = false;
+  options.heartbeat_interval = DurationMs(30'000);
+  return options;
+}
+
+class HeartbeatHandlerTest : public testing::Test {
+ protected:
+  HeartbeatHandlerTest()
+      : options_(MakeOptions()),
+        context_(&callbacks_),
+        timer_manager_([this]() { return callbacks_.CreateTimeout(); }),
+        handler_("log: ", options_, &context_, &timer_manager_) {}
+
+  const DcSctpOptions options_;
+  NiceMock<MockDcSctpSocketCallbacks> callbacks_;
+  NiceMock<MockContext> context_;
+  TimerManager timer_manager_;
+  HeartbeatHandler handler_;
+};
+
+TEST_F(HeartbeatHandlerTest, RepliesToHeartbeatRequests) {
+  uint8_t info_data[] = {1, 2, 3, 4, 5};
+  HeartbeatRequestChunk request(
+      Parameters::Builder().Add(HeartbeatInfoParameter(info_data)).Build());
+
+  handler_.HandleHeartbeatRequest(std::move(request));
+
+  std::vector<uint8_t> payload = callbacks_.ConsumeSentPacket();
+  ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket packet, SctpPacket::Parse(payload));
+  ASSERT_THAT(packet.descriptors(), SizeIs(1));
+
+  ASSERT_HAS_VALUE_AND_ASSIGN(
+      HeartbeatAckChunk response,
+      HeartbeatAckChunk::Parse(packet.descriptors()[0].data));
+
+  ASSERT_HAS_VALUE_AND_ASSIGN(
+      HeartbeatInfoParameter param,
+      response.parameters().get<HeartbeatInfoParameter>());
+
+  EXPECT_THAT(param.info(), ElementsAre(1, 2, 3, 4, 5));
+}
+
+TEST_F(HeartbeatHandlerTest, SendsHeartbeatRequestsOnIdleChannel) {
+  callbacks_.AdvanceTime(options_.heartbeat_interval);
+  for (TimeoutID id : callbacks_.RunTimers()) {
+    timer_manager_.HandleTimeout(id);
+  }
+
+  // Grab the request, and make a response.
+  std::vector<uint8_t> payload = callbacks_.ConsumeSentPacket();
+  ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket packet, SctpPacket::Parse(payload));
+  ASSERT_THAT(packet.descriptors(), SizeIs(1));
+
+  ASSERT_HAS_VALUE_AND_ASSIGN(
+      HeartbeatRequestChunk req,
+      HeartbeatRequestChunk::Parse(packet.descriptors()[0].data));
+
+  HeartbeatAckChunk ack(std::move(req).extract_parameters());
+
+  // Respond a while later. This RTT will be measured by the handler
+  constexpr DurationMs rtt(313);
+
+  EXPECT_CALL(context_, ObserveRTT(rtt)).Times(1);
+
+  callbacks_.AdvanceTime(rtt);
+  handler_.HandleHeartbeatAck(std::move(ack));
+}
+
+TEST_F(HeartbeatHandlerTest, IncreasesErrorIfNotAckedInTime) {
+  callbacks_.AdvanceTime(options_.heartbeat_interval);
+
+  DurationMs rto(105);
+  EXPECT_CALL(context_, current_rto).WillOnce(Return(rto));
+  for (TimeoutID id : callbacks_.RunTimers()) {
+    timer_manager_.HandleTimeout(id);
+  }
+
+  // Validate that a request was sent.
+  EXPECT_THAT(callbacks_.ConsumeSentPacket(), Not(IsEmpty()));
+
+  EXPECT_CALL(context_, IncrementTxErrorCounter).Times(1);
+  callbacks_.AdvanceTime(rto);
+  for (TimeoutID id : callbacks_.RunTimers()) {
+    timer_manager_.HandleTimeout(id);
+  }
+}
+
+}  // namespace
+}  // namespace dcsctp