/****************************************************************************
 * Twitch SDK
 *
 * This software is supplied under the terms of a license agreement with
 * Twitch Interactive, Inc. and may not be copied or used except in accordance
 * with the terms of that agreement
 *
 * Copyright (c) 2012-2016 Twitch Interactive, Inc.
 ***************************************************************************/

#include "twitchsdk/broadcast/internal/pch.h"

#include "twitchsdk/broadcast/winaudiocapture.h"

#include "twitchsdk/broadcast/audioconstants.h"
#include "twitchsdk/broadcast/audioframe.h"
#include "twitchsdk/broadcast/iaudiomixer.h"
#include "twitchsdk/broadcast/internal/audioconvert/audioconverter.h"
#include "twitchsdk/broadcast/ipcmaudioframereceiver.h"
#include "twitchsdk/core/assertion.h"
#include "twitchsdk/core/comobjectdeleter.h"
#include "twitchsdk/core/raiiwrapper.h"
#include "twitchsdk/core/systemclock.h"

#include <audioclient.h>
#include <mmdeviceapi.h>

#include <bitset>

#define EXIT_ON_ERROR(hr) \
  if (FAILED(hr)) {       \
    goto Exit;            \
  }

#define SAFE_RELEASE(pUnk)   \
  if ((pUnk) != nullptr) {   \
    (void)(pUnk)->Release(); \
    (pUnk) = nullptr;        \
  }

namespace {
const bool kPcmInterleaved = true;
const uint32_t kNumAudioChannels = 2;

template <size_t inputChannelCount, size_t inputSampleCount>
std::unique_ptr<ttv::IAudioConverter> BuildAudioConverter() {
  using InputBufferFormat = ttv::BufferFormat<int16_t, inputSampleCount, inputChannelCount>;
  using OutputBufferFormat = ttv::BufferFormat<int16_t, ttv::broadcast::kAudioEncodeRate, kNumAudioChannels>;

  return std::make_unique<ttv::AudioConverter<InputBufferFormat, OutputBufferFormat>>();
}

template <size_t inputChannelCount>
std::unique_ptr<ttv::IAudioConverter> BuildAudioConverter(size_t inputSampleCount) {
  // 8000, 11025, 16000, 22050, 32000, 44100, 48000, 96000 and 192000 are our supported input sample rates.
  switch (inputSampleCount) {
    case 8000:
      return BuildAudioConverter<inputChannelCount, 8000>();
    case 11025:
      return BuildAudioConverter<inputChannelCount, 11025>();
    case 16000:
      return BuildAudioConverter<inputChannelCount, 16000>();
    case 22050:
      return BuildAudioConverter<inputChannelCount, 22050>();
    case 32000:
      return BuildAudioConverter<inputChannelCount, 32000>();
    case 44100:
      return BuildAudioConverter<inputChannelCount, 44100>();
    case 48000:
      return BuildAudioConverter<inputChannelCount, 48000>();
    case 96000:
      return BuildAudioConverter<inputChannelCount, 96000>();
    case 192000:
      return BuildAudioConverter<inputChannelCount, 192000>();
    default:
      return nullptr;
  }
}

std::unique_ptr<ttv::IAudioConverter> BuildAudioConverter(size_t inputChannelCount, size_t inputSampleCount) {
  switch (inputChannelCount) {
    case 1:
      return BuildAudioConverter<1>(inputSampleCount);
    case kNumAudioChannels:
      return BuildAudioConverter<kNumAudioChannels>(inputSampleCount);
    default:
      return nullptr;
  }
}
}  // namespace

class ttv::broadcast::WinAudioCaptureInternalData {
 public:
  WinAudioCaptureInternalData(WinAudioCapture::CaptureType type)
      : mType(type), mInitialDevicePosition(0), mOutputSamplePosition(0), mNumInputSamplesPerSecond(0) {}

  TTV_ErrorCode Create() {
    TTV_ErrorCode ec = TTV_EC_SUCCESS;

    const REFERENCE_TIME kRefTimesPerSec = 10000000;

    HRESULT hr = CoInitialize(nullptr);
    if (FAILED(hr)) {
      ttv::trace::Message(
        "AudioStreamer", MessageLevel::Error, "Inside AudioStreamer::ProcessDevices - Call to CoInitialize failed");
      return TTV_EC_BROADCAST_AUDIO_DEVICE_INIT_FAILED;
    }

    // Create the device enumerator
    IMMDeviceEnumerator* tmpEnum = nullptr;
    hr = CoCreateInstance(
      CLSID_MMDeviceEnumerator, nullptr, CLSCTX_ALL, IID_IMMDeviceEnumerator, reinterpret_cast<void**>(&tmpEnum));
    if (FAILED(hr)) {
      ttv::trace::Message(
        "WinAudioCapture", MessageLevel::Error, "Inside WinAudioCapture::Init - CoCreateInstance failed");
      return TTV_EC_BROADCAST_AUDIO_DEVICE_INIT_FAILED;
    }
    std::unique_ptr<IMMDeviceEnumerator, COMObjectDeleter<IMMDeviceEnumerator>> enumerator(tmpEnum);

    // Get the render device
    EDataFlow dataFlow = mType == WinAudioCapture::CaptureType::Microphone ? eCapture : eRender;
    IMMDevice* tmpDevice = nullptr;
    hr = enumerator->GetDefaultAudioEndpoint(dataFlow, eConsole, &tmpDevice);
    if (FAILED(hr)) {
      ttv::trace::Message(
        "WinAudioCapture", MessageLevel::Warning, "Inside WinAudioCapture::Init - Audio device not found");
      return TTV_EC_BROADCAST_AUDIO_DEVICE_INIT_FAILED;
    }
    std::unique_ptr<IMMDevice, COMObjectDeleter<IMMDevice>> device(tmpDevice);

    // Activate the device
    IAudioClient* audioClientPtr = nullptr;
    hr = device->Activate(IID_IAudioClient, CLSCTX_ALL, nullptr, reinterpret_cast<void**>(&audioClientPtr));
    if (FAILED(hr)) {
      ttv::trace::Message(
        "WinAudioCapture", MessageLevel::Error, "Inside WinAudioCapture::Init - pDevice->Activate failed");
      return TTV_EC_BROADCAST_AUDIO_DEVICE_INIT_FAILED;
    }
    std::unique_ptr<IAudioClient, COMObjectDeleter<IAudioClient>> audioClient(audioClientPtr);

    // Get the device mix format
    tWAVEFORMATEX* pWfx = nullptr;

    const std::function<HRESULT()> RaiiCtor = [&audioClient, &pWfx]() -> HRESULT {
      return audioClient->GetMixFormat(&pWfx);
    };
    RaiiWrapper<HRESULT> mixFormatRaii(RaiiCtor, std::bind(&CoTaskMemFree, pWfx));
    if (FAILED(mixFormatRaii.mCtorRet)) {
      ttv::trace::Message("WinAudioCapture", MessageLevel::Error,
        "Inside WinAudioCapture::Init - audioClient.get()->GetMixFormat failed");
      return TTV_EC_BROADCAST_AUDIO_DEVICE_INIT_FAILED;
    }
    TTV_ASSERT(pWfx);
    TTV_ASSERT(pWfx->wFormatTag == WAVE_FORMAT_EXTENSIBLE);

    PWAVEFORMATEXTENSIBLE pEx = reinterpret_cast<PWAVEFORMATEXTENSIBLE>(pWfx);

    // Make sure we get PCM 16-bit samples
    if (IsEqualGUID(KSDATAFORMAT_SUBTYPE_IEEE_FLOAT, pEx->SubFormat)) {
      pEx->SubFormat = KSDATAFORMAT_SUBTYPE_PCM;
      pEx->Samples.wValidBitsPerSample = 16;
      pWfx->wBitsPerSample = 16;
      pWfx->nBlockAlign = pWfx->nChannels * pWfx->wBitsPerSample / 8;
      pWfx->nAvgBytesPerSec = pWfx->nBlockAlign * pWfx->nSamplesPerSec;
    }

    // Initialize the audio client based on whether it's mic or speakers
    DWORD streamFlags = 0;
    switch (mType) {
      case WinAudioCapture::CaptureType::Microphone:
        streamFlags = 0;
        break;
      case WinAudioCapture::CaptureType::System:
        streamFlags = AUDCLNT_STREAMFLAGS_LOOPBACK;
        break;
      default:
        ttv::trace::Message("WinAudioCapture", MessageLevel::Error, "Inside WinAudioCapture::Init - mType was invalid");
        return TTV_EC_BROADCAST_AUDIO_DEVICE_INIT_FAILED;
    }

    hr = audioClient->Initialize(AUDCLNT_SHAREMODE_SHARED, streamFlags, kRefTimesPerSec, 0, pWfx, nullptr);
    if (FAILED(hr)) {
      ttv::trace::Message(
        "WinAudioCapture", MessageLevel::Error, "Inside WinAudioCapture::Init - audioClient->Initialize failed");
      return TTV_EC_BROADCAST_AUDIO_DEVICE_INIT_FAILED;
    }

    // Get the size of the allocated buffer.
    UINT32 bufferFrameCount = 0;
    hr = audioClient->GetBufferSize(&bufferFrameCount);
    if (FAILED(hr)) {
      ttv::trace::Message(
        "WinAudioCapture", MessageLevel::Error, "Inside WinAudioCapture::Init - audioClient->GetBufferSize failed");
      return TTV_EC_BROADCAST_AUDIO_DEVICE_INIT_FAILED;
    }

    // Get the capture client interface
    IAudioCaptureClient* captureClientPtr = nullptr;
    hr = audioClient->GetService(IID_IAudioCaptureClient, reinterpret_cast<void**>(&captureClientPtr));
    if (FAILED(hr)) {
      ttv::trace::Message(
        "WinAudioCapture", MessageLevel::Error, "Inside WinAudioCapture::Init - audioClient->GetService failed");
      return TTV_EC_BROADCAST_AUDIO_DEVICE_INIT_FAILED;
    }
    std::unique_ptr<IAudioCaptureClient, COMObjectDeleter<IAudioCaptureClient>> captureClient(captureClientPtr);

    mNumInputSamplesPerSecond = pWfx->nSamplesPerSec;

    // If we need to do any re-channeling or rate re-sampling, then create an audio converter
    //
    size_t inputChannelCount = std::bitset<sizeof(pEx->dwChannelMask)>(pEx->dwChannelMask).count();
    if (inputChannelCount != kNumAudioChannels || pWfx->nSamplesPerSec != kAudioEncodeRate) {
      // Preallocate enough space for the number of samples required for 1 second of the output format
      mResampleBuffer.resize(kNumAudioChannels * kAudioEncodeRate, 0);

      mAudioConverter = BuildAudioConverter(inputChannelCount, pWfx->nSamplesPerSec);
      if (mAudioConverter == nullptr) {
        ttv::trace::Message("WinAudioCapture", MessageLevel::Error,
          "Inside WinAudioCapture::Init - unsupported audio format from capture device - ChannelCount: %zd, SampleRate: %lu",
          inputChannelCount, pWfx->nSamplesPerSec);
        return TTV_EC_BROADCAST_UNSUPPORTED_INPUT_FORMAT;
      }
    }

    // Start capturing
    hr = audioClient->Start();
    if (FAILED(hr)) {
      ttv::trace::Message(
        "WinAudioCapture", MessageLevel::Error, "Inside WinAudioCapture::Init - audioClient->Start failed");
      return TTV_EC_BROADCAST_AUDIO_DEVICE_INIT_FAILED;
    }

    // Now that everything succeeded, assign the audioclient and captureclient to the member pointers
    WinAudioCaptureInternalData::mAudioClient = std::move(audioClient);
    WinAudioCaptureInternalData::mCaptureClient = std::move(captureClient);

    return ec;
  }

  void Destroy() {
    if (mAudioClient != nullptr) {
      (void)mAudioClient->Stop();
      mAudioClient.reset();
    }

    mAudioConverter.reset();

    CoUninitialize();
  }

 public:
  WinAudioCapture::CaptureType mType;

  std::shared_ptr<IPcmAudioFrameReceiver> mReceiver;  // The receiver to use to send the audio data to the mixer.
  std::unique_ptr<IAudioClient, COMObjectDeleter<IAudioClient>>
    mAudioClient;  // The Windows audio client which supplies PCM audio data.
  std::unique_ptr<IAudioCaptureClient, COMObjectDeleter<IAudioCaptureClient>> mCaptureClient;

  std::unique_ptr<IAudioConverter> mAudioConverter;
  uint64_t mInitialDevicePosition;
  size_t mOutputSamplePosition;
  std::vector<int16_t> mResampleBuffer;
  uint32_t mNumInputSamplesPerSecond;  // The Hz of the input.  This value is per-channel
};

ttv::broadcast::WinAudioCapture::WinAudioCapture(CaptureType type) {
  ttv::trace::Message("WinAudioCapture", MessageLevel::Info, "WinAudioCapture created");

  mInternalData = std::make_unique<WinAudioCaptureInternalData>(type);
}

ttv::broadcast::WinAudioCapture::~WinAudioCapture() {
  ttv::trace::Message("WinAudioCapture", MessageLevel::Info, "WinAudioCapture destroyed");
}

extern "C" {
const CLSID CLSID_MMDeviceEnumerator = __uuidof(MMDeviceEnumerator);
const IID IID_IMMDeviceEnumerator = __uuidof(IMMDeviceEnumerator);
const IID IID_IAudioClient = __uuidof(IAudioClient);
const IID IID_IAudioCaptureClient = __uuidof(IAudioCaptureClient);
}

std::string ttv::broadcast::WinAudioCapture::GetName() const {
  if (mInternalData->mType == CaptureType::System) {
    return "WinAudioCapture - System";
  } else {
    return "WinAudioCapture - Microphone";
  }
}

uint32_t ttv::broadcast::WinAudioCapture::GetNumChannels() const {
  return kNumAudioChannels;
}

TTV_ErrorCode ttv::broadcast::WinAudioCapture::Start() {
  TTV_ErrorCode ec = AudioCaptureBase::Start();

  // Get the receiver
  if (TTV_SUCCEEDED(ec)) {
    std::shared_ptr<IAudioFrameReceiver> receiver =
      mAudioMixer->GetReceiverImplementation(IPcmAudioFrameReceiver::GetReceiverTypeId());

    if (receiver != nullptr) {
      mInternalData->mReceiver = std::static_pointer_cast<IPcmAudioFrameReceiver>(receiver);
    } else {
      ec = TTV_EC_BROADCAST_INVALID_ENCODER;
    }
  }

  // Create audio components and start
  if (TTV_SUCCEEDED(ec)) {
    ec = mInternalData->Create();
  }

  return ec;
}

TTV_ErrorCode ttv::broadcast::WinAudioCapture::Stop() {
  TTV_ErrorCode ec = AudioCaptureBase::Stop();

  if (TTV_SUCCEEDED(ec)) {
    mInternalData->Destroy();
  }

  return ec;
}

// TODO: This method is too long and complex. Needs to be broken up into
// smaller parts
TTV_ErrorCode ttv::broadcast::WinAudioCapture::Process(
  const std::shared_ptr<IAudioMixer>& mixer, uint64_t& lastSampleTime) {
  if (!mStarted) {
    return TTV_EC_INVALID_STATE;
  }

  TTV_ASSERT(mixer != nullptr);
  TTV_ASSERT(mInternalData->mNumInputSamplesPerSecond > 0);

  if (mInternalData->mNumInputSamplesPerSecond == 0) {
    return TTV_EC_BROADCAST_INVALID_SAMPLERATE;
  }

  TTV_ErrorCode ret = TTV_EC_BROADCAST_NOMOREDATA;

  uint64_t initialQpc =
    mInitialSysTime;  // Store this locally since it can be set from another thread in SetInitialTime()

  uint64_t curSysQpc = GetSystemClockTime();
  if (initialQpc == 0 || curSysQpc < initialQpc) {
    // Initial QPC hasn't been set yet or the current time is before it; don't do anything
    ttv::trace::Message("WinAudioCapture", MessageLevel::Warning,
      "WinAudioCapture::Process called before init...returning with output sample position = %lld",
      mInternalData->mOutputSamplePosition);
  } else {
    // The length of a data packet is expressed as the number of audio frames in the packet.
    // The size in bytes of an audio frame equals the number of channels in the stream multiplied by the sample size per
    // channel.
    UINT32 packetLength = 0;
    HRESULT hr = mInternalData->mCaptureClient->GetNextPacketSize(&packetLength);
    if (AUDCLNT_E_DEVICE_INVALIDATED == hr) {
      // Get the new default audio device
      mInternalData->Destroy();
      ret = mInternalData->Create();

      if (TTV_FAILED(ret)) {
        return ret;
      }

      hr = mInternalData->mCaptureClient->GetNextPacketSize(&packetLength);
    } else if (FAILED(hr)) {
      ttv::trace::Message("WinAudioCapture", MessageLevel::Error,
        "Inside WinAudioCapture::Process - captureClient->GetNextPacketSize() failed");
    }
    EXIT_ON_ERROR(hr);

    while (packetLength != 0) {
      ret = TTV_EC_SUCCESS;

      // Get the available data in the shared buffer.
      uint8_t* captureData = nullptr;
      uint32_t numSamplesAvailable = 0;
      DWORD flags;

      UINT64 devicePosition = 0;
      hr =
        mInternalData->mCaptureClient->GetBuffer(&captureData, &numSamplesAvailable, &flags, &devicePosition, nullptr);
      if (FAILED(hr)) {
        ttv::trace::Message(
          "WinAudioCapture", MessageLevel::Error, "Inside WinAudioCapture::Process - captureClient->GetBuffer failed");
      }
      EXIT_ON_ERROR(hr);

      if (mInternalData->mInitialDevicePosition == 0) {
        mInternalData->mInitialDevicePosition = devicePosition;
        devicePosition = 0;
      } else {
        devicePosition -= mInternalData->mInitialDevicePosition;
      }

      TTV_ASSERT(numSamplesAvailable > 0);
      TTV_ASSERT(captureData != nullptr);

      if (captureData != nullptr) {
        if (mInternalData->mAudioConverter != nullptr) {
          mInternalData->mAudioConverter->BindInputBuffer(
            captureData, {static_cast<size_t>(devicePosition), numSamplesAvailable});
          SampleRange outputRange = mInternalData->mAudioConverter->GetOutputSampleRange();

          mInternalData->mOutputSamplePosition = outputRange.startIndex;

          do {
            void* outputBuffer = mInternalData->mResampleBuffer.data();
            size_t maxOutputSampleCount = mInternalData->mResampleBuffer.size() / kNumAudioChannels;

            SampleRange writtenRange = mInternalData->mAudioConverter->TransferToOutputBuffer(
              outputBuffer, {mInternalData->mOutputSamplePosition, maxOutputSampleCount});

            ret = SubmitSamples(outputBuffer, writtenRange.sampleCount, mInternalData->mOutputSamplePosition);
            mInternalData->mOutputSamplePosition = writtenRange.startIndex + writtenRange.sampleCount;

            EXIT_ON_ERROR((ret == TTV_EC_SUCCESS ? S_OK : E_FAIL));
          } while (mInternalData->mOutputSamplePosition < outputRange.startIndex + outputRange.sampleCount);

          mInternalData->mAudioConverter->UnbindInputBuffer();
        } else {
          mInternalData->mOutputSamplePosition = static_cast<size_t>(devicePosition);

          ret = SubmitSamples(captureData, numSamplesAvailable, mInternalData->mOutputSamplePosition);
          EXIT_ON_ERROR((ret == TTV_EC_SUCCESS ? S_OK : E_FAIL));
          mInternalData->mOutputSamplePosition += numSamplesAvailable;
        }
      }

      TTV_ASSERT(ret == TTV_EC_SUCCESS);
      EXIT_ON_ERROR((ret == TTV_EC_SUCCESS ? S_OK : E_FAIL));

      hr = mInternalData->mCaptureClient->ReleaseBuffer(numSamplesAvailable);
      if (FAILED(hr)) {
        ttv::trace::Message("WinAudioCapture", MessageLevel::Error,
          "Inside WinAudioCapture::Process - captureClient->ReleaseBuffer failed");
      }
      EXIT_ON_ERROR(hr);

      hr = mInternalData->mCaptureClient->GetNextPacketSize(&packetLength);
      if (FAILED(hr)) {
        ttv::trace::Message("WinAudioCapture", MessageLevel::Error,
          "Inside WinAudioCapture::Process - captureClient->GetNextPacketSize failed on second call");
      }
      EXIT_ON_ERROR(hr);
    }
  }

Exit:

  lastSampleTime = mInternalData->mOutputSamplePosition;
  return ret;
}

TTV_ErrorCode ttv::broadcast::WinAudioCapture::SubmitSamples(void* samples, size_t sampleCount, uint64_t timestamp) {
  std::shared_ptr<AudioFrame> audioFrame;
  TTV_ErrorCode ec =
    mInternalData->mReceiver->PackageFrame(reinterpret_cast<uint8_t*>(samples), static_cast<uint32_t>(sampleCount),
      kNumAudioChannels, kPcmInterleaved, AudioSampleFormat::TTV_ASF_PCM_S16, timestamp, audioFrame);

  if (TTV_FAILED(ec)) {
    return ec;
  }
  return mAudioMixer->SubmitFrame(mAudioLayer, audioFrame);
}
