| /* |
| * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. |
| * |
| * Use of this source code is governed by a BSD-style license |
| * that can be found in the LICENSE file in the root of the source |
| * tree. An additional intellectual property rights grant can be found |
| * in the file PATENTS. All contributing project authors may |
| * be found in the AUTHORS file in the root of the source tree. |
| */ |
| #include "net/dcsctp/fuzzers/dcsctp_fuzzers.h" |
| |
| #include <string> |
| #include <utility> |
| #include <vector> |
| |
| #include "net/dcsctp/common/math.h" |
| #include "net/dcsctp/packet/chunk/cookie_ack_chunk.h" |
| #include "net/dcsctp/packet/chunk/cookie_echo_chunk.h" |
| #include "net/dcsctp/packet/chunk/data_chunk.h" |
| #include "net/dcsctp/packet/chunk/forward_tsn_chunk.h" |
| #include "net/dcsctp/packet/chunk/forward_tsn_common.h" |
| #include "net/dcsctp/packet/chunk/shutdown_chunk.h" |
| #include "net/dcsctp/packet/error_cause/protocol_violation_cause.h" |
| #include "net/dcsctp/packet/error_cause/user_initiated_abort_cause.h" |
| #include "net/dcsctp/packet/parameter/forward_tsn_supported_parameter.h" |
| #include "net/dcsctp/packet/parameter/outgoing_ssn_reset_request_parameter.h" |
| #include "net/dcsctp/packet/parameter/state_cookie_parameter.h" |
| #include "net/dcsctp/public/dcsctp_message.h" |
| #include "net/dcsctp/public/types.h" |
| #include "net/dcsctp/socket/dcsctp_socket.h" |
| #include "net/dcsctp/socket/state_cookie.h" |
| #include "rtc_base/copy_on_write_buffer.h" |
| #include "rtc_base/logging.h" |
| |
| namespace dcsctp { |
| namespace dcsctp_fuzzers { |
| namespace { |
| static constexpr int kRandomValue = FuzzerCallbacks::kRandomValue; |
| static constexpr size_t kMinInputLength = 5; |
| static constexpr size_t kMaxInputLength = 1024; |
| |
| // A starting state for the socket, when fuzzing. |
| enum class StartingState : int { |
| kConnectNotCalled, |
| // When socket initiating Connect |
| kConnectCalled, |
| kReceivedInitAck, |
| kReceivedCookieAck, |
| // When socket initiating Shutdown |
| kShutdownCalled, |
| kReceivedShutdownAck, |
| // When peer socket initiated Connect |
| kReceivedInit, |
| kReceivedCookieEcho, |
| // When peer initiated Shutdown |
| kReceivedShutdown, |
| kReceivedShutdownComplete, |
| kNumberOfStates, |
| }; |
| |
| // State about the current fuzzing iteration |
| class FuzzState { |
| public: |
| explicit FuzzState(rtc::ArrayView<const uint8_t> data) : data_(data) {} |
| |
| uint8_t GetByte() { |
| uint8_t value = 0; |
| if (offset_ < data_.size()) { |
| value = data_[offset_]; |
| ++offset_; |
| } |
| return value; |
| } |
| |
| TSN GetNextTSN() { return TSN(tsn_++); } |
| MID GetNextMID() { return MID(mid_++); } |
| |
| bool empty() const { return offset_ >= data_.size(); } |
| |
| private: |
| uint32_t tsn_ = kRandomValue; |
| uint32_t mid_ = 0; |
| rtc::ArrayView<const uint8_t> data_; |
| size_t offset_ = 0; |
| }; |
| |
| void SetSocketState(DcSctpSocketInterface& socket, |
| FuzzerCallbacks& socket_cb, |
| StartingState state) { |
| // We'll use another temporary peer socket for the establishment. |
| FuzzerCallbacks peer_cb; |
| DcSctpSocket peer("peer", peer_cb, nullptr, {}); |
| |
| switch (state) { |
| case StartingState::kConnectNotCalled: |
| return; |
| case StartingState::kConnectCalled: |
| socket.Connect(); |
| return; |
| case StartingState::kReceivedInitAck: |
| socket.Connect(); |
| peer.ReceivePacket(socket_cb.ConsumeSentPacket()); // INIT |
| socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // INIT_ACK |
| return; |
| case StartingState::kReceivedCookieAck: |
| socket.Connect(); |
| peer.ReceivePacket(socket_cb.ConsumeSentPacket()); // INIT |
| socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // INIT_ACK |
| peer.ReceivePacket(socket_cb.ConsumeSentPacket()); // COOKIE_ECHO |
| socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // COOKIE_ACK |
| return; |
| case StartingState::kShutdownCalled: |
| socket.Connect(); |
| peer.ReceivePacket(socket_cb.ConsumeSentPacket()); // INIT |
| socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // INIT_ACK |
| peer.ReceivePacket(socket_cb.ConsumeSentPacket()); // COOKIE_ECHO |
| socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // COOKIE_ACK |
| socket.Shutdown(); |
| return; |
| case StartingState::kReceivedShutdownAck: |
| socket.Connect(); |
| peer.ReceivePacket(socket_cb.ConsumeSentPacket()); // INIT |
| socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // INIT_ACK |
| peer.ReceivePacket(socket_cb.ConsumeSentPacket()); // COOKIE_ECHO |
| socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // COOKIE_ACK |
| socket.Shutdown(); |
| peer.ReceivePacket(socket_cb.ConsumeSentPacket()); // SHUTDOWN |
| socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // SHUTDOWN_ACK |
| return; |
| case StartingState::kReceivedInit: |
| peer.Connect(); |
| socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // INIT |
| return; |
| case StartingState::kReceivedCookieEcho: |
| peer.Connect(); |
| socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // INIT |
| peer.ReceivePacket(socket_cb.ConsumeSentPacket()); // INIT_ACK |
| socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // COOKIE_ECHO |
| return; |
| case StartingState::kReceivedShutdown: |
| socket.Connect(); |
| peer.ReceivePacket(socket_cb.ConsumeSentPacket()); // INIT |
| socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // INIT_ACK |
| peer.ReceivePacket(socket_cb.ConsumeSentPacket()); // COOKIE_ECHO |
| socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // COOKIE_ACK |
| peer.Shutdown(); |
| socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // SHUTDOWN |
| return; |
| case StartingState::kReceivedShutdownComplete: |
| socket.Connect(); |
| peer.ReceivePacket(socket_cb.ConsumeSentPacket()); // INIT |
| socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // INIT_ACK |
| peer.ReceivePacket(socket_cb.ConsumeSentPacket()); // COOKIE_ECHO |
| socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // COOKIE_ACK |
| peer.Shutdown(); |
| socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // SHUTDOWN |
| peer.ReceivePacket(socket_cb.ConsumeSentPacket()); // SHUTDOWN_ACK |
| socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // SHUTDOWN_COMPLETE |
| return; |
| case StartingState::kNumberOfStates: |
| RTC_CHECK(false); |
| return; |
| } |
| } |
| |
| void MakeDataChunk(FuzzState& state, SctpPacket::Builder& b) { |
| DataChunk::Options options; |
| options.is_unordered = IsUnordered(state.GetByte() != 0); |
| options.is_beginning = Data::IsBeginning(state.GetByte() != 0); |
| options.is_end = Data::IsEnd(state.GetByte() != 0); |
| rtc::CopyOnWriteBuffer payload(10); |
| b.Add(DataChunk(state.GetNextTSN(), StreamID(state.GetByte()), |
| SSN(state.GetByte()), PPID(53), payload, options)); |
| } |
| |
| void MakeInitChunk(FuzzState& state, SctpPacket::Builder& b) { |
| Parameters::Builder builder; |
| builder.Add(ForwardTsnSupportedParameter()); |
| |
| b.Add(InitChunk(VerificationTag(kRandomValue), 10000, 1000, 1000, |
| TSN(kRandomValue), builder.Build())); |
| } |
| |
| void MakeInitAckChunk(FuzzState& state, SctpPacket::Builder& b) { |
| Parameters::Builder builder; |
| builder.Add(ForwardTsnSupportedParameter()); |
| |
| uint8_t state_cookie[] = {1, 2, 3, 4, 5}; |
| Parameters::Builder params_builder = |
| Parameters::Builder().Add(StateCookieParameter(state_cookie)); |
| |
| b.Add(InitAckChunk(VerificationTag(kRandomValue), 10000, 1000, 1000, |
| TSN(kRandomValue), builder.Build())); |
| } |
| |
| void MakeSackChunk(FuzzState& state, SctpPacket::Builder& b) { |
| std::vector<SackChunk::GapAckBlock> gap_ack_blocks; |
| uint16_t last_end = 0; |
| while (gap_ack_blocks.size() < 20) { |
| uint8_t delta_start = state.GetByte(); |
| if (delta_start < 0x80) { |
| break; |
| } |
| uint8_t delta_end = state.GetByte(); |
| |
| uint16_t start = last_end + delta_start; |
| uint16_t end = start + delta_end; |
| last_end = end; |
| gap_ack_blocks.emplace_back(start, end); |
| } |
| |
| TSN cum_ack_tsn(kRandomValue + state.GetByte()); |
| b.Add(SackChunk(cum_ack_tsn, 10000, std::move(gap_ack_blocks), {})); |
| } |
| |
| void MakeHeartbeatRequestChunk(FuzzState& state, SctpPacket::Builder& b) { |
| uint8_t info[] = {1, 2, 3, 4, 5}; |
| b.Add(HeartbeatRequestChunk( |
| Parameters::Builder().Add(HeartbeatInfoParameter(info)).Build())); |
| } |
| |
| void MakeHeartbeatAckChunk(FuzzState& state, SctpPacket::Builder& b) { |
| std::vector<uint8_t> info(8); |
| b.Add(HeartbeatRequestChunk( |
| Parameters::Builder().Add(HeartbeatInfoParameter(info)).Build())); |
| } |
| |
| void MakeAbortChunk(FuzzState& state, SctpPacket::Builder& b) { |
| b.Add(AbortChunk( |
| /*filled_in_verification_tag=*/true, |
| Parameters::Builder().Add(UserInitiatedAbortCause("Fuzzing")).Build())); |
| } |
| |
| void MakeErrorChunk(FuzzState& state, SctpPacket::Builder& b) { |
| b.Add(ErrorChunk( |
| Parameters::Builder().Add(ProtocolViolationCause("Fuzzing")).Build())); |
| } |
| |
| void MakeCookieEchoChunk(FuzzState& state, SctpPacket::Builder& b) { |
| std::vector<uint8_t> cookie(StateCookie::kCookieSize); |
| b.Add(CookieEchoChunk(cookie)); |
| } |
| |
| void MakeCookieAckChunk(FuzzState& state, SctpPacket::Builder& b) { |
| b.Add(CookieAckChunk()); |
| } |
| |
| void MakeShutdownChunk(FuzzState& state, SctpPacket::Builder& b) { |
| b.Add(ShutdownChunk(state.GetNextTSN())); |
| } |
| |
| void MakeShutdownAckChunk(FuzzState& state, SctpPacket::Builder& b) { |
| b.Add(ShutdownAckChunk()); |
| } |
| |
| void MakeShutdownCompleteChunk(FuzzState& state, SctpPacket::Builder& b) { |
| b.Add(ShutdownCompleteChunk(false)); |
| } |
| |
| void MakeReConfigChunk(FuzzState& state, SctpPacket::Builder& b) { |
| std::vector<StreamID> streams = {StreamID(state.GetByte())}; |
| Parameters::Builder params_builder = |
| Parameters::Builder().Add(OutgoingSSNResetRequestParameter( |
| ReconfigRequestSN(kRandomValue), ReconfigRequestSN(kRandomValue), |
| state.GetNextTSN(), streams)); |
| b.Add(ReConfigChunk(params_builder.Build())); |
| } |
| |
| void MakeForwardTsnChunk(FuzzState& state, SctpPacket::Builder& b) { |
| std::vector<ForwardTsnChunk::SkippedStream> skipped_streams; |
| for (;;) { |
| uint8_t stream = state.GetByte(); |
| if (skipped_streams.size() > 20 || stream < 0x80) { |
| break; |
| } |
| skipped_streams.emplace_back(StreamID(stream), SSN(state.GetByte())); |
| } |
| b.Add(ForwardTsnChunk(state.GetNextTSN(), std::move(skipped_streams))); |
| } |
| |
| void MakeIDataChunk(FuzzState& state, SctpPacket::Builder& b) { |
| DataChunk::Options options; |
| options.is_unordered = IsUnordered(state.GetByte() != 0); |
| options.is_beginning = Data::IsBeginning(state.GetByte() != 0); |
| options.is_end = Data::IsEnd(state.GetByte() != 0); |
| b.Add(IDataChunk(state.GetNextTSN(), StreamID(state.GetByte()), |
| state.GetNextMID(), PPID(53), FSN(0), |
| rtc::CopyOnWriteBuffer(10), options)); |
| } |
| |
| void MakeIForwardTsnChunk(FuzzState& state, SctpPacket::Builder& b) { |
| std::vector<ForwardTsnChunk::SkippedStream> skipped_streams; |
| for (;;) { |
| uint8_t stream = state.GetByte(); |
| if (skipped_streams.size() > 20 || stream < 0x80) { |
| break; |
| } |
| skipped_streams.emplace_back(StreamID(stream), SSN(state.GetByte())); |
| } |
| b.Add(IForwardTsnChunk(state.GetNextTSN(), std::move(skipped_streams))); |
| } |
| |
| class RandomFuzzedChunk : public Chunk { |
| public: |
| explicit RandomFuzzedChunk(FuzzState& state) : state_(state) {} |
| |
| void SerializeTo(std::vector<uint8_t>& out) const override { |
| size_t bytes = state_.GetByte(); |
| for (size_t i = 0; i < bytes; ++i) { |
| out.push_back(state_.GetByte()); |
| } |
| } |
| |
| std::string ToString() const override { return std::string("RANDOM_FUZZED"); } |
| |
| private: |
| FuzzState& state_; |
| }; |
| |
| void MakeChunkWithRandomContent(FuzzState& state, SctpPacket::Builder& b) { |
| b.Add(RandomFuzzedChunk(state)); |
| } |
| |
| std::vector<uint8_t> GeneratePacket(FuzzState& state) { |
| DcSctpOptions options; |
| // Setting a fixed limit to not be dependent on the defaults, which may |
| // change. |
| options.mtu = 2048; |
| SctpPacket::Builder builder(VerificationTag(kRandomValue), options); |
| |
| // The largest expected serialized chunk, as created by fuzzers. |
| static constexpr size_t kMaxChunkSize = 256; |
| |
| for (int i = 0; i < 5 && builder.bytes_remaining() > kMaxChunkSize; ++i) { |
| switch (state.GetByte()) { |
| case 1: |
| MakeDataChunk(state, builder); |
| break; |
| case 2: |
| MakeInitChunk(state, builder); |
| break; |
| case 3: |
| MakeInitAckChunk(state, builder); |
| break; |
| case 4: |
| MakeSackChunk(state, builder); |
| break; |
| case 5: |
| MakeHeartbeatRequestChunk(state, builder); |
| break; |
| case 6: |
| MakeHeartbeatAckChunk(state, builder); |
| break; |
| case 7: |
| MakeAbortChunk(state, builder); |
| break; |
| case 8: |
| MakeErrorChunk(state, builder); |
| break; |
| case 9: |
| MakeCookieEchoChunk(state, builder); |
| break; |
| case 10: |
| MakeCookieAckChunk(state, builder); |
| break; |
| case 11: |
| MakeShutdownChunk(state, builder); |
| break; |
| case 12: |
| MakeShutdownAckChunk(state, builder); |
| break; |
| case 13: |
| MakeShutdownCompleteChunk(state, builder); |
| break; |
| case 14: |
| MakeReConfigChunk(state, builder); |
| break; |
| case 15: |
| MakeForwardTsnChunk(state, builder); |
| break; |
| case 16: |
| MakeIDataChunk(state, builder); |
| break; |
| case 17: |
| MakeIForwardTsnChunk(state, builder); |
| break; |
| case 18: |
| MakeChunkWithRandomContent(state, builder); |
| break; |
| default: |
| break; |
| } |
| } |
| std::vector<uint8_t> packet = builder.Build(); |
| return packet; |
| } |
| } // namespace |
| |
| void FuzzSocket(DcSctpSocketInterface& socket, |
| FuzzerCallbacks& cb, |
| rtc::ArrayView<const uint8_t> data) { |
| if (data.size() < kMinInputLength || data.size() > kMaxInputLength) { |
| return; |
| } |
| if (data[0] >= static_cast<int>(StartingState::kNumberOfStates)) { |
| return; |
| } |
| |
| // Set the socket in a specified valid starting state |
| SetSocketState(socket, cb, static_cast<StartingState>(data[0])); |
| |
| FuzzState state(data.subview(1)); |
| |
| while (!state.empty()) { |
| switch (state.GetByte()) { |
| case 1: |
| // Generate a valid SCTP packet (based on fuzz data) and "receive it". |
| socket.ReceivePacket(GeneratePacket(state)); |
| break; |
| case 2: |
| socket.Connect(); |
| break; |
| case 3: |
| socket.Shutdown(); |
| break; |
| case 4: |
| socket.Close(); |
| break; |
| case 5: { |
| StreamID streams[] = {StreamID(state.GetByte())}; |
| socket.ResetStreams(streams); |
| } break; |
| case 6: { |
| uint8_t flags = state.GetByte(); |
| SendOptions options; |
| options.unordered = IsUnordered(flags & 0x01); |
| options.max_retransmissions = |
| (flags & 0x02) != 0 ? absl::make_optional(0) : absl::nullopt; |
| size_t payload_exponent = (flags >> 2) % 16; |
| size_t payload_size = static_cast<size_t>(1) << payload_exponent; |
| socket.Send(DcSctpMessage(StreamID(state.GetByte()), PPID(53), |
| std::vector<uint8_t>(payload_size)), |
| options); |
| break; |
| } |
| case 7: { |
| // Expire an active timeout/timer. |
| uint8_t timeout_idx = state.GetByte(); |
| absl::optional<TimeoutID> timeout_id = cb.ExpireTimeout(timeout_idx); |
| if (timeout_id.has_value()) { |
| socket.HandleTimeout(*timeout_id); |
| } |
| break; |
| } |
| default: |
| break; |
| } |
| } |
| } |
| } // namespace dcsctp_fuzzers |
| } // namespace dcsctp |