/****************************************************************************
 * Twitch SDK
 *
 * This software is supplied under the terms of a license agreement with
 * Twitch Interactive, Inc. and may not be copied or used except in accordance
 * with the terms of that agreement
 *
 * Copyright (c) 2012-2016 Twitch Interactive, Inc.
 ***************************************************************************/

#include "twitchsdk/broadcast/internal/pch.h"

#include "twitchsdk/broadcast/lameaudioencoder.h"

#include "twitchsdk/broadcast/audioconstants.h"
#include "twitchsdk/broadcast/audioframe.h"
#include "twitchsdk/broadcast/iframewriter.h"
#include "twitchsdk/broadcast/internal/pcmaudioframe.h"
#include "twitchsdk/broadcast/ipcmaudioframereceiver.h"
#include "twitchsdk/broadcast/packet.h"
#include "twitchsdk/core/assertion.h"

#include <lame/lame.h>

namespace {
using namespace ttv;
using namespace ttv::broadcast;

const char* kLoggerName = "LameAudioEncoder";
const uint32_t kPcmSamplesPerMp3Frame = 1152;
const uint32_t kNumAudioChannels = 2;  // We currently assume 2 input channels
}  // namespace

struct ttv::broadcast::LameAudioEncoderInternalData {
  LameAudioEncoderInternalData()
      : mLame(nullptr),
        mFrameWriter(nullptr),
        mStreamIndex(0),
        mTotalSamplesWritten(0),
        mMp3FrameSizeInBytes(0),
        mCacheWritePos(0),
        mInitialized(false),
        mStarted(false) {
    mPcmReceiver = std::make_shared<PcmAudioReceiver>(kPcmSamplesPerMp3Frame);
  }

  std::shared_ptr<PcmAudioReceiver> mPcmReceiver;

  lame_global_struct* mLame;
  std::shared_ptr<IFrameWriter> mFrameWriter;
  uint32_t mStreamIndex;
  uint64_t mTotalSamplesWritten;

  uint32_t mMp3FrameSizeInBytes;  // Not including possible padding byte
  std::vector<unsigned char> mCache;
  uint32_t mCacheWritePos;
  bool mInitialized;
  bool mStarted;
};

ttv::broadcast::LameAudioEncoder::LameAudioEncoder()
    : mInternalData(std::make_unique<LameAudioEncoderInternalData>()) {}

ttv::broadcast::LameAudioEncoder::~LameAudioEncoder() {
  Shutdown();
}

std::string ttv::broadcast::LameAudioEncoder::GetName() const {
  return kLoggerName;
}

TTV_ErrorCode ttv::broadcast::LameAudioEncoder::Initialize() {
  ttv::trace::Message(kLoggerName, MessageLevel::Debug, "LameAudioEncoder::Initialize()");

  if (mInternalData->mInitialized) {
    return TTV_EC_INVALID_STATE;
  }

  mInternalData->mInitialized = true;

  return TTV_EC_SUCCESS;
}

TTV_ErrorCode ttv::broadcast::LameAudioEncoder::SetFrameWriter(const std::shared_ptr<IFrameWriter>& frameWriter) {
  if (mInternalData->mStarted) {
    return TTV_EC_INVALID_STATE;
  }

  mInternalData->mFrameWriter = frameWriter;

  return TTV_EC_SUCCESS;
}

TTV_ErrorCode ttv::broadcast::LameAudioEncoder::GetAudioEncodingFormat(AudioFormat& result) {
  result = AudioFormat::MP3;

  return TTV_EC_SUCCESS;
}

TTV_ErrorCode ttv::broadcast::LameAudioEncoder::Start(uint32_t streamIndex, const AudioParams& audioParams) {
  ttv::trace::Message(kLoggerName, MessageLevel::Debug, "LameAudioEncoder::Start()");

  auto& data = *mInternalData;

  if (!data.mInitialized || data.mStarted) {
    return TTV_EC_INVALID_STATE;
  }

  TTV_ASSERT(data.mFrameWriter != nullptr);
  if (data.mFrameWriter == nullptr) {
    return TTV_EC_INVALID_STATE;
  }

  data.mStreamIndex = streamIndex;
  data.mTotalSamplesWritten = 0;

  data.mLame = lame_init();
  if (data.mLame == nullptr) {
    return TTV_EC_BROADCAST_LAMEMP3_FAILED_INIT;
  }

  int lameRet = lame_set_in_samplerate(data.mLame, kAudioEncodeRate);
  TTV_ASSERT(lameRet == LAME_NOERROR);
  lameRet = lame_set_out_samplerate(data.mLame, kAudioEncodeRate);
  TTV_ASSERT(lameRet == LAME_NOERROR);
  lameRet = lame_set_num_channels(data.mLame, kNumAudioChannels);
  TTV_ASSERT(lameRet == LAME_NOERROR);
  lameRet = lame_set_brate(data.mLame, kMp3Bitrate);
  TTV_ASSERT(lameRet == LAME_NOERROR);
  lameRet = lame_set_mode(data.mLame, JOINT_STEREO);
  TTV_ASSERT(lameRet == LAME_NOERROR);
  lameRet = lame_init_params(data.mLame);
  TTV_ASSERT(lameRet == LAME_NOERROR);

  // from http://mpgedit.org/mpgedit/mpeg_format/mpeghdr.htm
  // FrameLengthInBytes = 144 * BitRate / SampleRate + Padding

  data.mMp3FrameSizeInBytes = 144 * (kMp3Bitrate * 1000) / kAudioEncodeRate;
  uint32_t cacheSize = 3 * (data.mMp3FrameSizeInBytes + 1);  // account for possible padding byte
                                                             // and save space for 3 frames of data

  data.mCache.resize(static_cast<size_t>(cacheSize));
  data.mCacheWritePos = 0;
  memset(&data.mCache[0], 0x00, cacheSize);

  if (lameRet == LAME_NOERROR) {
    data.mStarted = true;
    return TTV_EC_SUCCESS;
  } else {
    return TTV_EC_BROADCAST_LAMEMP3_FAILED_INIT;
  }
}

TTV_ErrorCode ttv::broadcast::LameAudioEncoder::GetNumInputSamplesPerEncodeFrame(uint32_t& result) {
  result = kPcmSamplesPerMp3Frame;

  return TTV_EC_SUCCESS;
}

TTV_ErrorCode ttv::broadcast::LameAudioEncoder::WritePacket() {
  ttv::trace::Message(kLoggerName, MessageLevel::Debug, "LameAudioEncoder::WritePacket()");

  auto& data = *mInternalData;

  if (data.mCacheWritePos >= 4)  // 4 byte header
  {
    // verify that the first 4 bytes are an mp3 frame buffer.
    // See http://mpgedit.org/mpgedit/mpeg_format/mpeghdr.htm
    // for details

    // first 11 bits need to be set
    TTV_ASSERT(data.mCache[0] == 0xff);
    TTV_ASSERT((data.mCache[1] & 0xE0) == 0xE0);
    // TODO: Search for a new frame boundary
    //       if this fails.

    uint32_t frameSize = data.mMp3FrameSizeInBytes;

    // does this packet have a frame of padding?
    if ((data.mCache[2] & 0x2) == 0x2) {
      frameSize++;
    }

    // Do we have enough data in our cache for this entire frame
    if (data.mCacheWritePos >= frameSize) {
      auto packet = std::make_unique<Packet>();

      packet->data.resize(static_cast<size_t>(frameSize));
      memcpy(packet->data.data(), data.mCache.data(), frameSize);
      memmove(&data.mCache[0], &data.mCache[frameSize], data.mCacheWritePos - frameSize);

      data.mCacheWritePos -= frameSize;
#if _DEBUG
      memset(&data.mCache[data.mCacheWritePos], 0x00, frameSize);
#endif

      packet->keyframe = true;
      packet->streamIndex = data.mStreamIndex;

      uint32_t samplesPerFrame = 0;
      GetNumInputSamplesPerEncodeFrame(samplesPerFrame);
      data.mTotalSamplesWritten += static_cast<uint64_t>(samplesPerFrame);
      packet->timestamp = data.mTotalSamplesWritten * 1000 / kAudioEncodeRate;

      return data.mFrameWriter->WritePacket(std::move(packet));
    }
  }

  return TTV_EC_NOTENOUGHDATA;
}

TTV_ErrorCode ttv::broadcast::LameAudioEncoder::SubmitPcmSamples(
  const int16_t* samples, uint32_t numSamplesPerChannel) {
  ttv::trace::Message(kLoggerName, MessageLevel::Debug, "LameAudioEncoder::SubmitPcmSamples()");

  auto& data = *mInternalData;

  TTV_ASSERT(samples != nullptr);
  TTV_ASSERT(data.mLame != nullptr);

  TTV_ErrorCode ec = TTV_EC_SUCCESS;

  short int* signedBuffer = const_cast<short int*>(reinterpret_cast<const short int*>(samples));
  uint32_t remainingSpace = static_cast<uint32_t>(data.mCache.size()) - data.mCacheWritePos;
  unsigned char* outputBuffer = static_cast<unsigned char*>(&data.mCache[data.mCacheWritePos]);

  int bytesEncoded = lame_encode_buffer_interleaved(
    data.mLame, signedBuffer, static_cast<int>(numSamplesPerChannel), outputBuffer, static_cast<int>(remainingSpace));

  if (bytesEncoded < 0) {
    switch (bytesEncoded) {
      case -1:
        ttv::trace::Message(
          kLoggerName, MessageLevel::Error, "LameAudioEncoder::SubmitPcmSamples() error - mp3buf was too small");
        break;
      case -2:
        ttv::trace::Message(
          kLoggerName, MessageLevel::Error, "LameAudioEncoder::SubmitPcmSamples() error - malloc() problem");
        break;
      case -3:
        ttv::trace::Message(kLoggerName, MessageLevel::Error,
          "LameAudioEncoder::SubmitPcmSamples() error - lame_init_params() not called");
        break;
      case -4:
        ttv::trace::Message(
          kLoggerName, MessageLevel::Error, "LameAudioEncoder::SubmitPcmSamples() error - psycho acoustic problems");
        break;
      default:
        ttv::trace::Message(kLoggerName, MessageLevel::Error, "LameAudioEncoder::SubmitPcmSamples() error - unknown");
        break;
    }

    ec = TTV_EC_BROADCAST_ENCODE_FAILED;

    TTV_ASSERT(false);
  } else if (bytesEncoded > 0) {
    data.mCacheWritePos += bytesEncoded;
    TTV_ASSERT(data.mCacheWritePos <= static_cast<uint32_t>(data.mCache.size()));

    ec = WritePacket();
  }

  return ec;
}

TTV_ErrorCode ttv::broadcast::LameAudioEncoder::Stop() {
  ttv::trace::Message(kLoggerName, MessageLevel::Debug, "LameAudioEncoder::Stop()");

  auto& data = *mInternalData;

  TTV_ErrorCode ret = TTV_EC_SUCCESS;

  if (data.mLame != nullptr) {
    ret = FlushAudioData();
    int lameRet = lame_close(data.mLame);
    TTV_ASSERT(lameRet == LAME_NOERROR);
    if (lameRet != LAME_NOERROR) {
      ret = TTV_EC_BROADCAST_LAMEMP3_FAILED_SHUTDOWN;
    }

    data.mLame = nullptr;
  }

  data.mFrameWriter = nullptr;
  data.mStarted = false;
  data.mCache.resize(0);
  data.mCacheWritePos = 0;

  return ret;
}

TTV_ErrorCode ttv::broadcast::LameAudioEncoder::Shutdown() {
  ttv::trace::Message(kLoggerName, MessageLevel::Debug, "LameAudioEncoder::Shutdown()");

  TTV_ErrorCode ret = Stop();

  mInternalData->mInitialized = false;

  return ret;
}

TTV_ErrorCode ttv::broadcast::LameAudioEncoder::FlushAudioData() {
  ttv::trace::Message(kLoggerName, MessageLevel::Debug, "LameAudioEncoder::FlushAudioData()");

  auto& data = *mInternalData;

  TTV_ASSERT(data.mLame);

  std::unique_ptr<uint8_t[]> buffer(new uint8_t[LAME_MAXMP3BUFFER]);
  uint bufferReadPos = 0;
  uint bufferSize = 0;

  TTV_ErrorCode ret = TTV_EC_SUCCESS;

  while (0 < (bufferSize = lame_encode_flush(data.mLame, buffer.get(), LAME_MAXMP3BUFFER))) {
    while (TTV_SUCCEEDED(ret) && bufferReadPos < bufferSize) {
      uint dataToCopy =
        std::min(bufferSize - bufferReadPos, static_cast<uint>(data.mCache.size() - data.mCacheWritePos));
      memcpy(&data.mCache[data.mCacheWritePos], &buffer[bufferReadPos], dataToCopy);
      data.mCacheWritePos += dataToCopy;
      bufferReadPos += dataToCopy;

      do {
        ret = WritePacket();
      } while (ret == TTV_EC_SUCCESS);
    }

    bufferReadPos = 0;
    if (TTV_FAILED(ret)) {
      break;
    }
  }

  return ret;
}

TTV_ErrorCode ttv::broadcast::LameAudioEncoder::SubmitFrame(const std::shared_ptr<AudioFrame>& audioFrame) {
  if (audioFrame == nullptr) {
    return TTV_EC_INVALID_ARG;
  }

  // We only handle PCM frames
  TTV_ASSERT(audioFrame->GetReceiverTypeId() == IPcmAudioFrameReceiver::GetReceiverTypeId());
  if (audioFrame->GetReceiverTypeId() != IPcmAudioFrameReceiver::GetReceiverTypeId()) {
    return TTV_EC_BROADCAST_INVALID_SUBMISSION_METHOD;
  }

  auto pcmFrame = std::static_pointer_cast<PcmAudioFrame>(audioFrame);

  // TODO: Add support for other than signed 16-bit audio samples
  TTV_ASSERT(pcmFrame->GetAudioSampleFormat() == AudioSampleFormat::TTV_ASF_PCM_S16);
  if (pcmFrame->GetAudioSampleFormat() != AudioSampleFormat::TTV_ASF_PCM_S16) {
    return TTV_EC_UNIMPLEMENTED;
  }

  // TODO: We only handle interleved audio right now
  TTV_ASSERT(pcmFrame->GetInterleaved());
  if (!pcmFrame->GetInterleaved()) {
    return TTV_EC_UNIMPLEMENTED;
  }

  TTV_ASSERT(pcmFrame->GetNumChannels() == kNumAudioChannels);

  const int16_t* sampleBuffer = reinterpret_cast<const int16_t*>(pcmFrame->GetSampleBuffer().data());

  return SubmitPcmSamples(sampleBuffer, pcmFrame->GetNumSamplesPerChannel());
}

bool ttv::broadcast::LameAudioEncoder::SupportsReceiverProtocol(IAudioFrameReceiver::ReceiverTypeId typeId) const {
  return typeId == IPcmAudioFrameReceiver::GetReceiverTypeId();
}

std::shared_ptr<ttv::broadcast::IAudioFrameReceiver> ttv::broadcast::LameAudioEncoder::GetReceiverImplementation(
  IAudioFrameReceiver::ReceiverTypeId typeId) {
  if (typeId == IPcmAudioFrameReceiver::GetReceiverTypeId()) {
    return mInternalData->mPcmReceiver;
  }

  return nullptr;
}
