blob: dbb9f7b5ca12a429ebcd58ca774864d91ad6f0bc [file] [log] [blame]
/*
* Copyright (c) 2021 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "net/dcsctp/fuzzers/dcsctp_fuzzers.h"
#include <string>
#include <utility>
#include <vector>
#include "net/dcsctp/common/math.h"
#include "net/dcsctp/packet/chunk/cookie_ack_chunk.h"
#include "net/dcsctp/packet/chunk/cookie_echo_chunk.h"
#include "net/dcsctp/packet/chunk/data_chunk.h"
#include "net/dcsctp/packet/chunk/forward_tsn_chunk.h"
#include "net/dcsctp/packet/chunk/forward_tsn_common.h"
#include "net/dcsctp/packet/chunk/shutdown_chunk.h"
#include "net/dcsctp/packet/error_cause/protocol_violation_cause.h"
#include "net/dcsctp/packet/error_cause/user_initiated_abort_cause.h"
#include "net/dcsctp/packet/parameter/forward_tsn_supported_parameter.h"
#include "net/dcsctp/packet/parameter/outgoing_ssn_reset_request_parameter.h"
#include "net/dcsctp/packet/parameter/state_cookie_parameter.h"
#include "net/dcsctp/public/dcsctp_message.h"
#include "net/dcsctp/public/types.h"
#include "net/dcsctp/socket/dcsctp_socket.h"
#include "net/dcsctp/socket/state_cookie.h"
#include "rtc_base/copy_on_write_buffer.h"
#include "rtc_base/logging.h"
namespace dcsctp {
namespace dcsctp_fuzzers {
namespace {
static constexpr int kRandomValue = FuzzerCallbacks::kRandomValue;
static constexpr size_t kMinInputLength = 5;
static constexpr size_t kMaxInputLength = 1024;
// A starting state for the socket, when fuzzing.
enum class StartingState : int {
kConnectNotCalled,
// When socket initiating Connect
kConnectCalled,
kReceivedInitAck,
kReceivedCookieAck,
// When socket initiating Shutdown
kShutdownCalled,
kReceivedShutdownAck,
// When peer socket initiated Connect
kReceivedInit,
kReceivedCookieEcho,
// When peer initiated Shutdown
kReceivedShutdown,
kReceivedShutdownComplete,
kNumberOfStates,
};
// State about the current fuzzing iteration
class FuzzState {
public:
explicit FuzzState(rtc::ArrayView<const uint8_t> data) : data_(data) {}
uint8_t GetByte() {
uint8_t value = 0;
if (offset_ < data_.size()) {
value = data_[offset_];
++offset_;
}
return value;
}
TSN GetNextTSN() { return TSN(tsn_++); }
MID GetNextMID() { return MID(mid_++); }
bool empty() const { return offset_ >= data_.size(); }
private:
uint32_t tsn_ = kRandomValue;
uint32_t mid_ = 0;
rtc::ArrayView<const uint8_t> data_;
size_t offset_ = 0;
};
void SetSocketState(DcSctpSocketInterface& socket,
FuzzerCallbacks& socket_cb,
StartingState state) {
// We'll use another temporary peer socket for the establishment.
FuzzerCallbacks peer_cb;
DcSctpSocket peer("peer", peer_cb, nullptr, {});
switch (state) {
case StartingState::kConnectNotCalled:
return;
case StartingState::kConnectCalled:
socket.Connect();
return;
case StartingState::kReceivedInitAck:
socket.Connect();
peer.ReceivePacket(socket_cb.ConsumeSentPacket()); // INIT
socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // INIT_ACK
return;
case StartingState::kReceivedCookieAck:
socket.Connect();
peer.ReceivePacket(socket_cb.ConsumeSentPacket()); // INIT
socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // INIT_ACK
peer.ReceivePacket(socket_cb.ConsumeSentPacket()); // COOKIE_ECHO
socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // COOKIE_ACK
return;
case StartingState::kShutdownCalled:
socket.Connect();
peer.ReceivePacket(socket_cb.ConsumeSentPacket()); // INIT
socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // INIT_ACK
peer.ReceivePacket(socket_cb.ConsumeSentPacket()); // COOKIE_ECHO
socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // COOKIE_ACK
socket.Shutdown();
return;
case StartingState::kReceivedShutdownAck:
socket.Connect();
peer.ReceivePacket(socket_cb.ConsumeSentPacket()); // INIT
socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // INIT_ACK
peer.ReceivePacket(socket_cb.ConsumeSentPacket()); // COOKIE_ECHO
socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // COOKIE_ACK
socket.Shutdown();
peer.ReceivePacket(socket_cb.ConsumeSentPacket()); // SHUTDOWN
socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // SHUTDOWN_ACK
return;
case StartingState::kReceivedInit:
peer.Connect();
socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // INIT
return;
case StartingState::kReceivedCookieEcho:
peer.Connect();
socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // INIT
peer.ReceivePacket(socket_cb.ConsumeSentPacket()); // INIT_ACK
socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // COOKIE_ECHO
return;
case StartingState::kReceivedShutdown:
socket.Connect();
peer.ReceivePacket(socket_cb.ConsumeSentPacket()); // INIT
socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // INIT_ACK
peer.ReceivePacket(socket_cb.ConsumeSentPacket()); // COOKIE_ECHO
socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // COOKIE_ACK
peer.Shutdown();
socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // SHUTDOWN
return;
case StartingState::kReceivedShutdownComplete:
socket.Connect();
peer.ReceivePacket(socket_cb.ConsumeSentPacket()); // INIT
socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // INIT_ACK
peer.ReceivePacket(socket_cb.ConsumeSentPacket()); // COOKIE_ECHO
socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // COOKIE_ACK
peer.Shutdown();
socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // SHUTDOWN
peer.ReceivePacket(socket_cb.ConsumeSentPacket()); // SHUTDOWN_ACK
socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // SHUTDOWN_COMPLETE
return;
case StartingState::kNumberOfStates:
RTC_CHECK(false);
return;
}
}
void MakeDataChunk(FuzzState& state, SctpPacket::Builder& b) {
DataChunk::Options options;
options.is_unordered = IsUnordered(state.GetByte() != 0);
options.is_beginning = Data::IsBeginning(state.GetByte() != 0);
options.is_end = Data::IsEnd(state.GetByte() != 0);
rtc::CopyOnWriteBuffer payload(10);
b.Add(DataChunk(state.GetNextTSN(), StreamID(state.GetByte()),
SSN(state.GetByte()), PPID(53), payload, options));
}
void MakeInitChunk(FuzzState& state, SctpPacket::Builder& b) {
Parameters::Builder builder;
builder.Add(ForwardTsnSupportedParameter());
b.Add(InitChunk(VerificationTag(kRandomValue), 10000, 1000, 1000,
TSN(kRandomValue), builder.Build()));
}
void MakeInitAckChunk(FuzzState& state, SctpPacket::Builder& b) {
Parameters::Builder builder;
builder.Add(ForwardTsnSupportedParameter());
uint8_t state_cookie[] = {1, 2, 3, 4, 5};
Parameters::Builder params_builder =
Parameters::Builder().Add(StateCookieParameter(state_cookie));
b.Add(InitAckChunk(VerificationTag(kRandomValue), 10000, 1000, 1000,
TSN(kRandomValue), builder.Build()));
}
void MakeSackChunk(FuzzState& state, SctpPacket::Builder& b) {
std::vector<SackChunk::GapAckBlock> gap_ack_blocks;
uint16_t last_end = 0;
while (gap_ack_blocks.size() < 20) {
uint8_t delta_start = state.GetByte();
if (delta_start < 0x80) {
break;
}
uint8_t delta_end = state.GetByte();
uint16_t start = last_end + delta_start;
uint16_t end = start + delta_end;
last_end = end;
gap_ack_blocks.emplace_back(start, end);
}
TSN cum_ack_tsn(kRandomValue + state.GetByte());
b.Add(SackChunk(cum_ack_tsn, 10000, std::move(gap_ack_blocks), {}));
}
void MakeHeartbeatRequestChunk(FuzzState& state, SctpPacket::Builder& b) {
uint8_t info[] = {1, 2, 3, 4, 5};
b.Add(HeartbeatRequestChunk(
Parameters::Builder().Add(HeartbeatInfoParameter(info)).Build()));
}
void MakeHeartbeatAckChunk(FuzzState& state, SctpPacket::Builder& b) {
std::vector<uint8_t> info(8);
b.Add(HeartbeatRequestChunk(
Parameters::Builder().Add(HeartbeatInfoParameter(info)).Build()));
}
void MakeAbortChunk(FuzzState& state, SctpPacket::Builder& b) {
b.Add(AbortChunk(
/*filled_in_verification_tag=*/true,
Parameters::Builder().Add(UserInitiatedAbortCause("Fuzzing")).Build()));
}
void MakeErrorChunk(FuzzState& state, SctpPacket::Builder& b) {
b.Add(ErrorChunk(
Parameters::Builder().Add(ProtocolViolationCause("Fuzzing")).Build()));
}
void MakeCookieEchoChunk(FuzzState& state, SctpPacket::Builder& b) {
std::vector<uint8_t> cookie(StateCookie::kCookieSize);
b.Add(CookieEchoChunk(cookie));
}
void MakeCookieAckChunk(FuzzState& state, SctpPacket::Builder& b) {
b.Add(CookieAckChunk());
}
void MakeShutdownChunk(FuzzState& state, SctpPacket::Builder& b) {
b.Add(ShutdownChunk(state.GetNextTSN()));
}
void MakeShutdownAckChunk(FuzzState& state, SctpPacket::Builder& b) {
b.Add(ShutdownAckChunk());
}
void MakeShutdownCompleteChunk(FuzzState& state, SctpPacket::Builder& b) {
b.Add(ShutdownCompleteChunk(false));
}
void MakeReConfigChunk(FuzzState& state, SctpPacket::Builder& b) {
std::vector<StreamID> streams = {StreamID(state.GetByte())};
Parameters::Builder params_builder =
Parameters::Builder().Add(OutgoingSSNResetRequestParameter(
ReconfigRequestSN(kRandomValue), ReconfigRequestSN(kRandomValue),
state.GetNextTSN(), streams));
b.Add(ReConfigChunk(params_builder.Build()));
}
void MakeForwardTsnChunk(FuzzState& state, SctpPacket::Builder& b) {
std::vector<ForwardTsnChunk::SkippedStream> skipped_streams;
for (;;) {
uint8_t stream = state.GetByte();
if (skipped_streams.size() > 20 || stream < 0x80) {
break;
}
skipped_streams.emplace_back(StreamID(stream), SSN(state.GetByte()));
}
b.Add(ForwardTsnChunk(state.GetNextTSN(), std::move(skipped_streams)));
}
void MakeIDataChunk(FuzzState& state, SctpPacket::Builder& b) {
DataChunk::Options options;
options.is_unordered = IsUnordered(state.GetByte() != 0);
options.is_beginning = Data::IsBeginning(state.GetByte() != 0);
options.is_end = Data::IsEnd(state.GetByte() != 0);
b.Add(IDataChunk(state.GetNextTSN(), StreamID(state.GetByte()),
state.GetNextMID(), PPID(53), FSN(0),
rtc::CopyOnWriteBuffer(10), options));
}
void MakeIForwardTsnChunk(FuzzState& state, SctpPacket::Builder& b) {
std::vector<ForwardTsnChunk::SkippedStream> skipped_streams;
for (;;) {
uint8_t stream = state.GetByte();
if (skipped_streams.size() > 20 || stream < 0x80) {
break;
}
skipped_streams.emplace_back(StreamID(stream), SSN(state.GetByte()));
}
b.Add(IForwardTsnChunk(state.GetNextTSN(), std::move(skipped_streams)));
}
class RandomFuzzedChunk : public Chunk {
public:
explicit RandomFuzzedChunk(FuzzState& state) : state_(state) {}
void SerializeTo(std::vector<uint8_t>& out) const override {
size_t bytes = state_.GetByte();
for (size_t i = 0; i < bytes; ++i) {
out.push_back(state_.GetByte());
}
}
std::string ToString() const override { return std::string("RANDOM_FUZZED"); }
private:
FuzzState& state_;
};
void MakeChunkWithRandomContent(FuzzState& state, SctpPacket::Builder& b) {
b.Add(RandomFuzzedChunk(state));
}
std::vector<uint8_t> GeneratePacket(FuzzState& state) {
DcSctpOptions options;
// Setting a fixed limit to not be dependent on the defaults, which may
// change.
options.mtu = 2048;
SctpPacket::Builder builder(VerificationTag(kRandomValue), options);
// The largest expected serialized chunk, as created by fuzzers.
static constexpr size_t kMaxChunkSize = 256;
for (int i = 0; i < 5 && builder.bytes_remaining() > kMaxChunkSize; ++i) {
switch (state.GetByte()) {
case 1:
MakeDataChunk(state, builder);
break;
case 2:
MakeInitChunk(state, builder);
break;
case 3:
MakeInitAckChunk(state, builder);
break;
case 4:
MakeSackChunk(state, builder);
break;
case 5:
MakeHeartbeatRequestChunk(state, builder);
break;
case 6:
MakeHeartbeatAckChunk(state, builder);
break;
case 7:
MakeAbortChunk(state, builder);
break;
case 8:
MakeErrorChunk(state, builder);
break;
case 9:
MakeCookieEchoChunk(state, builder);
break;
case 10:
MakeCookieAckChunk(state, builder);
break;
case 11:
MakeShutdownChunk(state, builder);
break;
case 12:
MakeShutdownAckChunk(state, builder);
break;
case 13:
MakeShutdownCompleteChunk(state, builder);
break;
case 14:
MakeReConfigChunk(state, builder);
break;
case 15:
MakeForwardTsnChunk(state, builder);
break;
case 16:
MakeIDataChunk(state, builder);
break;
case 17:
MakeIForwardTsnChunk(state, builder);
break;
case 18:
MakeChunkWithRandomContent(state, builder);
break;
default:
break;
}
}
std::vector<uint8_t> packet = builder.Build();
return packet;
}
} // namespace
void FuzzSocket(DcSctpSocketInterface& socket,
FuzzerCallbacks& cb,
rtc::ArrayView<const uint8_t> data) {
if (data.size() < kMinInputLength || data.size() > kMaxInputLength) {
return;
}
if (data[0] >= static_cast<int>(StartingState::kNumberOfStates)) {
return;
}
// Set the socket in a specified valid starting state
SetSocketState(socket, cb, static_cast<StartingState>(data[0]));
FuzzState state(data.subview(1));
while (!state.empty()) {
switch (state.GetByte()) {
case 1:
// Generate a valid SCTP packet (based on fuzz data) and "receive it".
socket.ReceivePacket(GeneratePacket(state));
break;
case 2:
socket.Connect();
break;
case 3:
socket.Shutdown();
break;
case 4:
socket.Close();
break;
case 5: {
StreamID streams[] = {StreamID(state.GetByte())};
socket.ResetStreams(streams);
} break;
case 6: {
uint8_t flags = state.GetByte();
SendOptions options;
options.unordered = IsUnordered(flags & 0x01);
options.max_retransmissions =
(flags & 0x02) != 0 ? absl::make_optional(0) : absl::nullopt;
size_t payload_exponent = (flags >> 2) % 16;
size_t payload_size = static_cast<size_t>(1) << payload_exponent;
socket.Send(DcSctpMessage(StreamID(state.GetByte()), PPID(53),
std::vector<uint8_t>(payload_size)),
options);
break;
}
case 7: {
// Expire an active timeout/timer.
uint8_t timeout_idx = state.GetByte();
absl::optional<TimeoutID> timeout_id = cb.ExpireTimeout(timeout_idx);
if (timeout_id.has_value()) {
socket.HandleTimeout(*timeout_id);
}
break;
}
default:
break;
}
}
}
} // namespace dcsctp_fuzzers
} // namespace dcsctp