/*
 *  Copyright (c) 2019 The WebRTC project authors. All Rights Reserved.
 *
 *  Use of this source code is governed by a BSD-style license
 *  that can be found in the LICENSE file in the root of the source
 *  tree. An additional intellectual property rights grant can be found
 *  in the file PATENTS.  All contributing project authors may
 *  be found in the AUTHORS file in the root of the source tree.
 */

#include "video/encoder_bitrate_adjuster.h"

#include <algorithm>
#include <cstddef>
#include <cstdint>
#include <memory>
#include <optional>
#include <vector>

#include "api/field_trials_view.h"
#include "api/units/data_rate.h"
#include "api/units/data_size.h"
#include "api/units/time_delta.h"
#include "api/units/timestamp.h"
#include "api/video/video_bitrate_allocation.h"
#include "api/video/video_codec_constants.h"
#include "api/video/video_codec_type.h"
#include "api/video_codecs/video_codec.h"
#include "api/video_codecs/video_encoder.h"
#include "modules/video_coding/svc/scalability_mode_util.h"
#include "rtc_base/checks.h"
#include "rtc_base/experiments/rate_control_settings.h"
#include "rtc_base/logging.h"
#include "rtc_base/time_utils.h"
#include "system_wrappers/include/clock.h"
#include "video/encoder_overshoot_detector.h"
#include "video/rate_utilization_tracker.h"

namespace webrtc {
namespace {
// Helper struct with metadata for a single spatial layer.
struct LayerRateInfo {
  double link_utilization_factor = 0.0;
  double media_utilization_factor = 0.0;
  DataRate target_rate = DataRate::Zero();

  DataRate WantedOvershoot() const {
    // If there is headroom, allow bitrate to go up to media rate limit.
    // Still limit media utilization to 1.0, so we don't overshoot over long
    // runs even if we have headroom.
    const double max_media_utilization =
        std::max(1.0, media_utilization_factor);
    if (link_utilization_factor > max_media_utilization) {
      return (link_utilization_factor - max_media_utilization) * target_rate;
    }
    return DataRate::Zero();
  }
};
}  // namespace
constexpr TimeDelta EncoderBitrateAdjuster::kWindowSize;
constexpr size_t EncoderBitrateAdjuster::kMinFramesSinceLayoutChange;
constexpr double EncoderBitrateAdjuster::kDefaultUtilizationFactor;

EncoderBitrateAdjuster::EncoderBitrateAdjuster(
    const VideoCodec& codec_settings,
    const FieldTrialsView& field_trials,
    Clock& clock)
    : utilize_bandwidth_headroom_(RateControlSettings(field_trials)
                                      .BitrateAdjusterCanUseNetworkHeadroom()),
      use_newfangled_headroom_adjustment_(!field_trials.IsDisabled(
          "WebRTC-BitrateAdjusterUseNewfangledHeadroomAdjustment")),
      frames_since_layout_change_(0),
      min_bitrates_bps_{},
      codec_(codec_settings.codecType),
      codec_mode_(codec_settings.mode),
      clock_(clock) {
  // TODO(https://crbug.com/webrtc/14891): If we want to support simulcast of
  // SVC streams, EncoderBitrateAdjuster needs to be updated to care about both
  // `simulcastStream` and `spatialLayers` at the same time.
  if (codec_settings.codecType == VideoCodecType::kVideoCodecAV1 &&
      codec_settings.numberOfSimulcastStreams <= 1 &&
      codec_settings.GetScalabilityMode().has_value()) {
    for (int si = 0; si < ScalabilityModeToNumSpatialLayers(
                              *(codec_settings.GetScalabilityMode()));
         ++si) {
      if (codec_settings.spatialLayers[si].active) {
        min_bitrates_bps_[si] =
            std::max(codec_settings.minBitrate * 1000,
                     codec_settings.spatialLayers[si].minBitrate * 1000);
      }
    }
  } else if (codec_settings.codecType == VideoCodecType::kVideoCodecVP9 &&
             codec_settings.numberOfSimulcastStreams <= 1) {
    for (size_t si = 0; si < codec_settings.VP9().numberOfSpatialLayers; ++si) {
      if (codec_settings.spatialLayers[si].active) {
        min_bitrates_bps_[si] =
            std::max(codec_settings.minBitrate * 1000,
                     codec_settings.spatialLayers[si].minBitrate * 1000);
      }
    }
  } else {
    for (size_t si = 0; si < codec_settings.numberOfSimulcastStreams; ++si) {
      if (codec_settings.simulcastStream[si].active) {
        min_bitrates_bps_[si] =
            std::max(codec_settings.minBitrate * 1000,
                     codec_settings.simulcastStream[si].minBitrate * 1000);
      }
    }
  }
}

EncoderBitrateAdjuster::~EncoderBitrateAdjuster() = default;

VideoBitrateAllocation EncoderBitrateAdjuster::AdjustRateAllocation(
    const VideoEncoder::RateControlParameters& rates) {
  current_rate_control_parameters_ = rates;
  const Timestamp now = clock_.CurrentTime();

  // First check that overshoot detectors exist, and store per simulcast/spatial
  // layer how many active temporal layers we have.
  size_t active_tls[kMaxSpatialLayers] = {};
  for (size_t si = 0; si < kMaxSpatialLayers; ++si) {
    active_tls[si] = 0;
    for (size_t ti = 0; ti < kMaxTemporalStreams; ++ti) {
      // Layer is enabled iff it has both positive bitrate and framerate target.
      if (rates.bitrate.GetBitrate(si, ti) > 0 &&
          current_fps_allocation_[si].size() > ti &&
          current_fps_allocation_[si][ti] > 0) {
        ++active_tls[si];
        if (!overshoot_detectors_[si][ti]) {
          overshoot_detectors_[si][ti] =
              std::make_unique<EncoderOvershootDetector>(
                  kWindowSize.ms(), codec_,
                  codec_mode_ == VideoCodecMode::kScreensharing);
          frames_since_layout_change_ = 0;
        }
      } else if (overshoot_detectors_[si][ti]) {
        // Layer removed, destroy overshoot detector.
        overshoot_detectors_[si][ti].reset();
        frames_since_layout_change_ = 0;
      }
    }
    if (use_newfangled_headroom_adjustment_) {
      // Instantiate average media rate trackers, one per active spatial layer.

      DataRate spatial_layer_rate =
          DataRate::BitsPerSec(rates.bitrate.GetSpatialLayerSum(si));
      if (spatial_layer_rate.IsZero()) {
        media_rate_trackers_[si].reset();
      } else {
        if (media_rate_trackers_[si] == nullptr) {
          constexpr int kMaxDataPointsInUtilizationTrackers = 100;
          media_rate_trackers_[si] = std::make_unique<RateUtilizationTracker>(
              kMaxDataPointsInUtilizationTrackers, kWindowSize);
        }
        // Media rate trackers use the unadjusted target rate.
        media_rate_trackers_[si]->OnDataRateChanged(spatial_layer_rate, now);
      }
    }
  }

  // Next poll the overshoot detectors and populate the adjusted allocation.
  VideoBitrateAllocation adjusted_allocation;
  std::vector<LayerRateInfo> layer_infos;
  DataRate wanted_overshoot_sum = DataRate::Zero();

  for (size_t si = 0; si < kMaxSpatialLayers; ++si) {
    layer_infos.emplace_back();
    LayerRateInfo& layer_info = layer_infos.back();

    layer_info.target_rate =
        DataRate::BitsPerSec(rates.bitrate.GetSpatialLayerSum(si));

    // Adjustment is done per simulcast/spatial layer only (not per temporal
    // layer).
    if (frames_since_layout_change_ < kMinFramesSinceLayoutChange) {
      layer_info.link_utilization_factor = kDefaultUtilizationFactor;
      layer_info.media_utilization_factor = kDefaultUtilizationFactor;
    } else if (active_tls[si] == 0 ||
               layer_info.target_rate == DataRate::Zero()) {
      // No signaled temporal layers, or no bitrate set. Could either be unused
      // simulcast/spatial layer or bitrate dynamic mode; pass bitrate through
      // without any change.
      layer_info.link_utilization_factor = 1.0;
      layer_info.media_utilization_factor = 1.0;
    } else if (active_tls[si] == 1) {
      // A single active temporal layer, this might mean single layer or that
      // encoder does not support temporal layers. Merge target bitrates for
      // this simulcast/spatial layer.
      RTC_DCHECK(overshoot_detectors_[si][0]);
      layer_info.link_utilization_factor =
          overshoot_detectors_[si][0]
              ->GetNetworkRateUtilizationFactor(now.ms())
              .value_or(kDefaultUtilizationFactor);
      layer_info.media_utilization_factor =
          use_newfangled_headroom_adjustment_
              ? media_rate_trackers_[si]
                    ->GetRateUtilizationFactor(now)
                    .value_or(kDefaultUtilizationFactor)
              : overshoot_detectors_[si][0]
                    ->GetMediaRateUtilizationFactor(now.ms())
                    .value_or(kDefaultUtilizationFactor);
    } else if (layer_info.target_rate > DataRate::Zero()) {
      // Multiple temporal layers enabled for this simulcast/spatial layer.
      // Update rate for each of them and make a weighted average of utilization
      // factors, with bitrate fraction used as weight.
      // If any layer is missing a utilization factor, fall back to default.
      layer_info.link_utilization_factor = 0.0;
      layer_info.media_utilization_factor = 0.0;
      for (size_t ti = 0; ti < active_tls[si]; ++ti) {
        RTC_DCHECK(overshoot_detectors_[si][ti]);
        const std::optional<double> ti_link_utilization_factor =
            overshoot_detectors_[si][ti]->GetNetworkRateUtilizationFactor(
                now.ms());

        const std::optional<double> ti_media_utilization_factor =
            overshoot_detectors_[si][ti]->GetMediaRateUtilizationFactor(
                now.ms());
        if (!ti_link_utilization_factor || !ti_media_utilization_factor) {
          layer_info.link_utilization_factor = kDefaultUtilizationFactor;
          layer_info.media_utilization_factor = kDefaultUtilizationFactor;
          break;
        }
        const double weight =
            static_cast<double>(rates.bitrate.GetBitrate(si, ti)) /
            layer_info.target_rate.bps();
        layer_info.link_utilization_factor +=
            weight * ti_link_utilization_factor.value();
        layer_info.media_utilization_factor +=
            weight * ti_media_utilization_factor.value();
      }

      if (use_newfangled_headroom_adjustment_) {
        layer_info.media_utilization_factor =
            media_rate_trackers_[si]->GetRateUtilizationFactor(now).value_or(
                kDefaultUtilizationFactor);
      }
    } else {
      RTC_DCHECK_NOTREACHED();
    }

    if (layer_info.link_utilization_factor < 1.0) {
      // Don't boost target bitrate if encoder is under-using.
      layer_info.link_utilization_factor = 1.0;
    } else {
      // Don't reduce encoder target below 50%, in which case the frame dropper
      // should kick in instead.
      layer_info.link_utilization_factor =
          std::min(layer_info.link_utilization_factor, 2.0);

      // Keep track of sum of desired overshoot bitrate.
      wanted_overshoot_sum += layer_info.WantedOvershoot();
    }
  }

  // Available link headroom that can be used to fill wanted overshoot.
  DataRate available_headroom = DataRate::Zero();
  if (utilize_bandwidth_headroom_) {
    available_headroom = rates.bandwidth_allocation -
                         DataRate::BitsPerSec(rates.bitrate.get_sum_bps());
  }

  // All wanted overshoots are satisfied in the same proportion based on
  // available headroom.
  const double granted_overshoot_ratio =
      wanted_overshoot_sum == DataRate::Zero()
          ? 0.0
          : std::min(1.0, available_headroom.bps<double>() /
                              wanted_overshoot_sum.bps());

  for (size_t si = 0; si < kMaxSpatialLayers; ++si) {
    LayerRateInfo& layer_info = layer_infos[si];
    double utilization_factor = layer_info.link_utilization_factor;
    DataRate allowed_overshoot =
        granted_overshoot_ratio * layer_info.WantedOvershoot();
    if (allowed_overshoot > DataRate::Zero()) {
      // Pretend the target bitrate is higher by the allowed overshoot.
      // Since utilization_factor = actual_bitrate / target_bitrate, it can be
      // done by multiplying by old_target_bitrate / new_target_bitrate.
      utilization_factor *= layer_info.target_rate.bps<double>() /
                            (allowed_overshoot.bps<double>() +
                             layer_info.target_rate.bps<double>());
    }

    if (min_bitrates_bps_[si] > 0 &&
        layer_info.target_rate > DataRate::Zero() &&
        DataRate::BitsPerSec(min_bitrates_bps_[si]) < layer_info.target_rate) {
      // Make sure rate adjuster doesn't push target bitrate below minimum.
      utilization_factor =
          std::min(utilization_factor, layer_info.target_rate.bps<double>() /
                                           min_bitrates_bps_[si]);
    }

    if (layer_info.target_rate > DataRate::Zero()) {
      RTC_LOG(LS_VERBOSE)
          << "Utilization factors for simulcast/spatial index " << si
          << ": link = " << layer_info.link_utilization_factor
          << ", media = " << layer_info.media_utilization_factor
          << ", wanted overshoot = " << layer_info.WantedOvershoot().bps()
          << " bps, available headroom = " << available_headroom.bps()
          << " bps, total utilization factor = " << utilization_factor;
    }

    // Populate the adjusted allocation with determined utilization factor.
    if (active_tls[si] == 1 &&
        layer_info.target_rate >
            DataRate::BitsPerSec(rates.bitrate.GetBitrate(si, 0))) {
      // Bitrate allocation indicates temporal layer usage, but encoder
      // does not seem to support it. Pipe all bitrate into a single
      // overshoot detector.
      uint32_t adjusted_layer_bitrate_bps =
          std::min(static_cast<uint32_t>(
                       layer_info.target_rate.bps() / utilization_factor + 0.5),
                   layer_info.target_rate.bps<uint32_t>());
      adjusted_allocation.SetBitrate(si, 0, adjusted_layer_bitrate_bps);
    } else {
      for (size_t ti = 0; ti < kMaxTemporalStreams; ++ti) {
        if (rates.bitrate.HasBitrate(si, ti)) {
          uint32_t adjusted_layer_bitrate_bps = std::min(
              static_cast<uint32_t>(
                  rates.bitrate.GetBitrate(si, ti) / utilization_factor + 0.5),
              rates.bitrate.GetBitrate(si, ti));
          adjusted_allocation.SetBitrate(si, ti, adjusted_layer_bitrate_bps);
        }
      }
    }

    // In case of rounding errors, add bitrate to TL0 until min bitrate
    // constraint has been met.
    const uint32_t adjusted_spatial_layer_sum =
        adjusted_allocation.GetSpatialLayerSum(si);
    if (layer_info.target_rate > DataRate::Zero() &&
        adjusted_spatial_layer_sum < min_bitrates_bps_[si]) {
      adjusted_allocation.SetBitrate(si, 0,
                                     adjusted_allocation.GetBitrate(si, 0) +
                                         min_bitrates_bps_[si] -
                                         adjusted_spatial_layer_sum);
    }

    // Update all detectors with the new adjusted bitrate targets.
    for (size_t ti = 0; ti < kMaxTemporalStreams; ++ti) {
      const uint32_t layer_bitrate_bps = adjusted_allocation.GetBitrate(si, ti);
      // Overshoot detector may not exist, eg for ScreenshareLayers case.
      if (layer_bitrate_bps > 0 && overshoot_detectors_[si][ti]) {
        // Number of frames in this layer alone is not cumulative, so
        // subtract fps from any low temporal layer.
        const double fps_fraction =
            static_cast<double>(
                current_fps_allocation_[si][ti] -
                (ti == 0 ? 0 : current_fps_allocation_[si][ti - 1])) /
            VideoEncoder::EncoderInfo::kMaxFramerateFraction;

        if (fps_fraction <= 0.0) {
          RTC_LOG(LS_WARNING)
              << "Encoder config has temporal layer with non-zero bitrate "
                 "allocation but zero framerate allocation.";
          continue;
        }

        overshoot_detectors_[si][ti]->SetTargetRate(
            DataRate::BitsPerSec(layer_bitrate_bps),
            fps_fraction * rates.framerate_fps, now.ms());
      }
    }
  }

  // Since no simulcast/spatial layers or streams are toggled by the adjustment
  // bw-limited flag stays the same.
  adjusted_allocation.set_bw_limited(rates.bitrate.is_bw_limited());

  return adjusted_allocation;
}

void EncoderBitrateAdjuster::OnEncoderInfo(
    const VideoEncoder::EncoderInfo& encoder_info) {
  // Copy allocation into current state and re-allocate.
  for (size_t si = 0; si < kMaxSpatialLayers; ++si) {
    current_fps_allocation_[si] = encoder_info.fps_allocation[si];
  }

  // Trigger re-allocation so that overshoot detectors have correct targets.
  AdjustRateAllocation(current_rate_control_parameters_);
}

void EncoderBitrateAdjuster::OnEncodedFrame(DataSize size,
                                            int stream_index,
                                            int temporal_index) {
  ++frames_since_layout_change_;
  // Detectors may not exist, for instance if ScreenshareLayers is used.
  auto& detector = overshoot_detectors_[stream_index][temporal_index];
  if (detector) {
    detector->OnEncodedFrame(size.bytes(), TimeMillis());
  }
  if (media_rate_trackers_[stream_index]) {
    media_rate_trackers_[stream_index]->OnDataProduced(size,
                                                       clock_.CurrentTime());
  }
}

void EncoderBitrateAdjuster::Reset() {
  for (size_t si = 0; si < kMaxSpatialLayers; ++si) {
    for (size_t ti = 0; ti < kMaxTemporalStreams; ++ti) {
      overshoot_detectors_[si][ti].reset();
    }
    media_rate_trackers_[si].reset();
  }
  // Call AdjustRateAllocation() with the last know bitrate allocation, so that
  // the appropriate overuse detectors are immediately re-created.
  AdjustRateAllocation(current_rate_control_parameters_);
}

}  // namespace webrtc
