| /* |
| * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. |
| * |
| * Use of this source code is governed by a BSD-style license |
| * that can be found in the LICENSE file in the root of the source |
| * tree. An additional intellectual property rights grant can be found |
| * in the file PATENTS. All contributing project authors may |
| * be found in the AUTHORS file in the root of the source tree. |
| */ |
| #ifndef NET_DCSCTP_FUZZERS_DCSCTP_FUZZERS_H_ |
| #define NET_DCSCTP_FUZZERS_DCSCTP_FUZZERS_H_ |
| |
| #include <deque> |
| #include <memory> |
| #include <set> |
| #include <vector> |
| |
| #include "api/array_view.h" |
| #include "api/task_queue/task_queue_base.h" |
| #include "net/dcsctp/public/dcsctp_socket.h" |
| |
| namespace dcsctp { |
| namespace dcsctp_fuzzers { |
| |
| // A fake timeout used during fuzzing. |
| class FuzzerTimeout : public Timeout { |
| public: |
| explicit FuzzerTimeout(std::set<TimeoutID>& active_timeouts) |
| : active_timeouts_(active_timeouts) {} |
| |
| void Start(DurationMs /* duration_ms */, TimeoutID timeout_id) override { |
| // Start is only allowed to be called on stopped or expired timeouts. |
| if (timeout_id_.has_value()) { |
| // It has been started before, but maybe it expired. Ensure that it's not |
| // running at least. |
| RTC_DCHECK(active_timeouts_.find(*timeout_id_) == active_timeouts_.end()); |
| } |
| timeout_id_ = timeout_id; |
| RTC_DCHECK(active_timeouts_.insert(timeout_id).second); |
| } |
| |
| void Stop() override { |
| // Stop is only allowed to be called on active timeouts. Not stopped or |
| // expired. |
| RTC_DCHECK(timeout_id_.has_value()); |
| RTC_DCHECK(active_timeouts_.erase(*timeout_id_) == 1); |
| timeout_id_ = std::nullopt; |
| } |
| |
| // A set of all active timeouts, managed by `FuzzerCallbacks`. |
| std::set<TimeoutID>& active_timeouts_; |
| // If present, the timout is active and will expire reported as `timeout_id`. |
| std::optional<TimeoutID> timeout_id_; |
| }; |
| |
| class FuzzerCallbacks : public DcSctpSocketCallbacks { |
| public: |
| static constexpr int kRandomValue = 42; |
| void SendPacket(rtc::ArrayView<const uint8_t> data) override { |
| sent_packets_.emplace_back(std::vector<uint8_t>(data.begin(), data.end())); |
| } |
| std::unique_ptr<Timeout> CreateTimeout( |
| webrtc::TaskQueueBase::DelayPrecision /* precision */) override { |
| // The fuzzer timeouts don't implement |precision|. |
| return std::make_unique<FuzzerTimeout>(active_timeouts_); |
| } |
| webrtc::Timestamp Now() override { return webrtc::Timestamp::Millis(42); } |
| uint32_t GetRandomInt(uint32_t /* low */, uint32_t /* high */) override { |
| return kRandomValue; |
| } |
| void OnMessageReceived(DcSctpMessage /* message */) override {} |
| void OnError(ErrorKind /* error */, |
| absl::string_view /* message */) override {} |
| void OnAborted(ErrorKind /* error */, |
| absl::string_view /* message */) override {} |
| void OnConnected() override {} |
| void OnClosed() override {} |
| void OnConnectionRestarted() override {} |
| void OnStreamsResetFailed( |
| rtc::ArrayView<const StreamID> /* outgoing_streams */, |
| absl::string_view /* reason */) override {} |
| void OnStreamsResetPerformed( |
| rtc::ArrayView<const StreamID> outgoing_streams) override {} |
| void OnIncomingStreamsReset( |
| rtc::ArrayView<const StreamID> incoming_streams) override {} |
| |
| std::vector<uint8_t> ConsumeSentPacket() { |
| if (sent_packets_.empty()) { |
| return {}; |
| } |
| std::vector<uint8_t> ret = sent_packets_.front(); |
| sent_packets_.pop_front(); |
| return ret; |
| } |
| |
| // Given an index among the active timeouts, will expire that one. |
| std::optional<TimeoutID> ExpireTimeout(size_t index) { |
| if (index < active_timeouts_.size()) { |
| auto it = active_timeouts_.begin(); |
| std::advance(it, index); |
| TimeoutID timeout_id = *it; |
| active_timeouts_.erase(it); |
| return timeout_id; |
| } |
| return std::nullopt; |
| } |
| |
| private: |
| // Needs to be ordered, to allow fuzzers to expire timers. |
| std::set<TimeoutID> active_timeouts_; |
| std::deque<std::vector<uint8_t>> sent_packets_; |
| }; |
| |
| // Given some fuzzing `data` will send packets to the socket as well as calling |
| // API methods. |
| void FuzzSocket(DcSctpSocketInterface& socket, |
| FuzzerCallbacks& cb, |
| rtc::ArrayView<const uint8_t> data); |
| |
| } // namespace dcsctp_fuzzers |
| } // namespace dcsctp |
| #endif // NET_DCSCTP_FUZZERS_DCSCTP_FUZZERS_H_ |