| /* |
| * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. |
| * |
| * Use of this source code is governed by a BSD-style license |
| * that can be found in the LICENSE file in the root of the source |
| * tree. An additional intellectual property rights grant can be found |
| * in the file PATENTS. All contributing project authors may |
| * be found in the AUTHORS file in the root of the source tree. |
| */ |
| #include "net/dcsctp/rx/traditional_reassembly_streams.h" |
| |
| #include <stddef.h> |
| |
| #include <cstdint> |
| #include <functional> |
| #include <iterator> |
| #include <map> |
| #include <numeric> |
| #include <optional> |
| #include <utility> |
| #include <vector> |
| |
| #include "absl/algorithm/container.h" |
| #include "api/array_view.h" |
| #include "net/dcsctp/common/sequence_numbers.h" |
| #include "net/dcsctp/packet/chunk/forward_tsn_common.h" |
| #include "net/dcsctp/packet/data.h" |
| #include "net/dcsctp/public/dcsctp_message.h" |
| #include "rtc_base/logging.h" |
| |
| namespace dcsctp { |
| namespace { |
| |
| // Given a map (`chunks`) and an iterator to within that map (`iter`), this |
| // function will return an iterator to the first chunk in that message, which |
| // has the `is_beginning` flag set. If there are any gaps, or if the beginning |
| // can't be found, `std::nullopt` is returned. |
| std::optional<std::map<UnwrappedTSN, Data>::iterator> FindBeginning( |
| const std::map<UnwrappedTSN, Data>& chunks, |
| std::map<UnwrappedTSN, Data>::iterator iter) { |
| UnwrappedTSN prev_tsn = iter->first; |
| for (;;) { |
| if (iter->second.is_beginning) { |
| return iter; |
| } |
| if (iter == chunks.begin()) { |
| return std::nullopt; |
| } |
| --iter; |
| if (iter->first.next_value() != prev_tsn) { |
| return std::nullopt; |
| } |
| prev_tsn = iter->first; |
| } |
| } |
| |
| // Given a map (`chunks`) and an iterator to within that map (`iter`), this |
| // function will return an iterator to the chunk after the last chunk in that |
| // message, which has the `is_end` flag set. If there are any gaps, or if the |
| // end can't be found, `std::nullopt` is returned. |
| std::optional<std::map<UnwrappedTSN, Data>::iterator> FindEnd( |
| std::map<UnwrappedTSN, Data>& chunks, |
| std::map<UnwrappedTSN, Data>::iterator iter) { |
| UnwrappedTSN prev_tsn = iter->first; |
| for (;;) { |
| if (iter->second.is_end) { |
| return ++iter; |
| } |
| ++iter; |
| if (iter == chunks.end()) { |
| return std::nullopt; |
| } |
| if (iter->first != prev_tsn.next_value()) { |
| return std::nullopt; |
| } |
| prev_tsn = iter->first; |
| } |
| } |
| } // namespace |
| |
| TraditionalReassemblyStreams::TraditionalReassemblyStreams( |
| absl::string_view log_prefix, |
| OnAssembledMessage on_assembled_message) |
| : log_prefix_(log_prefix), |
| on_assembled_message_(std::move(on_assembled_message)) {} |
| |
| int TraditionalReassemblyStreams::UnorderedStream::Add(UnwrappedTSN tsn, |
| Data data) { |
| if (data.is_beginning && data.is_end) { |
| // Fastpath for already assembled chunks. |
| AssembleMessage(tsn, std::move(data)); |
| return 0; |
| } |
| int queued_bytes = data.size(); |
| auto [it, inserted] = chunks_.emplace(tsn, std::move(data)); |
| if (!inserted) { |
| return 0; |
| } |
| |
| queued_bytes -= TryToAssembleMessage(it); |
| |
| return queued_bytes; |
| } |
| |
| size_t TraditionalReassemblyStreams::UnorderedStream::TryToAssembleMessage( |
| ChunkMap::iterator iter) { |
| // TODO(boivie): This method is O(N) with the number of fragments in a |
| // message, which can be inefficient for very large values of N. This could be |
| // optimized by e.g. only trying to assemble a message once _any_ beginning |
| // and _any_ end has been found. |
| std::optional<ChunkMap::iterator> start = FindBeginning(chunks_, iter); |
| if (!start.has_value()) { |
| return 0; |
| } |
| std::optional<ChunkMap::iterator> end = FindEnd(chunks_, iter); |
| if (!end.has_value()) { |
| return 0; |
| } |
| |
| size_t bytes_assembled = AssembleMessage(*start, *end); |
| chunks_.erase(*start, *end); |
| return bytes_assembled; |
| } |
| |
| size_t TraditionalReassemblyStreams::StreamBase::AssembleMessage( |
| const ChunkMap::iterator start, |
| const ChunkMap::iterator end) { |
| size_t count = std::distance(start, end); |
| |
| if (count == 1) { |
| // Fast path - zero-copy |
| return AssembleMessage(start->first, std::move(start->second)); |
| } |
| |
| // Slow path - will need to concatenate the payload. |
| std::vector<UnwrappedTSN> tsns; |
| std::vector<uint8_t> payload; |
| |
| size_t payload_size = std::accumulate( |
| start, end, 0, |
| [](size_t v, const auto& p) { return v + p.second.size(); }); |
| |
| tsns.reserve(count); |
| payload.reserve(payload_size); |
| for (auto it = start; it != end; ++it) { |
| const Data& data = it->second; |
| tsns.push_back(it->first); |
| payload.insert(payload.end(), data.payload.begin(), data.payload.end()); |
| } |
| |
| DcSctpMessage message(start->second.stream_id, start->second.ppid, |
| std::move(payload)); |
| parent_.on_assembled_message_(tsns, std::move(message)); |
| |
| return payload_size; |
| } |
| |
| size_t TraditionalReassemblyStreams::StreamBase::AssembleMessage( |
| UnwrappedTSN tsn, |
| Data data) { |
| // Fast path - zero-copy |
| size_t payload_size = data.size(); |
| UnwrappedTSN tsns[1] = {tsn}; |
| DcSctpMessage message(data.stream_id, data.ppid, std::move(data.payload)); |
| parent_.on_assembled_message_(tsns, std::move(message)); |
| return payload_size; |
| } |
| |
| size_t TraditionalReassemblyStreams::UnorderedStream::EraseTo( |
| UnwrappedTSN tsn) { |
| auto end_iter = chunks_.upper_bound(tsn); |
| size_t removed_bytes = std::accumulate( |
| chunks_.begin(), end_iter, 0, |
| [](size_t r, const auto& p) { return r + p.second.size(); }); |
| |
| chunks_.erase(chunks_.begin(), end_iter); |
| return removed_bytes; |
| } |
| |
| size_t TraditionalReassemblyStreams::OrderedStream::TryToAssembleMessage() { |
| if (chunks_by_ssn_.empty() || chunks_by_ssn_.begin()->first != next_ssn_) { |
| return 0; |
| } |
| |
| ChunkMap& chunks = chunks_by_ssn_.begin()->second; |
| |
| if (!chunks.begin()->second.is_beginning || !chunks.rbegin()->second.is_end) { |
| return 0; |
| } |
| |
| uint32_t tsn_diff = |
| UnwrappedTSN::Difference(chunks.rbegin()->first, chunks.begin()->first); |
| if (tsn_diff != chunks.size() - 1) { |
| return 0; |
| } |
| |
| size_t assembled_bytes = AssembleMessage(chunks.begin(), chunks.end()); |
| chunks_by_ssn_.erase(chunks_by_ssn_.begin()); |
| next_ssn_.Increment(); |
| return assembled_bytes; |
| } |
| |
| size_t TraditionalReassemblyStreams::OrderedStream::TryToAssembleMessages() { |
| size_t assembled_bytes = 0; |
| |
| for (;;) { |
| size_t assembled_bytes_this_iter = TryToAssembleMessage(); |
| if (assembled_bytes_this_iter == 0) { |
| break; |
| } |
| assembled_bytes += assembled_bytes_this_iter; |
| } |
| return assembled_bytes; |
| } |
| |
| size_t |
| TraditionalReassemblyStreams::OrderedStream::TryToAssembleMessagesFastpath( |
| UnwrappedSSN ssn, |
| UnwrappedTSN tsn, |
| Data data) { |
| RTC_DCHECK(ssn == next_ssn_); |
| size_t assembled_bytes = 0; |
| if (data.is_beginning && data.is_end) { |
| assembled_bytes += AssembleMessage(tsn, std::move(data)); |
| next_ssn_.Increment(); |
| } else { |
| size_t queued_bytes = data.size(); |
| auto [iter, inserted] = chunks_by_ssn_[ssn].emplace(tsn, std::move(data)); |
| if (!inserted) { |
| // Not actually assembled, but deduplicated meaning queued size doesn't |
| // include this message. |
| return queued_bytes; |
| } |
| } |
| return assembled_bytes + TryToAssembleMessages(); |
| } |
| |
| int TraditionalReassemblyStreams::OrderedStream::Add(UnwrappedTSN tsn, |
| Data data) { |
| int queued_bytes = data.size(); |
| UnwrappedSSN ssn = ssn_unwrapper_.Unwrap(data.ssn); |
| if (ssn == next_ssn_) { |
| return queued_bytes - |
| TryToAssembleMessagesFastpath(ssn, tsn, std::move(data)); |
| } |
| auto [iter, inserted] = chunks_by_ssn_[ssn].emplace(tsn, std::move(data)); |
| if (!inserted) { |
| return 0; |
| } |
| return queued_bytes; |
| } |
| |
| size_t TraditionalReassemblyStreams::OrderedStream::EraseTo(SSN ssn) { |
| UnwrappedSSN unwrapped_ssn = ssn_unwrapper_.Unwrap(ssn); |
| |
| auto end_iter = chunks_by_ssn_.upper_bound(unwrapped_ssn); |
| size_t removed_bytes = std::accumulate( |
| chunks_by_ssn_.begin(), end_iter, 0, [](size_t r1, const auto& p) { |
| return r1 + |
| absl::c_accumulate(p.second, 0, [](size_t r2, const auto& q) { |
| return r2 + q.second.size(); |
| }); |
| }); |
| chunks_by_ssn_.erase(chunks_by_ssn_.begin(), end_iter); |
| |
| if (unwrapped_ssn >= next_ssn_) { |
| unwrapped_ssn.Increment(); |
| next_ssn_ = unwrapped_ssn; |
| } |
| |
| removed_bytes += TryToAssembleMessages(); |
| return removed_bytes; |
| } |
| |
| int TraditionalReassemblyStreams::Add(UnwrappedTSN tsn, Data data) { |
| if (data.is_unordered) { |
| auto it = unordered_streams_.try_emplace(data.stream_id, this).first; |
| return it->second.Add(tsn, std::move(data)); |
| } |
| |
| auto it = ordered_streams_.try_emplace(data.stream_id, this).first; |
| return it->second.Add(tsn, std::move(data)); |
| } |
| |
| size_t TraditionalReassemblyStreams::HandleForwardTsn( |
| UnwrappedTSN new_cumulative_ack_tsn, |
| rtc::ArrayView<const AnyForwardTsnChunk::SkippedStream> skipped_streams) { |
| size_t bytes_removed = 0; |
| // The `skipped_streams` only cover ordered messages - need to |
| // iterate all unordered streams manually to remove those chunks. |
| for (auto& [unused, stream] : unordered_streams_) { |
| bytes_removed += stream.EraseTo(new_cumulative_ack_tsn); |
| } |
| |
| for (const auto& skipped_stream : skipped_streams) { |
| auto it = |
| ordered_streams_.try_emplace(skipped_stream.stream_id, this).first; |
| bytes_removed += it->second.EraseTo(skipped_stream.ssn); |
| } |
| |
| return bytes_removed; |
| } |
| |
| void TraditionalReassemblyStreams::ResetStreams( |
| rtc::ArrayView<const StreamID> stream_ids) { |
| if (stream_ids.empty()) { |
| for (auto& [stream_id, stream] : ordered_streams_) { |
| RTC_DLOG(LS_VERBOSE) << log_prefix_ |
| << "Resetting implicit stream_id=" << *stream_id; |
| stream.Reset(); |
| } |
| } else { |
| for (StreamID stream_id : stream_ids) { |
| auto it = ordered_streams_.find(stream_id); |
| if (it != ordered_streams_.end()) { |
| RTC_DLOG(LS_VERBOSE) |
| << log_prefix_ << "Resetting explicit stream_id=" << *stream_id; |
| it->second.Reset(); |
| } |
| } |
| } |
| } |
| |
| HandoverReadinessStatus TraditionalReassemblyStreams::GetHandoverReadiness() |
| const { |
| HandoverReadinessStatus status; |
| for (const auto& [unused, stream] : ordered_streams_) { |
| if (stream.has_unassembled_chunks()) { |
| status.Add(HandoverUnreadinessReason::kOrderedStreamHasUnassembledChunks); |
| break; |
| } |
| } |
| for (const auto& [unused, stream] : unordered_streams_) { |
| if (stream.has_unassembled_chunks()) { |
| status.Add( |
| HandoverUnreadinessReason::kUnorderedStreamHasUnassembledChunks); |
| break; |
| } |
| } |
| return status; |
| } |
| |
| void TraditionalReassemblyStreams::AddHandoverState( |
| DcSctpSocketHandoverState& state) { |
| for (const auto& [stream_id, stream] : ordered_streams_) { |
| DcSctpSocketHandoverState::OrderedStream state_stream; |
| state_stream.id = stream_id.value(); |
| state_stream.next_ssn = stream.next_ssn().value(); |
| state.rx.ordered_streams.push_back(std::move(state_stream)); |
| } |
| for (const auto& [stream_id, unused] : unordered_streams_) { |
| DcSctpSocketHandoverState::UnorderedStream state_stream; |
| state_stream.id = stream_id.value(); |
| state.rx.unordered_streams.push_back(std::move(state_stream)); |
| } |
| } |
| |
| void TraditionalReassemblyStreams::RestoreFromState( |
| const DcSctpSocketHandoverState& state) { |
| // Validate that the component is in pristine state. |
| RTC_DCHECK(ordered_streams_.empty()); |
| RTC_DCHECK(unordered_streams_.empty()); |
| |
| for (const DcSctpSocketHandoverState::OrderedStream& state_stream : |
| state.rx.ordered_streams) { |
| ordered_streams_.emplace( |
| std::piecewise_construct, |
| std::forward_as_tuple(StreamID(state_stream.id)), |
| std::forward_as_tuple(this, SSN(state_stream.next_ssn))); |
| } |
| for (const DcSctpSocketHandoverState::UnorderedStream& state_stream : |
| state.rx.unordered_streams) { |
| unordered_streams_.emplace(std::piecewise_construct, |
| std::forward_as_tuple(StreamID(state_stream.id)), |
| std::forward_as_tuple(this)); |
| } |
| } |
| |
| } // namespace dcsctp |