/*
 *  Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
 *
 *  Use of this source code is governed by a BSD-style license
 *  that can be found in the LICENSE file in the root of the source
 *  tree. An additional intellectual property rights grant can be found
 *  in the file PATENTS.  All contributing project authors may
 *  be found in the AUTHORS file in the root of the source tree.
 */

#include "webrtc/modules/audio_device/include/audio_device_data_observer.h"
#include "webrtc/rtc_base/checks.h"
#include "webrtc/rtc_base/refcountedobject.h"

namespace webrtc {

namespace {

// A wrapper over AudioDeviceModule that registers itself as AudioTransport
// callback and redirects the PCM data to AudioDeviceDataObserver callback.
class ADMWrapper : public AudioDeviceModule, public AudioTransport {
 public:
  ADMWrapper(const int32_t id,
             const AudioLayer audio_layer,
             AudioDeviceDataObserver* observer)
      : impl_(AudioDeviceModule::Create(id, audio_layer)), observer_(observer) {
    // Register self as the audio transport callback for underlying ADM impl.
    auto res = impl_->RegisterAudioCallback(this);
    is_valid_ = (impl_.get() != nullptr) && (res == 0);
  }
  virtual ~ADMWrapper() {
    audio_transport_ = nullptr;
    observer_ = nullptr;
  }

  // Make sure we have a valid ADM before returning it to user.
  bool IsValid() { return is_valid_; }

  // RefCountedModule methods overrides.
  int64_t TimeUntilNextProcess() override {
    return impl_->TimeUntilNextProcess();
  }
  void Process() override { return impl_->Process(); }

  // AudioTransport methods overrides.
  int32_t RecordedDataIsAvailable(const void* audioSamples,
                                  const size_t nSamples,
                                  const size_t nBytesPerSample,
                                  const size_t nChannels,
                                  const uint32_t samples_per_sec,
                                  const uint32_t total_delay_ms,
                                  const int32_t clockDrift,
                                  const uint32_t currentMicLevel,
                                  const bool keyPressed,
                                  uint32_t& newMicLevel) override {
    int32_t res = 0;
    // Capture PCM data of locally captured audio.
    if (observer_) {
      observer_->OnCaptureData(audioSamples, nSamples, nBytesPerSample,
                               nChannels, samples_per_sec);
    }

    // Send to the actual audio transport.
    if (audio_transport_) {
      res = audio_transport_->RecordedDataIsAvailable(
          audioSamples, nSamples, nBytesPerSample, nChannels, samples_per_sec,
          total_delay_ms, clockDrift, currentMicLevel, keyPressed, newMicLevel);
    }

    return res;
  }

  int32_t NeedMorePlayData(const size_t nSamples,
                           const size_t nBytesPerSample,
                           const size_t nChannels,
                           const uint32_t samples_per_sec,
                           void* audioSamples,
                           size_t& nSamplesOut,
                           int64_t* elapsed_time_ms,
                           int64_t* ntp_time_ms) override {
    int32_t res = 0;
    // Request data from audio transport.
    if (audio_transport_) {
      res = audio_transport_->NeedMorePlayData(
          nSamples, nBytesPerSample, nChannels, samples_per_sec, audioSamples,
          nSamplesOut, elapsed_time_ms, ntp_time_ms);
    }

    // Capture rendered data.
    if (observer_) {
      observer_->OnRenderData(audioSamples, nSamples, nBytesPerSample,
                              nChannels, samples_per_sec);
    }

    return res;
  }

  void PushCaptureData(int voe_channel,
                       const void* audio_data,
                       int bits_per_sample,
                       int sample_rate,
                       size_t number_of_channels,
                       size_t number_of_frames) override {
    RTC_NOTREACHED();
  }

  void PullRenderData(int bits_per_sample,
                      int sample_rate,
                      size_t number_of_channels,
                      size_t number_of_frames,
                      void* audio_data,
                      int64_t* elapsed_time_ms,
                      int64_t* ntp_time_ms) override {
    RTC_NOTREACHED();
  }

  // Override AudioDeviceModule's RegisterAudioCallback method to remember the
  // actual audio transport (e.g.: voice engine).
  int32_t RegisterAudioCallback(AudioTransport* audio_callback) override {
    // Remember the audio callback to forward PCM data
    audio_transport_ = audio_callback;
    return 0;
  }

  // AudioDeviceModule pass through method overrides.
  int32_t ActiveAudioLayer(AudioLayer* audio_layer) const override {
    return impl_->ActiveAudioLayer(audio_layer);
  }
  ErrorCode LastError() const override { return impl_->LastError(); }
  int32_t RegisterEventObserver(AudioDeviceObserver* event_callback) override {
    return impl_->RegisterEventObserver(event_callback);
  }
  int32_t Init() override { return impl_->Init(); }
  int32_t Terminate() override { return impl_->Terminate(); }
  bool Initialized() const override { return impl_->Initialized(); }
  int16_t PlayoutDevices() override { return impl_->PlayoutDevices(); }
  int16_t RecordingDevices() override { return impl_->RecordingDevices(); }
  int32_t PlayoutDeviceName(uint16_t index,
                            char name[kAdmMaxDeviceNameSize],
                            char guid[kAdmMaxGuidSize]) override {
    return impl_->PlayoutDeviceName(index, name, guid);
  }
  int32_t RecordingDeviceName(uint16_t index,
                              char name[kAdmMaxDeviceNameSize],
                              char guid[kAdmMaxGuidSize]) override {
    return impl_->RecordingDeviceName(index, name, guid);
  }
  int32_t SetPlayoutDevice(uint16_t index) override {
    return impl_->SetPlayoutDevice(index);
  }
  int32_t SetPlayoutDevice(WindowsDeviceType device) override {
    return impl_->SetPlayoutDevice(device);
  }
  int32_t SetRecordingDevice(uint16_t index) override {
    return impl_->SetRecordingDevice(index);
  }
  int32_t SetRecordingDevice(WindowsDeviceType device) override {
    return impl_->SetRecordingDevice(device);
  }
  int32_t PlayoutIsAvailable(bool* available) override {
    return impl_->PlayoutIsAvailable(available);
  }
  int32_t InitPlayout() override { return impl_->InitPlayout(); }
  bool PlayoutIsInitialized() const override {
    return impl_->PlayoutIsInitialized();
  }
  int32_t RecordingIsAvailable(bool* available) override {
    return impl_->RecordingIsAvailable(available);
  }
  int32_t InitRecording() override { return impl_->InitRecording(); }
  bool RecordingIsInitialized() const override {
    return impl_->RecordingIsInitialized();
  }
  int32_t StartPlayout() override { return impl_->StartPlayout(); }
  int32_t StopPlayout() override { return impl_->StopPlayout(); }
  bool Playing() const override { return impl_->Playing(); }
  int32_t StartRecording() override { return impl_->StartRecording(); }
  int32_t StopRecording() override { return impl_->StopRecording(); }
  bool Recording() const override { return impl_->Recording(); }
  int32_t SetAGC(bool enable) override { return impl_->SetAGC(enable); }
  bool AGC() const override { return impl_->AGC(); }
  int32_t InitSpeaker() override { return impl_->InitSpeaker(); }
  bool SpeakerIsInitialized() const override {
    return impl_->SpeakerIsInitialized();
  }
  int32_t InitMicrophone() override { return impl_->InitMicrophone(); }
  bool MicrophoneIsInitialized() const override {
    return impl_->MicrophoneIsInitialized();
  }
  int32_t SpeakerVolumeIsAvailable(bool* available) override {
    return impl_->SpeakerVolumeIsAvailable(available);
  }
  int32_t SetSpeakerVolume(uint32_t volume) override {
    return impl_->SetSpeakerVolume(volume);
  }
  int32_t SpeakerVolume(uint32_t* volume) const override {
    return impl_->SpeakerVolume(volume);
  }
  int32_t MaxSpeakerVolume(uint32_t* max_volume) const override {
    return impl_->MaxSpeakerVolume(max_volume);
  }
  int32_t MinSpeakerVolume(uint32_t* min_volume) const override {
    return impl_->MinSpeakerVolume(min_volume);
  }
  int32_t SpeakerVolumeStepSize(uint16_t* step_size) const override {
    return impl_->SpeakerVolumeStepSize(step_size);
  }
  int32_t MicrophoneVolumeIsAvailable(bool* available) override {
    return impl_->MicrophoneVolumeIsAvailable(available);
  }
  int32_t SetMicrophoneVolume(uint32_t volume) override {
    return impl_->SetMicrophoneVolume(volume);
  }
  int32_t MicrophoneVolume(uint32_t* volume) const override {
    return impl_->MicrophoneVolume(volume);
  }
  int32_t MaxMicrophoneVolume(uint32_t* max_volume) const override {
    return impl_->MaxMicrophoneVolume(max_volume);
  }
  int32_t MinMicrophoneVolume(uint32_t* min_volume) const override {
    return impl_->MinMicrophoneVolume(min_volume);
  }
  int32_t MicrophoneVolumeStepSize(uint16_t* step_size) const override {
    return impl_->MicrophoneVolumeStepSize(step_size);
  }
  int32_t SpeakerMuteIsAvailable(bool* available) override {
    return impl_->SpeakerMuteIsAvailable(available);
  }
  int32_t SetSpeakerMute(bool enable) override {
    return impl_->SetSpeakerMute(enable);
  }
  int32_t SpeakerMute(bool* enabled) const override {
    return impl_->SpeakerMute(enabled);
  }
  int32_t MicrophoneMuteIsAvailable(bool* available) override {
    return impl_->MicrophoneMuteIsAvailable(available);
  }
  int32_t SetMicrophoneMute(bool enable) override {
    return impl_->SetMicrophoneMute(enable);
  }
  int32_t MicrophoneMute(bool* enabled) const override {
    return impl_->MicrophoneMute(enabled);
  }
  int32_t MicrophoneBoostIsAvailable(bool* available) override {
    return impl_->MicrophoneBoostIsAvailable(available);
  }
  int32_t SetMicrophoneBoost(bool enable) override {
    return impl_->SetMicrophoneBoost(enable);
  }
  int32_t MicrophoneBoost(bool* enabled) const override {
    return impl_->MicrophoneBoost(enabled);
  }
  int32_t StereoPlayoutIsAvailable(bool* available) const override {
    return impl_->StereoPlayoutIsAvailable(available);
  }
  int32_t SetStereoPlayout(bool enable) override {
    return impl_->SetStereoPlayout(enable);
  }
  int32_t StereoPlayout(bool* enabled) const override {
    return impl_->StereoPlayout(enabled);
  }
  int32_t StereoRecordingIsAvailable(bool* available) const override {
    return impl_->StereoRecordingIsAvailable(available);
  }
  int32_t SetStereoRecording(bool enable) override {
    return impl_->SetStereoRecording(enable);
  }
  int32_t StereoRecording(bool* enabled) const override {
    return impl_->StereoRecording(enabled);
  }
  int32_t SetRecordingChannel(const ChannelType channel) override {
    return impl_->SetRecordingChannel(channel);
  }
  int32_t RecordingChannel(ChannelType* channel) const override {
    return impl_->RecordingChannel(channel);
  }
  int32_t SetPlayoutBuffer(const BufferType type, uint16_t size_ms) override {
    return impl_->SetPlayoutBuffer(type, size_ms);
  }
  int32_t PlayoutBuffer(BufferType* type, uint16_t* size_ms) const override {
    return impl_->PlayoutBuffer(type, size_ms);
  }
  int32_t PlayoutDelay(uint16_t* delay_ms) const override {
    return impl_->PlayoutDelay(delay_ms);
  }
  int32_t RecordingDelay(uint16_t* delay_ms) const override {
    return impl_->RecordingDelay(delay_ms);
  }
  int32_t CPULoad(uint16_t* load) const override {
    return impl_->CPULoad(load);
  }
  int32_t StartRawOutputFileRecording(
      const char pcm_file_name_utf8[kAdmMaxFileNameSize]) override {
    return impl_->StartRawOutputFileRecording(pcm_file_name_utf8);
  }
  int32_t StopRawOutputFileRecording() override {
    return impl_->StopRawOutputFileRecording();
  }
  int32_t StartRawInputFileRecording(
      const char pcm_file_name_utf8[kAdmMaxFileNameSize]) override {
    return impl_->StartRawInputFileRecording(pcm_file_name_utf8);
  }
  int32_t StopRawInputFileRecording() override {
    return impl_->StopRawInputFileRecording();
  }
  int32_t SetRecordingSampleRate(const uint32_t samples_per_sec) override {
    return impl_->SetRecordingSampleRate(samples_per_sec);
  }
  int32_t RecordingSampleRate(uint32_t* samples_per_sec) const override {
    return impl_->RecordingSampleRate(samples_per_sec);
  }
  int32_t SetPlayoutSampleRate(const uint32_t samples_per_sec) override {
    return impl_->SetPlayoutSampleRate(samples_per_sec);
  }
  int32_t PlayoutSampleRate(uint32_t* samples_per_sec) const override {
    return impl_->PlayoutSampleRate(samples_per_sec);
  }
  int32_t ResetAudioDevice() override { return impl_->ResetAudioDevice(); }
  int32_t SetLoudspeakerStatus(bool enable) override {
    return impl_->SetLoudspeakerStatus(enable);
  }
  int32_t GetLoudspeakerStatus(bool* enabled) const override {
    return impl_->GetLoudspeakerStatus(enabled);
  }
  bool BuiltInAECIsAvailable() const override {
    return impl_->BuiltInAECIsAvailable();
  }
  bool BuiltInAGCIsAvailable() const override {
    return impl_->BuiltInAGCIsAvailable();
  }
  bool BuiltInNSIsAvailable() const override {
    return impl_->BuiltInNSIsAvailable();
  }
  int32_t EnableBuiltInAEC(bool enable) override {
    return impl_->EnableBuiltInAEC(enable);
  }
  int32_t EnableBuiltInAGC(bool enable) override {
    return impl_->EnableBuiltInAGC(enable);
  }
  int32_t EnableBuiltInNS(bool enable) override {
    return impl_->EnableBuiltInNS(enable);
  }
// Only supported on iOS.
#if defined(WEBRTC_IOS)
  int GetPlayoutAudioParameters(AudioParameters* params) const override {
    return impl_->GetPlayoutAudioParameters(params);
  }
  int GetRecordAudioParameters(AudioParameters* params) const override {
    return impl_->GetRecordAudioParameters(params);
  }
#endif  // WEBRTC_IOS

 protected:
  rtc::scoped_refptr<AudioDeviceModule> impl_;
  AudioDeviceDataObserver* observer_ = nullptr;
  AudioTransport* audio_transport_ = nullptr;
  bool is_valid_ = false;
};

}  // namespace

rtc::scoped_refptr<AudioDeviceModule> CreateAudioDeviceWithDataObserver(
    const int32_t id,
    const AudioDeviceModule::AudioLayer audio_layer,
    AudioDeviceDataObserver* observer) {
  rtc::scoped_refptr<ADMWrapper> audio_device(
      new rtc::RefCountedObject<ADMWrapper>(id, audio_layer, observer));

  if (!audio_device->IsValid()) {
    return nullptr;
  }

  return audio_device;
}

}  // namespace webrtc
