/*
 *  Copyright (c) 2021 The WebRTC project authors. All Rights Reserved.
 *
 *  Use of this source code is governed by a BSD-style license
 *  that can be found in the LICENSE file in the root of the source
 *  tree. An additional intellectual property rights grant can be found
 *  in the file PATENTS.  All contributing project authors may
 *  be found in the AUTHORS file in the root of the source tree.
 */
#include "net/dcsctp/rx/traditional_reassembly_streams.h"

#include <stddef.h>

#include <cstdint>
#include <functional>
#include <iterator>
#include <map>
#include <numeric>
#include <optional>
#include <utility>
#include <vector>

#include "absl/algorithm/container.h"
#include "api/array_view.h"
#include "net/dcsctp/common/sequence_numbers.h"
#include "net/dcsctp/packet/chunk/forward_tsn_common.h"
#include "net/dcsctp/packet/data.h"
#include "net/dcsctp/public/dcsctp_message.h"
#include "rtc_base/logging.h"

namespace dcsctp {
namespace {

// Given a map (`chunks`) and an iterator to within that map (`iter`), this
// function will return an iterator to the first chunk in that message, which
// has the `is_beginning` flag set. If there are any gaps, or if the beginning
// can't be found, `std::nullopt` is returned.
std::optional<std::map<UnwrappedTSN, Data>::iterator> FindBeginning(
    const std::map<UnwrappedTSN, Data>& chunks,
    std::map<UnwrappedTSN, Data>::iterator iter) {
  UnwrappedTSN prev_tsn = iter->first;
  for (;;) {
    if (iter->second.is_beginning) {
      return iter;
    }
    if (iter == chunks.begin()) {
      return std::nullopt;
    }
    --iter;
    if (iter->first.next_value() != prev_tsn) {
      return std::nullopt;
    }
    prev_tsn = iter->first;
  }
}

// Given a map (`chunks`) and an iterator to within that map (`iter`), this
// function will return an iterator to the chunk after the last chunk in that
// message, which has the `is_end` flag set. If there are any gaps, or if the
// end can't be found, `std::nullopt` is returned.
std::optional<std::map<UnwrappedTSN, Data>::iterator> FindEnd(
    std::map<UnwrappedTSN, Data>& chunks,
    std::map<UnwrappedTSN, Data>::iterator iter) {
  UnwrappedTSN prev_tsn = iter->first;
  for (;;) {
    if (iter->second.is_end) {
      return ++iter;
    }
    ++iter;
    if (iter == chunks.end()) {
      return std::nullopt;
    }
    if (iter->first != prev_tsn.next_value()) {
      return std::nullopt;
    }
    prev_tsn = iter->first;
  }
}
}  // namespace

TraditionalReassemblyStreams::TraditionalReassemblyStreams(
    absl::string_view log_prefix,
    OnAssembledMessage on_assembled_message)
    : log_prefix_(log_prefix),
      on_assembled_message_(std::move(on_assembled_message)) {}

int TraditionalReassemblyStreams::UnorderedStream::Add(UnwrappedTSN tsn,
                                                       Data data) {
  if (data.is_beginning && data.is_end) {
    // Fastpath for already assembled chunks.
    AssembleMessage(tsn, std::move(data));
    return 0;
  }
  int queued_bytes = data.size();
  auto [it, inserted] = chunks_.emplace(tsn, std::move(data));
  if (!inserted) {
    return 0;
  }

  queued_bytes -= TryToAssembleMessage(it);

  return queued_bytes;
}

size_t TraditionalReassemblyStreams::UnorderedStream::TryToAssembleMessage(
    ChunkMap::iterator iter) {
  // TODO(boivie): This method is O(N) with the number of fragments in a
  // message, which can be inefficient for very large values of N. This could be
  // optimized by e.g. only trying to assemble a message once _any_ beginning
  // and _any_ end has been found.
  std::optional<ChunkMap::iterator> start = FindBeginning(chunks_, iter);
  if (!start.has_value()) {
    return 0;
  }
  std::optional<ChunkMap::iterator> end = FindEnd(chunks_, iter);
  if (!end.has_value()) {
    return 0;
  }

  size_t bytes_assembled = AssembleMessage(*start, *end);
  chunks_.erase(*start, *end);
  return bytes_assembled;
}

size_t TraditionalReassemblyStreams::StreamBase::AssembleMessage(
    const ChunkMap::iterator start,
    const ChunkMap::iterator end) {
  size_t count = std::distance(start, end);

  if (count == 1) {
    // Fast path - zero-copy
    return AssembleMessage(start->first, std::move(start->second));
  }

  // Slow path - will need to concatenate the payload.
  std::vector<UnwrappedTSN> tsns;
  std::vector<uint8_t> payload;

  size_t payload_size = std::accumulate(
      start, end, 0,
      [](size_t v, const auto& p) { return v + p.second.size(); });

  tsns.reserve(count);
  payload.reserve(payload_size);
  for (auto it = start; it != end; ++it) {
    const Data& data = it->second;
    tsns.push_back(it->first);
    payload.insert(payload.end(), data.payload.begin(), data.payload.end());
  }

  DcSctpMessage message(start->second.stream_id, start->second.ppid,
                        std::move(payload));
  parent_.on_assembled_message_(tsns, std::move(message));

  return payload_size;
}

size_t TraditionalReassemblyStreams::StreamBase::AssembleMessage(
    UnwrappedTSN tsn,
    Data data) {
  // Fast path - zero-copy
  size_t payload_size = data.size();
  UnwrappedTSN tsns[1] = {tsn};
  DcSctpMessage message(data.stream_id, data.ppid, std::move(data.payload));
  parent_.on_assembled_message_(tsns, std::move(message));
  return payload_size;
}

size_t TraditionalReassemblyStreams::UnorderedStream::EraseTo(
    UnwrappedTSN tsn) {
  auto end_iter = chunks_.upper_bound(tsn);
  size_t removed_bytes = std::accumulate(
      chunks_.begin(), end_iter, 0,
      [](size_t r, const auto& p) { return r + p.second.size(); });

  chunks_.erase(chunks_.begin(), end_iter);
  return removed_bytes;
}

size_t TraditionalReassemblyStreams::OrderedStream::TryToAssembleMessage() {
  if (chunks_by_ssn_.empty() || chunks_by_ssn_.begin()->first != next_ssn_) {
    return 0;
  }

  ChunkMap& chunks = chunks_by_ssn_.begin()->second;

  if (!chunks.begin()->second.is_beginning || !chunks.rbegin()->second.is_end) {
    return 0;
  }

  uint32_t tsn_diff =
      UnwrappedTSN::Difference(chunks.rbegin()->first, chunks.begin()->first);
  if (tsn_diff != chunks.size() - 1) {
    return 0;
  }

  size_t assembled_bytes = AssembleMessage(chunks.begin(), chunks.end());
  chunks_by_ssn_.erase(chunks_by_ssn_.begin());
  next_ssn_.Increment();
  return assembled_bytes;
}

size_t TraditionalReassemblyStreams::OrderedStream::TryToAssembleMessages() {
  size_t assembled_bytes = 0;

  for (;;) {
    size_t assembled_bytes_this_iter = TryToAssembleMessage();
    if (assembled_bytes_this_iter == 0) {
      break;
    }
    assembled_bytes += assembled_bytes_this_iter;
  }
  return assembled_bytes;
}

size_t
TraditionalReassemblyStreams::OrderedStream::TryToAssembleMessagesFastpath(
    UnwrappedSSN ssn,
    UnwrappedTSN tsn,
    Data data) {
  RTC_DCHECK(ssn == next_ssn_);
  size_t assembled_bytes = 0;
  if (data.is_beginning && data.is_end) {
    assembled_bytes += AssembleMessage(tsn, std::move(data));
    next_ssn_.Increment();
  } else {
    size_t queued_bytes = data.size();
    auto [iter, inserted] = chunks_by_ssn_[ssn].emplace(tsn, std::move(data));
    if (!inserted) {
      // Not actually assembled, but deduplicated meaning queued size doesn't
      // include this message.
      return queued_bytes;
    }
  }
  return assembled_bytes + TryToAssembleMessages();
}

int TraditionalReassemblyStreams::OrderedStream::Add(UnwrappedTSN tsn,
                                                     Data data) {
  int queued_bytes = data.size();
  UnwrappedSSN ssn = ssn_unwrapper_.Unwrap(data.ssn);
  if (ssn == next_ssn_) {
    return queued_bytes -
           TryToAssembleMessagesFastpath(ssn, tsn, std::move(data));
  }
  auto [iter, inserted] = chunks_by_ssn_[ssn].emplace(tsn, std::move(data));
  if (!inserted) {
    return 0;
  }
  return queued_bytes;
}

size_t TraditionalReassemblyStreams::OrderedStream::EraseTo(SSN ssn) {
  UnwrappedSSN unwrapped_ssn = ssn_unwrapper_.Unwrap(ssn);

  auto end_iter = chunks_by_ssn_.upper_bound(unwrapped_ssn);
  size_t removed_bytes = std::accumulate(
      chunks_by_ssn_.begin(), end_iter, 0, [](size_t r1, const auto& p) {
        return r1 +
               absl::c_accumulate(p.second, 0, [](size_t r2, const auto& q) {
                 return r2 + q.second.size();
               });
      });
  chunks_by_ssn_.erase(chunks_by_ssn_.begin(), end_iter);

  if (unwrapped_ssn >= next_ssn_) {
    unwrapped_ssn.Increment();
    next_ssn_ = unwrapped_ssn;
  }

  removed_bytes += TryToAssembleMessages();
  return removed_bytes;
}

int TraditionalReassemblyStreams::Add(UnwrappedTSN tsn, Data data) {
  if (data.is_unordered) {
    auto it = unordered_streams_.try_emplace(data.stream_id, this).first;
    return it->second.Add(tsn, std::move(data));
  }

  auto it = ordered_streams_.try_emplace(data.stream_id, this).first;
  return it->second.Add(tsn, std::move(data));
}

size_t TraditionalReassemblyStreams::HandleForwardTsn(
    UnwrappedTSN new_cumulative_ack_tsn,
    rtc::ArrayView<const AnyForwardTsnChunk::SkippedStream> skipped_streams) {
  size_t bytes_removed = 0;
  // The `skipped_streams` only cover ordered messages - need to
  // iterate all unordered streams manually to remove those chunks.
  for (auto& [unused, stream] : unordered_streams_) {
    bytes_removed += stream.EraseTo(new_cumulative_ack_tsn);
  }

  for (const auto& skipped_stream : skipped_streams) {
    auto it =
        ordered_streams_.try_emplace(skipped_stream.stream_id, this).first;
    bytes_removed += it->second.EraseTo(skipped_stream.ssn);
  }

  return bytes_removed;
}

void TraditionalReassemblyStreams::ResetStreams(
    rtc::ArrayView<const StreamID> stream_ids) {
  if (stream_ids.empty()) {
    for (auto& [stream_id, stream] : ordered_streams_) {
      RTC_DLOG(LS_VERBOSE) << log_prefix_
                           << "Resetting implicit stream_id=" << *stream_id;
      stream.Reset();
    }
  } else {
    for (StreamID stream_id : stream_ids) {
      auto it = ordered_streams_.find(stream_id);
      if (it != ordered_streams_.end()) {
        RTC_DLOG(LS_VERBOSE)
            << log_prefix_ << "Resetting explicit stream_id=" << *stream_id;
        it->second.Reset();
      }
    }
  }
}

HandoverReadinessStatus TraditionalReassemblyStreams::GetHandoverReadiness()
    const {
  HandoverReadinessStatus status;
  for (const auto& [unused, stream] : ordered_streams_) {
    if (stream.has_unassembled_chunks()) {
      status.Add(HandoverUnreadinessReason::kOrderedStreamHasUnassembledChunks);
      break;
    }
  }
  for (const auto& [unused, stream] : unordered_streams_) {
    if (stream.has_unassembled_chunks()) {
      status.Add(
          HandoverUnreadinessReason::kUnorderedStreamHasUnassembledChunks);
      break;
    }
  }
  return status;
}

void TraditionalReassemblyStreams::AddHandoverState(
    DcSctpSocketHandoverState& state) {
  for (const auto& [stream_id, stream] : ordered_streams_) {
    DcSctpSocketHandoverState::OrderedStream state_stream;
    state_stream.id = stream_id.value();
    state_stream.next_ssn = stream.next_ssn().value();
    state.rx.ordered_streams.push_back(std::move(state_stream));
  }
  for (const auto& [stream_id, unused] : unordered_streams_) {
    DcSctpSocketHandoverState::UnorderedStream state_stream;
    state_stream.id = stream_id.value();
    state.rx.unordered_streams.push_back(std::move(state_stream));
  }
}

void TraditionalReassemblyStreams::RestoreFromState(
    const DcSctpSocketHandoverState& state) {
  // Validate that the component is in pristine state.
  RTC_DCHECK(ordered_streams_.empty());
  RTC_DCHECK(unordered_streams_.empty());

  for (const DcSctpSocketHandoverState::OrderedStream& state_stream :
       state.rx.ordered_streams) {
    ordered_streams_.emplace(
        std::piecewise_construct,
        std::forward_as_tuple(StreamID(state_stream.id)),
        std::forward_as_tuple(this, SSN(state_stream.next_ssn)));
  }
  for (const DcSctpSocketHandoverState::UnorderedStream& state_stream :
       state.rx.unordered_streams) {
    unordered_streams_.emplace(std::piecewise_construct,
                               std::forward_as_tuple(StreamID(state_stream.id)),
                               std::forward_as_tuple(this));
  }
}

}  // namespace dcsctp
