dcsctp: implement socket handover in the DcSctpSocket class and expose the functionality in the API
Bug: webrtc:13154
Change-Id: Idf4f4028c8e65943cb6b41fab0baef1b3584205d
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/232126
Reviewed-by: Victor Boivie <boivie@webrtc.org>
Commit-Queue: Sergey Sukhanov <sergeysu@webrtc.org>
Cr-Commit-Position: refs/heads/main@{#35029}
diff --git a/net/dcsctp/public/dcsctp_handover_state.h b/net/dcsctp/public/dcsctp_handover_state.h
index 8907669..3ad4ab7 100644
--- a/net/dcsctp/public/dcsctp_handover_state.h
+++ b/net/dcsctp/public/dcsctp_handover_state.h
@@ -24,6 +24,25 @@
// for serialization. Serialization is not provided by dcSCTP. If needed it has
// to be implemented in the calling client.
struct DcSctpSocketHandoverState {
+ enum class SocketState {
+ kClosed,
+ kConnected,
+ };
+ SocketState socket_state = SocketState::kClosed;
+
+ uint32_t my_verification_tag = 0;
+ uint32_t my_initial_tsn = 0;
+ uint32_t peer_verification_tag = 0;
+ uint32_t peer_initial_tsn = 0;
+ uint64_t tie_tag = 0;
+
+ struct Capabilities {
+ bool partial_reliability = false;
+ bool message_interleaving = false;
+ bool reconfig = false;
+ };
+ Capabilities capabilities;
+
struct Transmission {
uint32_t next_tsn = 0;
uint32_t next_reset_req_sn = 0;
@@ -98,6 +117,7 @@
value() |= status.value();
return *this;
}
+ std::string ToString() const;
};
} // namespace dcsctp
diff --git a/net/dcsctp/public/dcsctp_socket.h b/net/dcsctp/public/dcsctp_socket.h
index 248646e..583d037 100644
--- a/net/dcsctp/public/dcsctp_socket.h
+++ b/net/dcsctp/public/dcsctp_socket.h
@@ -17,6 +17,7 @@
#include "absl/strings/string_view.h"
#include "absl/types/optional.h"
#include "api/array_view.h"
+#include "net/dcsctp/public/dcsctp_handover_state.h"
#include "net/dcsctp/public/dcsctp_message.h"
#include "net/dcsctp/public/dcsctp_options.h"
#include "net/dcsctp/public/packet_observer.h"
@@ -355,6 +356,14 @@
// `DcSctpSocketCallbacks::OnConnected` will be called on success.
virtual void Connect() = 0;
+ // Puts this socket to the state in which the original socket was when its
+ // `DcSctpSocketHandoverState` was captured by `GetHandoverStateAndClose`.
+ // `RestoreFromState` is allowed only on the closed socket.
+ // `DcSctpSocketCallbacks::OnConnected` will be called if a connected socket
+ // state is restored.
+ // `DcSctpSocketCallbacks::OnError` will be called on error.
+ virtual void RestoreFromState(const DcSctpSocketHandoverState& state) = 0;
+
// Gracefully shutdowns the socket and sends all outstanding data. This is an
// asynchronous operation and `DcSctpSocketCallbacks::OnClosed` will be called
// on success.
@@ -417,6 +426,20 @@
// Retrieves the latest metrics.
virtual Metrics GetMetrics() const = 0;
+
+ // Returns empty bitmask if the socket is in the state in which a snapshot of
+ // the state can be made by `GetHandoverStateAndClose()`. Return value is
+ // invalidated by a call to any non-const method.
+ virtual HandoverReadinessStatus GetHandoverReadiness() const = 0;
+
+ // Collects a snapshot of the socket state that can be used to reconstruct
+ // this socket in another process. On success this socket object is closed
+ // synchronously and no callbacks will be made after the method has returned.
+ // The method fails if the socket is not in a state ready for handover.
+ // nullopt indicates the failure. `DcSctpSocketCallbacks::OnClosed` will be
+ // called on success.
+ virtual absl::optional<DcSctpSocketHandoverState>
+ GetHandoverStateAndClose() = 0;
};
} // namespace dcsctp
diff --git a/net/dcsctp/public/mock_dcsctp_socket.h b/net/dcsctp/public/mock_dcsctp_socket.h
index b382773..eb1e8cc 100644
--- a/net/dcsctp/public/mock_dcsctp_socket.h
+++ b/net/dcsctp/public/mock_dcsctp_socket.h
@@ -26,6 +26,11 @@
MOCK_METHOD(void, Connect, (), (override));
+ MOCK_METHOD(void,
+ RestoreFromState,
+ (const DcSctpSocketHandoverState&),
+ (override));
+
MOCK_METHOD(void, Shutdown, (), (override));
MOCK_METHOD(void, Close, (), (override));
@@ -59,6 +64,15 @@
(override));
MOCK_METHOD(Metrics, GetMetrics, (), (const, override));
+
+ MOCK_METHOD(HandoverReadinessStatus,
+ GetHandoverReadiness,
+ (),
+ (const, override));
+ MOCK_METHOD(absl::optional<DcSctpSocketHandoverState>,
+ GetHandoverStateAndClose,
+ (),
+ (override));
};
} // namespace dcsctp
diff --git a/net/dcsctp/socket/dcsctp_socket.cc b/net/dcsctp/socket/dcsctp_socket.cc
index 5211cca..afc30f6 100644
--- a/net/dcsctp/socket/dcsctp_socket.cc
+++ b/net/dcsctp/socket/dcsctp_socket.cc
@@ -139,8 +139,57 @@
static_cast<uint64_t>(tie_tag_lower));
}
+constexpr absl::string_view HandoverUnreadinessReasonToString(
+ HandoverUnreadinessReason reason) {
+ switch (reason) {
+ case HandoverUnreadinessReason::kWrongConnectionState:
+ return "WRONG_CONNECTION_STATE";
+ case HandoverUnreadinessReason::kSendQueueNotEmpty:
+ return "SEND_QUEUE_NOT_EMPTY";
+ case HandoverUnreadinessReason::kDataTrackerTsnBlocksPending:
+ return "DATA_TRACKER_TSN_BLOCKS_PENDING";
+ case HandoverUnreadinessReason::kReassemblyQueueDeliveredTSNsGap:
+ return "REASSEMBLY_QUEUE_DELIVERED_TSN_GAP";
+ case HandoverUnreadinessReason::kStreamResetDeferred:
+ return "STREAM_RESET_DEFERRED";
+ case HandoverUnreadinessReason::kOrderedStreamHasUnassembledChunks:
+ return "ORDERED_STREAM_HAS_UNASSEMBLED_CHUNKS";
+ case HandoverUnreadinessReason::kUnorderedStreamHasUnassembledChunks:
+ return "UNORDERED_STREAM_HAS_UNASSEMBLED_CHUNKS";
+ case HandoverUnreadinessReason::kRetransmissionQueueOutstandingData:
+ return "RETRANSMISSION_QUEUE_OUTSTANDING_DATA";
+ case HandoverUnreadinessReason::kRetransmissionQueueFastRecovery:
+ return "RETRANSMISSION_QUEUE_FAST_RECOVERY";
+ case HandoverUnreadinessReason::kRetransmissionQueueNotEmpty:
+ return "RETRANSMISSION_QUEUE_NOT_EMPTY";
+ case HandoverUnreadinessReason::kPendingStreamReset:
+ return "PENDING_STREAM_RESET";
+ case HandoverUnreadinessReason::kPendingStreamResetRequest:
+ return "PENDING_STREAM_RESET_REQUEST";
+ }
+}
} // namespace
+std::string HandoverReadinessStatus::ToString() const {
+ std::string result;
+ for (uint32_t bit = 1;
+ bit <= static_cast<uint32_t>(HandoverUnreadinessReason::kMax);
+ bit *= 2) {
+ auto flag = static_cast<HandoverUnreadinessReason>(bit);
+ if (Contains(flag)) {
+ if (!result.empty()) {
+ result.append(",");
+ }
+ absl::string_view s = HandoverUnreadinessReasonToString(flag);
+ result.append(s.data(), s.size());
+ }
+ }
+ if (result.empty()) {
+ result = "READY";
+ }
+ return result;
+}
+
DcSctpSocket::DcSctpSocket(absl::string_view log_prefix,
DcSctpSocketCallbacks& callbacks,
std::unique_ptr<PacketObserver> packet_observer,
@@ -286,6 +335,42 @@
callbacks_.TriggerDeferred();
}
+void DcSctpSocket::RestoreFromState(const DcSctpSocketHandoverState& state) {
+ if (state_ != State::kClosed) {
+ callbacks_.OnError(ErrorKind::kUnsupportedOperation,
+ "Only closed socket can be restored from state");
+ } else {
+ if (state.socket_state ==
+ DcSctpSocketHandoverState::SocketState::kConnected) {
+ VerificationTag my_verification_tag =
+ VerificationTag(state.my_verification_tag);
+ connect_params_.verification_tag = my_verification_tag;
+
+ Capabilities capabilities;
+ capabilities.partial_reliability = state.capabilities.partial_reliability;
+ capabilities.message_interleaving =
+ state.capabilities.message_interleaving;
+ capabilities.reconfig = state.capabilities.reconfig;
+
+ tcb_ = std::make_unique<TransmissionControlBlock>(
+ timer_manager_, log_prefix_, options_, capabilities, callbacks_,
+ send_queue_, my_verification_tag, TSN(state.my_initial_tsn),
+ VerificationTag(state.peer_verification_tag),
+ TSN(state.peer_initial_tsn), static_cast<size_t>(0),
+ TieTag(state.tie_tag), packet_sender_,
+ [this]() { return state_ == State::kEstablished; }, &state);
+ RTC_DLOG(LS_VERBOSE) << log_prefix() << "Created peer TCB from state: "
+ << tcb_->ToString();
+
+ SetState(State::kEstablished, "restored from handover state");
+ callbacks_.OnConnected();
+ }
+ }
+
+ RTC_DCHECK(IsConsistent());
+ callbacks_.TriggerDeferred();
+}
+
void DcSctpSocket::Shutdown() {
if (tcb_ != nullptr) {
// https://tools.ietf.org/html/rfc4960#section-9.2
@@ -1579,4 +1664,38 @@
t2_shutdown_->Start();
}
+HandoverReadinessStatus DcSctpSocket::GetHandoverReadiness() const {
+ HandoverReadinessStatus status;
+ if (state_ != State::kClosed && state_ != State::kEstablished) {
+ status.Add(HandoverUnreadinessReason::kWrongConnectionState);
+ }
+ if (!send_queue_.IsEmpty()) {
+ status.Add(HandoverUnreadinessReason::kSendQueueNotEmpty);
+ }
+ if (tcb_) {
+ status.Add(tcb_->GetHandoverReadiness());
+ }
+ return status;
+}
+
+absl::optional<DcSctpSocketHandoverState>
+DcSctpSocket::GetHandoverStateAndClose() {
+ if (!GetHandoverReadiness().IsReady()) {
+ return absl::nullopt;
+ }
+
+ DcSctpSocketHandoverState state;
+
+ if (state_ == State::kClosed) {
+ state.socket_state = DcSctpSocketHandoverState::SocketState::kClosed;
+ } else if (state_ == State::kEstablished) {
+ state.socket_state = DcSctpSocketHandoverState::SocketState::kConnected;
+ tcb_->AddHandoverState(state);
+ InternalClose(ErrorKind::kNoError, "handover");
+ callbacks_.TriggerDeferred();
+ }
+
+ return std::move(state);
+}
+
} // namespace dcsctp
diff --git a/net/dcsctp/socket/dcsctp_socket.h b/net/dcsctp/socket/dcsctp_socket.h
index 60359bd..508a8a6 100644
--- a/net/dcsctp/socket/dcsctp_socket.h
+++ b/net/dcsctp/socket/dcsctp_socket.h
@@ -85,6 +85,7 @@
void ReceivePacket(rtc::ArrayView<const uint8_t> data) override;
void HandleTimeout(TimeoutID timeout_id) override;
void Connect() override;
+ void RestoreFromState(const DcSctpSocketHandoverState& state) override;
void Shutdown() override;
void Close() override;
SendStatus Send(DcSctpMessage message,
@@ -98,6 +99,8 @@
size_t buffered_amount_low_threshold(StreamID stream_id) const override;
void SetBufferedAmountLowThreshold(StreamID stream_id, size_t bytes) override;
Metrics GetMetrics() const override;
+ HandoverReadinessStatus GetHandoverReadiness() const override;
+ absl::optional<DcSctpSocketHandoverState> GetHandoverStateAndClose() override;
// Returns this socket's verification tag, or zero if not yet connected.
VerificationTag verification_tag() const {
diff --git a/net/dcsctp/socket/dcsctp_socket_test.cc b/net/dcsctp/socket/dcsctp_socket_test.cc
index 5f99cc9..2fadde8 100644
--- a/net/dcsctp/socket/dcsctp_socket_test.cc
+++ b/net/dcsctp/socket/dcsctp_socket_test.cc
@@ -315,6 +315,24 @@
EXPECT_EQ(sock_z_->state(), SocketState::kConnected);
}
+ void HandoverSocketZ() {
+ ASSERT_EQ(sock_z_->GetHandoverReadiness(), HandoverReadinessStatus());
+ bool is_closed = sock_z_->state() == SocketState::kClosed;
+ if (!is_closed) {
+ EXPECT_CALL(cb_z_, OnClosed).Times(1);
+ }
+ absl::optional<DcSctpSocketHandoverState> handover_state =
+ sock_z_->GetHandoverStateAndClose();
+ EXPECT_TRUE(handover_state.has_value());
+ cb_z_.Reset();
+ sock_z_ = std::make_unique<DcSctpSocket>("Z", cb_z_, GetPacketObserver("Z"),
+ options_);
+ if (!is_closed) {
+ EXPECT_CALL(cb_z_, OnConnected).Times(1);
+ }
+ sock_z_->RestoreFromState(*handover_state);
+ }
+
const DcSctpOptions options_;
testing::NiceMock<MockDcSctpSocketCallbacks> cb_a_;
testing::NiceMock<MockDcSctpSocketCallbacks> cb_z_;
@@ -322,6 +340,52 @@
std::unique_ptr<DcSctpSocket> sock_z_;
};
+// Test parameter that controls whether to perform handovers during the test. A
+// test can have multiple points where it conditionally hands over socket Z.
+// Either socket Z will be handed over at all those points or handed over never.
+enum class HandoverMode {
+ kNoHandover,
+ kPerformHandovers,
+};
+
+class DcSctpSocketParametrizedTest
+ : public DcSctpSocketTest,
+ public ::testing::WithParamInterface<HandoverMode> {
+ protected:
+ // Trigger handover for socket Z depending on the current test param.
+ void MaybeHandoverSocketZ() {
+ if (GetParam() == HandoverMode::kPerformHandovers) {
+ HandoverSocketZ();
+ }
+ }
+ // Trigger handover for socket Z depending on the current test param.
+ // Then checks message passing to verify the handed over socket is functional.
+ void MaybeHandoverSocketZAndSendMessage() {
+ if (GetParam() == HandoverMode::kPerformHandovers) {
+ HandoverSocketZ();
+ }
+
+ ExchangeMessages(*sock_a_, cb_a_, *sock_z_, cb_z_);
+ sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), kSendOptions);
+ ExchangeMessages(*sock_a_, cb_a_, *sock_z_, cb_z_);
+
+ absl::optional<DcSctpMessage> msg = cb_z_.ConsumeReceivedMessage();
+ ASSERT_TRUE(msg.has_value());
+ EXPECT_EQ(msg->stream_id(), StreamID(1));
+ }
+};
+
+INSTANTIATE_TEST_SUITE_P(Handovers,
+ DcSctpSocketParametrizedTest,
+ testing::Values(HandoverMode::kNoHandover,
+ HandoverMode::kPerformHandovers),
+ [](const auto& test_info) {
+ return test_info.param ==
+ HandoverMode::kPerformHandovers
+ ? "WithHandovers"
+ : "NoHandover";
+ });
+
TEST_F(DcSctpSocketTest, EstablishConnection) {
EXPECT_CALL(cb_a_, OnConnected).Times(1);
EXPECT_CALL(cb_z_, OnConnected).Times(1);
@@ -566,8 +630,8 @@
TEST_F(DcSctpSocketTest, DoesntSendMorePacketsUntilCookieAckHasBeenReceived) {
sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53),
- std::vector<uint8_t>(kLargeMessageSize)),
- kSendOptions);
+ std::vector<uint8_t>(kLargeMessageSize)),
+ kSendOptions);
sock_a_->Connect();
// Z reads INIT, produces INIT_ACK
@@ -623,11 +687,13 @@
SizeIs(kLargeMessageSize));
}
-TEST_F(DcSctpSocketTest, ShutdownConnection) {
+TEST_P(DcSctpSocketParametrizedTest, ShutdownConnection) {
ConnectSockets();
+ MaybeHandoverSocketZ();
RTC_LOG(LS_INFO) << "Shutting down";
+ EXPECT_CALL(cb_z_, OnClosed).Times(1);
sock_a_->Shutdown();
// Z reads SHUTDOWN, produces SHUTDOWN_ACK
sock_z_->ReceivePacket(cb_a_.ConsumeSentPacket());
@@ -638,6 +704,9 @@
EXPECT_EQ(sock_a_->state(), SocketState::kClosed);
EXPECT_EQ(sock_z_->state(), SocketState::kClosed);
+
+ MaybeHandoverSocketZ();
+ EXPECT_EQ(sock_z_->state(), SocketState::kClosed);
}
TEST_F(DcSctpSocketTest, ShutdownTimerExpiresTooManyTimeClosesConnection) {
@@ -704,8 +773,9 @@
EXPECT_EQ(msg->stream_id(), StreamID(1));
}
-TEST_F(DcSctpSocketTest, TimeoutResendsPacket) {
+TEST_P(DcSctpSocketParametrizedTest, TimeoutResendsPacket) {
ConnectSockets();
+ MaybeHandoverSocketZ();
sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), kSendOptions);
cb_a_.ConsumeSentPacket();
@@ -719,10 +789,13 @@
absl::optional<DcSctpMessage> msg = cb_z_.ConsumeReceivedMessage();
ASSERT_TRUE(msg.has_value());
EXPECT_EQ(msg->stream_id(), StreamID(1));
+
+ MaybeHandoverSocketZAndSendMessage();
}
-TEST_F(DcSctpSocketTest, SendALotOfBytesMissedSecondPacket) {
+TEST_P(DcSctpSocketParametrizedTest, SendALotOfBytesMissedSecondPacket) {
ConnectSockets();
+ MaybeHandoverSocketZ();
std::vector<uint8_t> payload(kLargeMessageSize);
sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53), payload), kSendOptions);
@@ -739,10 +812,13 @@
ASSERT_TRUE(msg.has_value());
EXPECT_EQ(msg->stream_id(), StreamID(1));
EXPECT_THAT(msg->payload(), testing::ElementsAreArray(payload));
+
+ MaybeHandoverSocketZAndSendMessage();
}
-TEST_F(DcSctpSocketTest, SendingHeartbeatAnswersWithAck) {
+TEST_P(DcSctpSocketParametrizedTest, SendingHeartbeatAnswersWithAck) {
ConnectSockets();
+ MaybeHandoverSocketZ();
// Inject a HEARTBEAT chunk
SctpPacket::Builder b(sock_a_->verification_tag(), DcSctpOptions());
@@ -761,10 +837,13 @@
HeartbeatAckChunk::Parse(ack_packet.descriptors()[0].data));
ASSERT_HAS_VALUE_AND_ASSIGN(HeartbeatInfoParameter info_param, ack.info());
EXPECT_THAT(info_param.info(), ElementsAre(1, 2, 3, 4));
+
+ MaybeHandoverSocketZAndSendMessage();
}
-TEST_F(DcSctpSocketTest, ExpectHeartbeatToBeSent) {
+TEST_P(DcSctpSocketParametrizedTest, ExpectHeartbeatToBeSent) {
ConnectSockets();
+ MaybeHandoverSocketZ();
EXPECT_THAT(cb_a_.ConsumeSentPacket(), IsEmpty());
@@ -786,11 +865,16 @@
// Feed it to Sock-z and expect a HEARTBEAT_ACK that will be propagated back.
sock_z_->ReceivePacket(hb_packet_raw);
sock_a_->ReceivePacket(cb_z_.ConsumeSentPacket());
+
+ MaybeHandoverSocketZAndSendMessage();
}
-TEST_F(DcSctpSocketTest, CloseConnectionAfterTooManyLostHeartbeats) {
+TEST_P(DcSctpSocketParametrizedTest,
+ CloseConnectionAfterTooManyLostHeartbeats) {
ConnectSockets();
+ MaybeHandoverSocketZ();
+ EXPECT_CALL(cb_z_, OnClosed).Times(1);
EXPECT_THAT(cb_a_.ConsumeSentPacket(), testing::IsEmpty());
// Force-close socket Z so that it doesn't interfere from now on.
sock_z_->Close();
@@ -825,12 +909,16 @@
// Should suffice as exceeding RTO
AdvanceTime(DurationMs(1000));
RunTimers();
+
+ MaybeHandoverSocketZ();
}
-TEST_F(DcSctpSocketTest, RecoversAfterASuccessfulAck) {
+TEST_P(DcSctpSocketParametrizedTest, RecoversAfterASuccessfulAck) {
ConnectSockets();
+ MaybeHandoverSocketZ();
EXPECT_THAT(cb_a_.ConsumeSentPacket(), testing::IsEmpty());
+ EXPECT_CALL(cb_z_, OnClosed).Times(1);
// Force-close socket Z so that it doesn't interfere from now on.
sock_z_->Close();
@@ -882,8 +970,9 @@
EXPECT_EQ(another_packet.descriptors()[0].type, HeartbeatRequestChunk::kType);
}
-TEST_F(DcSctpSocketTest, ResetStream) {
+TEST_P(DcSctpSocketParametrizedTest, ResetStream) {
ConnectSockets();
+ MaybeHandoverSocketZ();
sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), {});
sock_z_->ReceivePacket(cb_a_.ConsumeSentPacket());
@@ -906,10 +995,13 @@
// Receiving a response will trigger a callback. Streams are now reset.
EXPECT_CALL(cb_a_, OnStreamsResetPerformed).Times(1);
sock_a_->ReceivePacket(cb_z_.ConsumeSentPacket());
+
+ MaybeHandoverSocketZAndSendMessage();
}
-TEST_F(DcSctpSocketTest, ResetStreamWillMakeChunksStartAtZeroSsn) {
+TEST_P(DcSctpSocketParametrizedTest, ResetStreamWillMakeChunksStartAtZeroSsn) {
ConnectSockets();
+ MaybeHandoverSocketZ();
std::vector<uint8_t> payload(options_.mtu - 100);
@@ -956,10 +1048,14 @@
// Handle SACK
sock_a_->ReceivePacket(cb_z_.ConsumeSentPacket());
+
+ MaybeHandoverSocketZAndSendMessage();
}
-TEST_F(DcSctpSocketTest, ResetStreamWillOnlyResetTheRequestedStreams) {
+TEST_P(DcSctpSocketParametrizedTest,
+ ResetStreamWillOnlyResetTheRequestedStreams) {
ConnectSockets();
+ MaybeHandoverSocketZ();
std::vector<uint8_t> payload(options_.mtu - 100);
@@ -1034,10 +1130,13 @@
// Handle SACK
sock_a_->ReceivePacket(cb_z_.ConsumeSentPacket());
+
+ MaybeHandoverSocketZAndSendMessage();
}
-TEST_F(DcSctpSocketTest, OnePeerReconnects) {
+TEST_P(DcSctpSocketParametrizedTest, OnePeerReconnects) {
ConnectSockets();
+ MaybeHandoverSocketZ();
EXPECT_CALL(cb_a_, OnConnectionRestarted).Times(1);
// Let's be evil here - reconnect while a fragmented packet was about to be
@@ -1064,8 +1163,9 @@
EXPECT_THAT(msg->payload(), testing::ElementsAreArray(payload));
}
-TEST_F(DcSctpSocketTest, SendMessageWithLimitedRtx) {
+TEST_P(DcSctpSocketParametrizedTest, SendMessageWithLimitedRtx) {
ConnectSockets();
+ MaybeHandoverSocketZ();
SendOptions send_options;
send_options.max_retransmissions = 0;
@@ -1117,10 +1217,13 @@
absl::optional<DcSctpMessage> msg3 = cb_z_.ConsumeReceivedMessage();
EXPECT_FALSE(msg3.has_value());
+
+ MaybeHandoverSocketZAndSendMessage();
}
-TEST_F(DcSctpSocketTest, SendManyFragmentedMessagesWithLimitedRtx) {
+TEST_P(DcSctpSocketParametrizedTest, SendManyFragmentedMessagesWithLimitedRtx) {
ConnectSockets();
+ MaybeHandoverSocketZ();
SendOptions send_options;
send_options.unordered = IsUnordered(true);
@@ -1210,8 +1313,9 @@
std::string ToString() const override { return "FAKE"; }
};
-TEST_F(DcSctpSocketTest, ReceivingUnknownChunkRespondsWithError) {
+TEST_P(DcSctpSocketParametrizedTest, ReceivingUnknownChunkRespondsWithError) {
ConnectSockets();
+ MaybeHandoverSocketZ();
// Inject a FAKE chunk
SctpPacket::Builder b(sock_a_->verification_tag(), DcSctpOptions());
@@ -1228,10 +1332,13 @@
UnrecognizedChunkTypeCause cause,
error.error_causes().get<UnrecognizedChunkTypeCause>());
EXPECT_THAT(cause.unrecognized_chunk(), ElementsAre(0x49, 0x00, 0x00, 0x04));
+
+ MaybeHandoverSocketZAndSendMessage();
}
-TEST_F(DcSctpSocketTest, ReceivingErrorChunkReportsAsCallback) {
+TEST_P(DcSctpSocketParametrizedTest, ReceivingErrorChunkReportsAsCallback) {
ConnectSockets();
+ MaybeHandoverSocketZ();
// Inject a ERROR chunk
SctpPacket::Builder b(sock_a_->verification_tag(), DcSctpOptions());
@@ -1243,6 +1350,8 @@
EXPECT_CALL(cb_a_, OnError(ErrorKind::kPeerReported,
HasSubstr("Unrecognized Chunk Type")));
sock_a_->ReceivePacket(b.Build());
+
+ MaybeHandoverSocketZAndSendMessage();
}
TEST_F(DcSctpSocketTest, PassingHighWatermarkWillOnlyAcceptCumAckTsn) {
@@ -1359,8 +1468,9 @@
EXPECT_EQ(sock_a_->options().max_message_size, 42u);
}
-TEST_F(DcSctpSocketTest, SendsMessagesWithLowLifetime) {
+TEST_P(DcSctpSocketParametrizedTest, SendsMessagesWithLowLifetime) {
ConnectSockets();
+ MaybeHandoverSocketZ();
// Mock that the time always goes forward.
TimeMs now(0);
@@ -1394,10 +1504,14 @@
// Validate that the sockets really make the time move forward.
EXPECT_GE(*now, kIterations * 2);
+
+ MaybeHandoverSocketZAndSendMessage();
}
-TEST_F(DcSctpSocketTest, DiscardsMessagesWithLowLifetimeIfMustBuffer) {
+TEST_P(DcSctpSocketParametrizedTest,
+ DiscardsMessagesWithLowLifetimeIfMustBuffer) {
ConnectSockets();
+ MaybeHandoverSocketZ();
SendOptions lifetime_0;
lifetime_0.unordered = IsUnordered(true);
@@ -1449,53 +1563,65 @@
// But none of the smaller messages.
EXPECT_FALSE(cb_z_.ConsumeReceivedMessage().has_value());
+
+ MaybeHandoverSocketZAndSendMessage();
}
-TEST_F(DcSctpSocketTest, HasReasonableBufferedAmountValues) {
+TEST_P(DcSctpSocketParametrizedTest, HasReasonableBufferedAmountValues) {
ConnectSockets();
+ MaybeHandoverSocketZ();
EXPECT_EQ(sock_a_->buffered_amount(StreamID(1)), 0u);
sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53),
- std::vector<uint8_t>(kSmallMessageSize)),
- kSendOptions);
+ std::vector<uint8_t>(kSmallMessageSize)),
+ kSendOptions);
// Sending a small message will directly send it as a single packet, so
// nothing is left in the queue.
EXPECT_EQ(sock_a_->buffered_amount(StreamID(1)), 0u);
sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53),
- std::vector<uint8_t>(kLargeMessageSize)),
- kSendOptions);
+ std::vector<uint8_t>(kLargeMessageSize)),
+ kSendOptions);
// Sending a message will directly start sending a few packets, so the
// buffered amount is not the full message size.
EXPECT_GT(sock_a_->buffered_amount(StreamID(1)), 0u);
EXPECT_LT(sock_a_->buffered_amount(StreamID(1)), kLargeMessageSize);
+
+ MaybeHandoverSocketZAndSendMessage();
}
TEST_F(DcSctpSocketTest, HasDefaultOnBufferedAmountLowValueZero) {
EXPECT_EQ(sock_a_->buffered_amount_low_threshold(StreamID(1)), 0u);
}
-TEST_F(DcSctpSocketTest, TriggersOnBufferedAmountLowWithDefaultValueZero) {
+TEST_P(DcSctpSocketParametrizedTest,
+ TriggersOnBufferedAmountLowWithDefaultValueZero) {
EXPECT_CALL(cb_a_, OnBufferedAmountLow).Times(0);
ConnectSockets();
+ MaybeHandoverSocketZ();
EXPECT_CALL(cb_a_, OnBufferedAmountLow(StreamID(1)));
sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53),
- std::vector<uint8_t>(kSmallMessageSize)),
- kSendOptions);
+ std::vector<uint8_t>(kSmallMessageSize)),
+ kSendOptions);
ExchangeMessages(*sock_a_, cb_a_, *sock_z_, cb_z_);
+
+ EXPECT_CALL(cb_a_, OnBufferedAmountLow).WillRepeatedly(testing::Return());
+ MaybeHandoverSocketZAndSendMessage();
}
-TEST_F(DcSctpSocketTest, DoesntTriggerOnBufferedAmountLowIfBelowThreshold) {
+TEST_P(DcSctpSocketParametrizedTest,
+ DoesntTriggerOnBufferedAmountLowIfBelowThreshold) {
static constexpr size_t kMessageSize = 1000;
static constexpr size_t kBufferedAmountLowThreshold = kMessageSize * 10;
sock_a_->SetBufferedAmountLowThreshold(StreamID(1),
- kBufferedAmountLowThreshold);
+ kBufferedAmountLowThreshold);
EXPECT_CALL(cb_a_, OnBufferedAmountLow).Times(0);
ConnectSockets();
+ MaybeHandoverSocketZ();
EXPECT_CALL(cb_a_, OnBufferedAmountLow(StreamID(1))).Times(0);
sock_a_->Send(
@@ -1507,16 +1633,19 @@
DcSctpMessage(StreamID(1), PPID(53), std::vector<uint8_t>(kMessageSize)),
kSendOptions);
ExchangeMessages(*sock_a_, cb_a_, *sock_z_, cb_z_);
+
+ MaybeHandoverSocketZAndSendMessage();
}
-TEST_F(DcSctpSocketTest, TriggersOnBufferedAmountMultipleTimes) {
+TEST_P(DcSctpSocketParametrizedTest, TriggersOnBufferedAmountMultipleTimes) {
static constexpr size_t kMessageSize = 1000;
static constexpr size_t kBufferedAmountLowThreshold = kMessageSize / 2;
sock_a_->SetBufferedAmountLowThreshold(StreamID(1),
- kBufferedAmountLowThreshold);
+ kBufferedAmountLowThreshold);
EXPECT_CALL(cb_a_, OnBufferedAmountLow).Times(0);
ConnectSockets();
+ MaybeHandoverSocketZ();
EXPECT_CALL(cb_a_, OnBufferedAmountLow(StreamID(1))).Times(3);
EXPECT_CALL(cb_a_, OnBufferedAmountLow(StreamID(2))).Times(2);
@@ -1544,16 +1673,20 @@
DcSctpMessage(StreamID(1), PPID(53), std::vector<uint8_t>(kMessageSize)),
kSendOptions);
ExchangeMessages(*sock_a_, cb_a_, *sock_z_, cb_z_);
+
+ MaybeHandoverSocketZAndSendMessage();
}
-TEST_F(DcSctpSocketTest, TriggersOnBufferedAmountLowOnlyWhenCrossingThreshold) {
+TEST_P(DcSctpSocketParametrizedTest,
+ TriggersOnBufferedAmountLowOnlyWhenCrossingThreshold) {
static constexpr size_t kMessageSize = 1000;
static constexpr size_t kBufferedAmountLowThreshold = kMessageSize * 1.5;
sock_a_->SetBufferedAmountLowThreshold(StreamID(1),
- kBufferedAmountLowThreshold);
+ kBufferedAmountLowThreshold);
EXPECT_CALL(cb_a_, OnBufferedAmountLow).Times(0);
ConnectSockets();
+ MaybeHandoverSocketZ();
EXPECT_CALL(cb_a_, OnBufferedAmountLow).Times(0);
@@ -1561,8 +1694,8 @@
// messages will start to be fully buffered.
while (sock_a_->buffered_amount(StreamID(1)) <= kBufferedAmountLowThreshold) {
sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53),
- std::vector<uint8_t>(kMessageSize)),
- kSendOptions);
+ std::vector<uint8_t>(kMessageSize)),
+ kSendOptions);
}
size_t initial_buffered = sock_a_->buffered_amount(StreamID(1));
ASSERT_GT(initial_buffered, kBufferedAmountLowThreshold);
@@ -1571,36 +1704,46 @@
// callback.
EXPECT_CALL(cb_a_, OnBufferedAmountLow(StreamID(1))).Times(1);
ExchangeMessages(*sock_a_, cb_a_, *sock_z_, cb_z_);
+
+ MaybeHandoverSocketZAndSendMessage();
}
-TEST_F(DcSctpSocketTest, DoesntTriggerOnTotalBufferAmountLowWhenBelow) {
+TEST_P(DcSctpSocketParametrizedTest,
+ DoesntTriggerOnTotalBufferAmountLowWhenBelow) {
ConnectSockets();
+ MaybeHandoverSocketZ();
EXPECT_CALL(cb_a_, OnTotalBufferedAmountLow).Times(0);
sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53),
- std::vector<uint8_t>(kLargeMessageSize)),
- kSendOptions);
+ std::vector<uint8_t>(kLargeMessageSize)),
+ kSendOptions);
ExchangeMessages(*sock_a_, cb_a_, *sock_z_, cb_z_);
+
+ MaybeHandoverSocketZAndSendMessage();
}
-TEST_F(DcSctpSocketTest, TriggersOnTotalBufferAmountLowWhenCrossingThreshold) {
+TEST_P(DcSctpSocketParametrizedTest,
+ TriggersOnTotalBufferAmountLowWhenCrossingThreshold) {
ConnectSockets();
+ MaybeHandoverSocketZ();
EXPECT_CALL(cb_a_, OnTotalBufferedAmountLow).Times(0);
// Fill up the send queue completely.
for (;;) {
if (sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53),
- std::vector<uint8_t>(kLargeMessageSize)),
- kSendOptions) == SendStatus::kErrorResourceExhaustion) {
+ std::vector<uint8_t>(kLargeMessageSize)),
+ kSendOptions) == SendStatus::kErrorResourceExhaustion) {
break;
}
}
EXPECT_CALL(cb_a_, OnTotalBufferedAmountLow).Times(1);
ExchangeMessages(*sock_a_, cb_a_, *sock_z_, cb_z_);
+
+ MaybeHandoverSocketZAndSendMessage();
}
TEST_F(DcSctpSocketTest, InitialMetricsAreZeroed) {
@@ -1650,8 +1793,8 @@
// Send one more (large - fragmented), and receive the delayed SACK.
sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53),
- std::vector<uint8_t>(options_.mtu * 2 + 1)),
- kSendOptions);
+ std::vector<uint8_t>(options_.mtu * 2 + 1)),
+ kSendOptions);
EXPECT_EQ(sock_a_->GetMetrics().unack_data_count, 3u);
sock_z_->ReceivePacket(cb_a_.ConsumeSentPacket()); // DATA
@@ -1683,12 +1826,13 @@
EXPECT_EQ(*sock_a_->GetMetrics().peer_rwnd_bytes, initial_a_rwnd);
}
-TEST_F(DcSctpSocketTest, UnackDataAlsoIncludesSendQueue) {
+TEST_P(DcSctpSocketParametrizedTest, UnackDataAlsoIncludesSendQueue) {
ConnectSockets();
+ MaybeHandoverSocketZ();
sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53),
- std::vector<uint8_t>(kLargeMessageSize)),
- kSendOptions);
+ std::vector<uint8_t>(kLargeMessageSize)),
+ kSendOptions);
size_t payload_bytes =
options_.mtu - SctpPacket::kHeaderSize - DataChunk::kHeaderSize;
@@ -1706,14 +1850,17 @@
EXPECT_LE(sock_a_->GetMetrics().unack_data_count,
expected_sent_packets + expected_queued_packets + 2);
+
+ MaybeHandoverSocketZAndSendMessage();
}
-TEST_F(DcSctpSocketTest, DoesntSendMoreThanMaxBurstPackets) {
+TEST_P(DcSctpSocketParametrizedTest, DoesntSendMoreThanMaxBurstPackets) {
ConnectSockets();
+ MaybeHandoverSocketZ();
sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53),
- std::vector<uint8_t>(kLargeMessageSize)),
- kSendOptions);
+ std::vector<uint8_t>(kLargeMessageSize)),
+ kSendOptions);
for (int i = 0; i < kMaxBurstPackets; ++i) {
std::vector<uint8_t> packet = cb_a_.ConsumeSentPacket();
@@ -1722,10 +1869,14 @@
}
EXPECT_THAT(cb_a_.ConsumeSentPacket(), IsEmpty());
+
+ ExchangeMessages(*sock_a_, cb_a_, *sock_z_, cb_z_);
+ MaybeHandoverSocketZAndSendMessage();
}
-TEST_F(DcSctpSocketTest, SendsOnlyLargePackets) {
+TEST_P(DcSctpSocketParametrizedTest, SendsOnlyLargePackets) {
ConnectSockets();
+ MaybeHandoverSocketZ();
// A really large message, to ensure that the congestion window is often full.
constexpr size_t kMessageSize = 100000;
@@ -1765,10 +1916,13 @@
// The 4 is for padding/alignment.
EXPECT_GE(size, options_.mtu - 4);
}
+
+ MaybeHandoverSocketZAndSendMessage();
}
-TEST_F(DcSctpSocketTest, DoesntBundleForwardTsnWithData) {
+TEST_P(DcSctpSocketParametrizedTest, DoesntBundleForwardTsnWithData) {
ConnectSockets();
+ MaybeHandoverSocketZ();
// Force an RTT measurement using heartbeats.
AdvanceTime(options_.heartbeat_interval);
@@ -1848,5 +2002,49 @@
EXPECT_EQ(packet4.descriptors()[0].type, ForwardTsnChunk::kType);
}
+TEST_F(DcSctpSocketTest, SendMessagesAfterHandover) {
+ ConnectSockets();
+
+ // Send message before handover to move socket to a not initial state
+ sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), kSendOptions);
+ sock_z_->ReceivePacket(cb_a_.ConsumeSentPacket());
+ cb_z_.ConsumeReceivedMessage();
+
+ HandoverSocketZ();
+
+ absl::optional<DcSctpMessage> msg;
+
+ RTC_LOG(LS_INFO) << "Sending A #1";
+
+ sock_a_->Send(DcSctpMessage(StreamID(1), PPID(53), {3, 4}), kSendOptions);
+ sock_z_->ReceivePacket(cb_a_.ConsumeSentPacket());
+
+ msg = cb_z_.ConsumeReceivedMessage();
+ ASSERT_TRUE(msg.has_value());
+ EXPECT_EQ(msg->stream_id(), StreamID(1));
+ EXPECT_THAT(msg->payload(), testing::ElementsAre(3, 4));
+
+ RTC_LOG(LS_INFO) << "Sending A #2";
+
+ sock_a_->Send(DcSctpMessage(StreamID(2), PPID(53), {5, 6}), kSendOptions);
+ sock_z_->ReceivePacket(cb_a_.ConsumeSentPacket());
+
+ msg = cb_z_.ConsumeReceivedMessage();
+ ASSERT_TRUE(msg.has_value());
+ EXPECT_EQ(msg->stream_id(), StreamID(2));
+ EXPECT_THAT(msg->payload(), testing::ElementsAre(5, 6));
+
+ RTC_LOG(LS_INFO) << "Sending Z #1";
+
+ sock_z_->Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2, 3}), kSendOptions);
+ sock_a_->ReceivePacket(cb_z_.ConsumeSentPacket()); // ack
+ sock_a_->ReceivePacket(cb_z_.ConsumeSentPacket()); // data
+
+ msg = cb_a_.ConsumeReceivedMessage();
+ ASSERT_TRUE(msg.has_value());
+ EXPECT_EQ(msg->stream_id(), StreamID(1));
+ EXPECT_THAT(msg->payload(), testing::ElementsAre(1, 2, 3));
+}
+
} // namespace
} // namespace dcsctp
diff --git a/net/dcsctp/socket/mock_dcsctp_socket_callbacks.h b/net/dcsctp/socket/mock_dcsctp_socket_callbacks.h
index 894dd9a..a49a0b3 100644
--- a/net/dcsctp/socket/mock_dcsctp_socket_callbacks.h
+++ b/net/dcsctp/socket/mock_dcsctp_socket_callbacks.h
@@ -150,6 +150,12 @@
return timeout_manager_.GetNextExpiredTimeout();
}
+ void Reset() {
+ sent_packets_.clear();
+ received_messages_.clear();
+ timeout_manager_.Reset();
+ }
+
private:
const std::string log_prefix_;
TimeMs now_ = TimeMs(0);
diff --git a/net/dcsctp/socket/transmission_control_block.cc b/net/dcsctp/socket/transmission_control_block.cc
index f0f1ab9..2e4e968 100644
--- a/net/dcsctp/socket/transmission_control_block.cc
+++ b/net/dcsctp/socket/transmission_control_block.cc
@@ -183,4 +183,30 @@
return sb.Release();
}
+HandoverReadinessStatus TransmissionControlBlock::GetHandoverReadiness() const {
+ HandoverReadinessStatus status;
+ status.Add(data_tracker_.GetHandoverReadiness());
+ status.Add(stream_reset_handler_.GetHandoverReadiness());
+ status.Add(reassembly_queue_.GetHandoverReadiness());
+ status.Add(retransmission_queue_.GetHandoverReadiness());
+ return status;
+}
+
+void TransmissionControlBlock::AddHandoverState(
+ DcSctpSocketHandoverState& state) {
+ state.capabilities.partial_reliability = capabilities_.partial_reliability;
+ state.capabilities.message_interleaving = capabilities_.message_interleaving;
+ state.capabilities.reconfig = capabilities_.reconfig;
+
+ state.my_verification_tag = my_verification_tag().value();
+ state.peer_verification_tag = peer_verification_tag().value();
+ state.my_initial_tsn = my_initial_tsn().value();
+ state.peer_initial_tsn = peer_initial_tsn().value();
+ state.tie_tag = tie_tag().value();
+
+ data_tracker_.AddHandoverState(state);
+ stream_reset_handler_.AddHandoverState(state);
+ reassembly_queue_.AddHandoverState(state);
+ retransmission_queue_.AddHandoverState(state);
+}
} // namespace dcsctp
diff --git a/net/dcsctp/socket/transmission_control_block.h b/net/dcsctp/socket/transmission_control_block.h
index c3766d1..6d9dfc5 100644
--- a/net/dcsctp/socket/transmission_control_block.h
+++ b/net/dcsctp/socket/transmission_control_block.h
@@ -44,20 +44,22 @@
// closed or restarted, this object will be deleted and/or replaced.
class TransmissionControlBlock : public Context {
public:
- TransmissionControlBlock(TimerManager& timer_manager,
- absl::string_view log_prefix,
- const DcSctpOptions& options,
- const Capabilities& capabilities,
- DcSctpSocketCallbacks& callbacks,
- SendQueue& send_queue,
- VerificationTag my_verification_tag,
- TSN my_initial_tsn,
- VerificationTag peer_verification_tag,
- TSN peer_initial_tsn,
- size_t a_rwnd,
- TieTag tie_tag,
- PacketSender& packet_sender,
- std::function<bool()> is_connection_established)
+ TransmissionControlBlock(
+ TimerManager& timer_manager,
+ absl::string_view log_prefix,
+ const DcSctpOptions& options,
+ const Capabilities& capabilities,
+ DcSctpSocketCallbacks& callbacks,
+ SendQueue& send_queue,
+ VerificationTag my_verification_tag,
+ TSN my_initial_tsn,
+ VerificationTag peer_verification_tag,
+ TSN peer_initial_tsn,
+ size_t a_rwnd,
+ TieTag tie_tag,
+ PacketSender& packet_sender,
+ std::function<bool()> is_connection_established,
+ const DcSctpSocketHandoverState* handover_state = nullptr)
: log_prefix_(log_prefix),
options_(options),
timer_manager_(timer_manager),
@@ -86,10 +88,14 @@
packet_sender_(packet_sender),
rto_(options),
tx_error_counter_(log_prefix, options),
- data_tracker_(log_prefix, delayed_ack_timer_.get(), peer_initial_tsn),
+ data_tracker_(log_prefix,
+ delayed_ack_timer_.get(),
+ peer_initial_tsn,
+ handover_state),
reassembly_queue_(log_prefix,
peer_initial_tsn,
- options.max_receiver_window_buffer_size),
+ options.max_receiver_window_buffer_size,
+ handover_state),
retransmission_queue_(
log_prefix,
my_initial_tsn,
@@ -100,13 +106,15 @@
*t3_rtx_,
options,
capabilities.partial_reliability,
- capabilities.message_interleaving),
+ capabilities.message_interleaving,
+ handover_state),
stream_reset_handler_(log_prefix,
this,
&timer_manager,
&data_tracker_,
&reassembly_queue_,
- &retransmission_queue_),
+ &retransmission_queue_,
+ handover_state),
heartbeat_handler_(log_prefix, options, this, &timer_manager_) {}
// Implementation of `Context`.
@@ -188,6 +196,10 @@
// Returns a textual representation of this object, for logging.
std::string ToString() const;
+ HandoverReadinessStatus GetHandoverReadiness() const;
+
+ void AddHandoverState(DcSctpSocketHandoverState& state);
+
private:
// Will be called when the retransmission timer (t3-rtx) expires.
absl::optional<DurationMs> OnRtxTimerExpiry();
diff --git a/net/dcsctp/timer/fake_timeout.h b/net/dcsctp/timer/fake_timeout.h
index f2bf103..e8f50d9 100644
--- a/net/dcsctp/timer/fake_timeout.h
+++ b/net/dcsctp/timer/fake_timeout.h
@@ -91,6 +91,8 @@
return absl::nullopt;
}
+ void Reset() { timers_.clear(); }
+
private:
const std::function<TimeMs()> get_time_;
webrtc::flat_set<FakeTimeout*> timers_;