/*
 *  Copyright (c) 2012 The WebRTC project authors. All Rights Reserved.
 *
 *  Use of this source code is governed by a BSD-style license
 *  that can be found in the LICENSE file in the root of the source
 *  tree. An additional intellectual property rights grant can be found
 *  in the file PATENTS.  All contributing project authors may
 *  be found in the AUTHORS file in the root of the source tree.
 */

#include "webrtc/modules/audio_processing/voice_detection_impl.h"

#include <assert.h>

#include "webrtc/base/criticalsection.h"
#include "webrtc/base/thread_checker.h"
#include "webrtc/common_audio/vad/include/webrtc_vad.h"
#include "webrtc/modules/audio_processing/audio_buffer.h"

namespace webrtc {

typedef VadInst Handle;

namespace {
int MapSetting(VoiceDetection::Likelihood likelihood) {
  switch (likelihood) {
    case VoiceDetection::kVeryLowLikelihood:
      return 3;
    case VoiceDetection::kLowLikelihood:
      return 2;
    case VoiceDetection::kModerateLikelihood:
      return 1;
    case VoiceDetection::kHighLikelihood:
      return 0;
  }
  assert(false);
  return -1;
}
}  // namespace

VoiceDetectionImpl::VoiceDetectionImpl(const AudioProcessing* apm,
                                       rtc::CriticalSection* crit)
    : ProcessingComponent(),
      apm_(apm),
      crit_(crit),
      stream_has_voice_(false),
      using_external_vad_(false),
      likelihood_(kLowLikelihood),
      frame_size_ms_(10),
      frame_size_samples_(0) {
  RTC_DCHECK(apm);
  RTC_DCHECK(crit);
}

VoiceDetectionImpl::~VoiceDetectionImpl() {}

int VoiceDetectionImpl::ProcessCaptureAudio(AudioBuffer* audio) {
  rtc::CritScope cs(crit_);
  if (!is_component_enabled()) {
    return apm_->kNoError;
  }

  if (using_external_vad_) {
    using_external_vad_ = false;
    return apm_->kNoError;
  }
  assert(audio->num_frames_per_band() <= 160);

  // TODO(ajm): concatenate data in frame buffer here.

  int vad_ret = WebRtcVad_Process(static_cast<Handle*>(handle(0)),
                                  apm_->proc_split_sample_rate_hz(),
                                  audio->mixed_low_pass_data(),
                                  frame_size_samples_);
  if (vad_ret == 0) {
    stream_has_voice_ = false;
    audio->set_activity(AudioFrame::kVadPassive);
  } else if (vad_ret == 1) {
    stream_has_voice_ = true;
    audio->set_activity(AudioFrame::kVadActive);
  } else {
    return apm_->kUnspecifiedError;
  }

  return apm_->kNoError;
}

int VoiceDetectionImpl::Enable(bool enable) {
  rtc::CritScope cs(crit_);
  return EnableComponent(enable);
}

bool VoiceDetectionImpl::is_enabled() const {
  rtc::CritScope cs(crit_);
  return is_component_enabled();
}

int VoiceDetectionImpl::set_stream_has_voice(bool has_voice) {
  rtc::CritScope cs(crit_);
  using_external_vad_ = true;
  stream_has_voice_ = has_voice;
  return apm_->kNoError;
}

bool VoiceDetectionImpl::stream_has_voice() const {
  rtc::CritScope cs(crit_);
  // TODO(ajm): enable this assertion?
  //assert(using_external_vad_ || is_component_enabled());
  return stream_has_voice_;
}

int VoiceDetectionImpl::set_likelihood(VoiceDetection::Likelihood likelihood) {
  rtc::CritScope cs(crit_);
  if (MapSetting(likelihood) == -1) {
    return apm_->kBadParameterError;
  }

  likelihood_ = likelihood;
  return Configure();
}

VoiceDetection::Likelihood VoiceDetectionImpl::likelihood() const {
  rtc::CritScope cs(crit_);
  return likelihood_;
}

int VoiceDetectionImpl::set_frame_size_ms(int size) {
  rtc::CritScope cs(crit_);
  assert(size == 10); // TODO(ajm): remove when supported.
  if (size != 10 &&
      size != 20 &&
      size != 30) {
    return apm_->kBadParameterError;
  }

  frame_size_ms_ = size;

  return Initialize();
}

int VoiceDetectionImpl::frame_size_ms() const {
  rtc::CritScope cs(crit_);
  return frame_size_ms_;
}

int VoiceDetectionImpl::Initialize() {
  int err = ProcessingComponent::Initialize();

  rtc::CritScope cs(crit_);
  if (err != apm_->kNoError || !is_component_enabled()) {
    return err;
  }

  using_external_vad_ = false;
  frame_size_samples_ = static_cast<size_t>(
      frame_size_ms_ * apm_->proc_split_sample_rate_hz() / 1000);
  // TODO(ajm): intialize frame buffer here.

  return apm_->kNoError;
}

void* VoiceDetectionImpl::CreateHandle() const {
  return WebRtcVad_Create();
}

void VoiceDetectionImpl::DestroyHandle(void* handle) const {
  WebRtcVad_Free(static_cast<Handle*>(handle));
}

int VoiceDetectionImpl::InitializeHandle(void* handle) const {
  return WebRtcVad_Init(static_cast<Handle*>(handle));
}

int VoiceDetectionImpl::ConfigureHandle(void* handle) const {
  rtc::CritScope cs(crit_);
  return WebRtcVad_set_mode(static_cast<Handle*>(handle),
                            MapSetting(likelihood_));
}

int VoiceDetectionImpl::num_handles_required() const {
  return 1;
}

int VoiceDetectionImpl::GetHandleError(void* handle) const {
  // The VAD has no get_error() function.
  assert(handle != NULL);
  return apm_->kUnspecifiedError;
}
}  // namespace webrtc
