/****************************************************************************
 * Twitch SDK
 *
 * This software is supplied under the terms of a license agreement with
 * Twitch Interactive, Inc. and may not be copied or used except in accordance
 * with the terms of that agreement
 *
 * Copyright (c) 2012-2016 Twitch Interactive, Inc.
 ***************************************************************************/

#include "twitchsdk/broadcast/internal/pch.h"

#include "twitchsdk/broadcast/passthroughaudioencoder.h"

#include "twitchsdk/broadcast/audioframe.h"
#include "twitchsdk/broadcast/iframewriter.h"
#include "twitchsdk/broadcast/ipreencodedaudioframereceiver.h"
#include "twitchsdk/broadcast/packet.h"
#include "twitchsdk/core/systemclock.h"

namespace {
using namespace ttv::broadcast;

const char* kLoggerName = "PassThroughAudioEncoder";

/**
 * The audio frame required to pack pre-encoded data.
 */
class PreEncodedAudioFrame : public AudioFrame {
 public:
  PreEncodedAudioFrame(IPreEncodedAudioFrameReceiver::AudioPacket&& frameData, AudioFormat audioFormat,
    uint32_t numChannels, uint64_t timeStamp)
      : AudioFrame(IPreEncodedAudioFrameReceiver::GetReceiverTypeId()), mFrameData(std::move(frameData)) {
    SetNumChannels(numChannels);
    SetSampleRateHz(static_cast<uint32_t>(AudioSampleRate::Hz44100));
    SetTimeStamp(timeStamp);
    SetAudioFormat(audioFormat);
  }

  const IPreEncodedAudioFrameReceiver::AudioPacket& GetFrameData() const { return mFrameData; }

 private:
  IPreEncodedAudioFrameReceiver::AudioPacket mFrameData;
};

class PreEncodedReceiver : public IPreEncodedAudioFrameReceiver {
 public:
  virtual TTV_ErrorCode PackageFrame(AudioPacket&& packet, AudioFormat audioFormat, uint32_t numChannels,
    uint64_t timeStamp, std::shared_ptr<AudioFrame>& result) override {
    result = std::make_shared<PreEncodedAudioFrame>(std::move(packet), audioFormat, numChannels, timeStamp);

    return TTV_EC_SUCCESS;
  }
};
}  // namespace

class ttv::broadcast::PassThroughAudioEncoderInternalData {
 public:
  PassThroughAudioEncoderInternalData()
      : mStreamIndex(0), mSamplesPerFrame(0), mTotalSamplesWritten(0), mInitialized(false), mStarted(false) {
    mReceiver = std::make_shared<PreEncodedReceiver>();
  }

 public:
  std::shared_ptr<PreEncodedReceiver> mReceiver;
  std::shared_ptr<IFrameWriter> mFrameWriter;
  AudioFormat mAudioFormat;  //!< This encoder can only accept one type of audio data.
  uint32_t mStreamIndex;
  uint32_t mSamplesPerFrame;
  uint64_t mTotalSamplesWritten;
  bool mInitialized;
  bool mStarted;
};

ttv::broadcast::PassThroughAudioEncoder::PassThroughAudioEncoder()
    : mInternalData(std::make_shared<PassThroughAudioEncoderInternalData>()) {}

ttv::broadcast::PassThroughAudioEncoder::~PassThroughAudioEncoder() {
  Stop();
}

TTV_ErrorCode ttv::broadcast::PassThroughAudioEncoder::SetAudioFormat(AudioFormat format) {
  if (mInternalData->mStarted) {
    return TTV_EC_INVALID_STATE;
  }

  mInternalData->mAudioFormat = format;

  return TTV_EC_SUCCESS;
}

TTV_ErrorCode ttv::broadcast::PassThroughAudioEncoder::SetSamplesPerFrame(uint32_t samplesPerFrame) {
  if (mInternalData->mStarted) {
    return TTV_EC_INVALID_STATE;
  }

  mInternalData->mSamplesPerFrame = samplesPerFrame;

  return TTV_EC_SUCCESS;
}

std::string ttv::broadcast::PassThroughAudioEncoder::GetName() const {
  return kLoggerName;
}

TTV_ErrorCode ttv::broadcast::PassThroughAudioEncoder::Initialize() {
  ttv::trace::Message(kLoggerName, MessageLevel::Debug, "PassThroughAudioEncoder::Initialize()");

  if (mInternalData->mInitialized) {
    return TTV_EC_INVALID_STATE;
  }

  mInternalData->mInitialized = true;

  return TTV_EC_SUCCESS;
}

TTV_ErrorCode ttv::broadcast::PassThroughAudioEncoder::SetFrameWriter(
  const std::shared_ptr<IFrameWriter>& frameWriter) {
  if (mInternalData->mStarted) {
    return TTV_EC_INVALID_STATE;
  }

  mInternalData->mFrameWriter = frameWriter;

  return TTV_EC_SUCCESS;
}

TTV_ErrorCode ttv::broadcast::PassThroughAudioEncoder::GetAudioEncodingFormat(AudioFormat& result) {
  result = mInternalData->mAudioFormat;

  return TTV_EC_SUCCESS;
}

TTV_ErrorCode ttv::broadcast::PassThroughAudioEncoder::Start(uint32_t streamIndex, const AudioParams& /*audioParams*/) {
  ttv::trace::Message(kLoggerName, MessageLevel::Debug, "PassThroughAudioEncoder::Start()");

  auto& data = *mInternalData;

  if (!data.mInitialized || data.mStarted) {
    return TTV_EC_INVALID_STATE;
  }

  TTV_ASSERT(data.mFrameWriter != nullptr);
  if (data.mFrameWriter == nullptr) {
    return TTV_EC_INVALID_STATE;
  }

  data.mStarted = true;
  data.mStreamIndex = streamIndex;

  return TTV_EC_SUCCESS;
}

TTV_ErrorCode ttv::broadcast::PassThroughAudioEncoder::GetNumInputSamplesPerEncodeFrame(uint32_t& result) {
  result = mInternalData->mSamplesPerFrame;

  return TTV_EC_SUCCESS;
}

TTV_ErrorCode ttv::broadcast::PassThroughAudioEncoder::Stop() {
  ttv::trace::Message(kLoggerName, MessageLevel::Debug, "PassThroughAudioEncoder::Stop()");

  TTV_ErrorCode ec = TTV_EC_SUCCESS;

  mInternalData->mStarted = false;

  return ec;
}

TTV_ErrorCode ttv::broadcast::PassThroughAudioEncoder::Shutdown() {
  ttv::trace::Message(kLoggerName, MessageLevel::Debug, "PassThroughAudioEncoder::Shutdown()");

  TTV_ErrorCode ec = Stop();

  if (TTV_SUCCEEDED(ec)) {
    mInternalData->mInitialized = false;
  }

  return ec;
}

TTV_ErrorCode ttv::broadcast::PassThroughAudioEncoder::SubmitFrame(const std::shared_ptr<AudioFrame>& audioFrame) {
  ttv::trace::Message("LameAudioEncoder", MessageLevel::Debug, "LameAudioEncoder::WritePacket()");

  auto& data = *mInternalData;

  if (!data.mStarted) {
    return TTV_EC_INVALID_STATE;
  }

  if (audioFrame == nullptr) {
    return TTV_EC_INVALID_ARG;
  }

  TTV_ASSERT(audioFrame->GetReceiverTypeId() == IPreEncodedAudioFrameReceiver::GetReceiverTypeId());
  if (audioFrame->GetReceiverTypeId() != IPreEncodedAudioFrameReceiver::GetReceiverTypeId()) {
    return TTV_EC_BROADCAST_INVALID_SUBMISSION_METHOD;
  }

  auto preEncodedFrame = std::static_pointer_cast<PreEncodedAudioFrame>(audioFrame);

  auto packet = std::make_unique<Packet>();

  const auto& buffer = preEncodedFrame->GetFrameData();
  packet->data.resize(buffer.size());
  memcpy(packet->data.data(), buffer.data(), buffer.size());

  packet->keyframe = true;
  packet->streamIndex = 1;

  uint32_t samplesPerFrame = 0;
  GetNumInputSamplesPerEncodeFrame(samplesPerFrame);
  data.mTotalSamplesWritten += static_cast<uint64_t>(data.mSamplesPerFrame);
  packet->timestamp = audioFrame->GetTimeStamp();  // TODO: Is this the right thing to do?

  return data.mFrameWriter->WritePacket(std::move(packet));
}

bool ttv::broadcast::PassThroughAudioEncoder::SupportsReceiverProtocol(
  IAudioFrameReceiver::ReceiverTypeId typeId) const {
  return typeId == IPreEncodedAudioFrameReceiver::GetReceiverTypeId();
}

std::shared_ptr<ttv::broadcast::IAudioFrameReceiver> ttv::broadcast::PassThroughAudioEncoder::GetReceiverImplementation(
  IAudioFrameReceiver::ReceiverTypeId typeId) {
  if (typeId == IPreEncodedAudioFrameReceiver::GetReceiverTypeId()) {
    return mInternalData->mReceiver;
  }

  return nullptr;
}
