/****************************************************************************
 * Twitch SDK
 *
 * This software is supplied under the terms of a license agreement with
 * Twitch Interactive, Inc. and may not be copied or used except in accordance
 * with the terms of that agreement
 *
 * Copyright (c) 2012-2016 Twitch Interactive, Inc.
 ***************************************************************************/

#include "twitchsdk/broadcast/internal/pch.h"

#include "twitchsdk/broadcast/internal/muxers/flvmuxer.h"

#include "twitchsdk/broadcast/internal/muxers/amf0encoder.h"
#include "twitchsdk/broadcast/internal/muxers/rtmpstream.h"
#include "twitchsdk/broadcast/internal/streamstats.h"
#include "twitchsdk/broadcast/packet.h"

#include <codecvt>
#include <iostream>

namespace {
const uint8_t kFlvMp3CodecId = 2;
const uint8_t kFLvAACCodecId = 10;
const uint8_t kFlvH264CodecId = 7;

const uint8_t kFlvAudioStereo = 0x1;
const uint8_t kFlvAudio16Bit = 0x2;
const uint8_t kFlvAudio11250 = 0x4;
const uint8_t kFlvAudio22500 = 0x8;
const uint8_t kFlvAudio44100 = 0x4 | 0x8;
const uint8_t kFlvAudioCodecShift = 4;
const uint8_t kFlvKeyFrame = 0x10;
const uint8_t kFlvRegularFrame = 0x20;

const uint8_t kFlvHasAudio = 0x04;
const uint8_t kFlvHasVideo = 0x01;

const uint8_t kAVCPacketSequenceHeader = 0;
const uint8_t kAVCPacketNALU = 1;

template <typename result_t>
result_t ToBigEndian(result_t param) {
  result_t result = 0;
  auto big = reinterpret_cast<uint8_t*>(&result);
  auto little = reinterpret_cast<uint8_t*>(&param);

  for (size_t i = 0; i < sizeof(result_t); i++) {
    big[i] = little[sizeof(result_t) - i - 1];
  }

  return result;
}

void ToBigEndian24(uint32_t param, ttv::broadcast::flv::uint24_t big) {
  auto little = reinterpret_cast<uint8_t*>(&param);
  big[0] = little[2];
  big[1] = little[1];
  big[2] = little[0];
}
}  // namespace

ttv::broadcast::FlvMuxer::FlvMuxer(std::shared_ptr<StreamStats> streamStats)
    : mOutputFile(nullptr),
      mStreamStats(streamStats),
      mRtmpStream(nullptr),
      mTotalVideoPacketsSent(0),
      mOutputConnected(false) {}

ttv::broadcast::FlvMuxer::~FlvMuxer() {
  Stop();
}

TTV_ErrorCode ttv::broadcast::FlvMuxer::Start(const MuxerParameters& params) {
  mInitParams = params;

  if (!mRtmpUrl.empty()) {
    mRtmpStream = std::make_unique<RtmpStream>(mStreamStats);
    mRtmpStream->Start(mRtmpUrl);
  }

  uint8_t audioCodecId = 0;
  switch (params.audioFormat) {
    case AudioFormat::None:
    case AudioFormat::PCM:  // The backend doesn't support PCM
    case AudioFormat::MP3:
      audioCodecId = kFlvMp3CodecId;
      break;
    case AudioFormat::AAC:
      audioCodecId = kFLvAACCodecId;
      break;
    default:
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wunreachable-code"
      return TTV_EC_BROADCAST_FLV_UNSUPPORTED_AUDIO_CODEC;
#pragma clang diagnostic pop
  }

  uint8_t audioFlags = static_cast<uint8_t>((audioCodecId << kFlvAudioCodecShift) | kFlvAudio16Bit);

  if (audioCodecId == kFLvAACCodecId) {
    // NOTE: For AAC, stereo flag is always set and the sample rate is 44100 even though
    // the data may contain different values.  See the FLV spec page 10
    // https://www.adobe.com/content/dam/Adobe/en/devnet/flv/pdfs/video_file_format_spec_v10.pdf
    audioFlags |= kFlvAudioStereo;
    audioFlags |= kFlvAudio44100;
  } else {
    if (params.audioStereo) {
      audioFlags |= kFlvAudioStereo;
    }

    switch (params.audioSampleRate) {
      case 11025:
        audioFlags |= kFlvAudio11250;
        break;
      case 22050:
        audioFlags |= kFlvAudio22500;
        break;
      case 44100:
        audioFlags |= kFlvAudio44100;
        break;
      default:
        return TTV_EC_BROADCAST_FLV_UNSUPPORTED_AUDIO_RATE;
    }
  }

  mAudioFlags.push_back(audioFlags);

  if (audioCodecId == kFLvAACCodecId) {
    // AAC packets have one byte for AACPacketType which is set to 1 for "AAC raw"
    mAudioFlags.push_back(1);
  }

  mVideoFlags.resize(5, 0);
  if (!mFlvPath.empty()) {
#if WIN32
    // TODO: Fix this
    mOutputFile = _wfopen(mFlvPath.c_str(), L"wb");
#else
    // POSIX functions in MacOS X support UTF-8 strings
    // So convert our path to UTF-8
    // std::wstring_convert<std::codecvt_utf8<wchar_t> > converter;
    std::string filepath(mFlvPath.begin(), mFlvPath.end());
    mOutputFile = fopen(filepath.c_str(), "wb");
#endif
    if (mOutputFile == nullptr) {
      return TTV_EC_BROADCAST_FLV_UNABLE_TO_OPEN_FILE;
    }
  }

  if (mOutputFile != nullptr) {
    flv::Header header;
    header.mSignature[0] = 'F';
    header.mSignature[1] = 'L';
    header.mSignature[2] = 'V';
    header.mVersion = 1;
    header.mFlags = kFlvHasVideo;
    if (params.audioEnabled) {
      header.mFlags |= kFlvHasAudio;
    }
    uint32_t sizeOf = sizeof(header);
    header.mHeaderSize = ToBigEndian(sizeOf);

    WriteToOutput(&header, false);
    fseek(mOutputFile, 4, SEEK_CUR);
  }

  std::shared_ptr<AMF0Encoder> encoder = std::make_shared<AMF0Encoder>();
  encoder->String("onMetaData");
  encoder->EcmaArray(params.audioEnabled ? 11 : 7);
  encoder->EcmaArrayKey("duration");
  encoder->Number(0);  // TODO: We can jump back to the start of the file and fix this
  encoder->EcmaArrayKey("width");
  encoder->Number(params.videoWidth);
  encoder->EcmaArrayKey("height");
  encoder->Number(params.videoHeight);
  encoder->EcmaArrayKey("framerate");
  encoder->Number(params.frameRate);
  encoder->EcmaArrayKey("videocodecid");
  encoder->Number(kFlvH264CodecId);  // TODO: HARDCODED TO h.264

  if (params.audioEnabled) {
    encoder->EcmaArrayKey("audiosamplerate");
    encoder->Number(params.audioSampleRate);
    encoder->EcmaArrayKey("audiosamplesize");
    encoder->Number(params.audioSampleSize);
    encoder->EcmaArrayKey("stereo");
    encoder->Boolean(params.audioStereo);
    encoder->EcmaArrayKey("audiocodecid");
    encoder->Number(audioCodecId);
  }
  encoder->EcmaArrayKey("filesize");
  encoder->Number(0);  // TODO: We can jump back to the start of the file and fix this
  encoder->EcmaArrayKey("appVersion");
  encoder->String(params.appVersion);
  encoder->ObjectEnd();  // TODO: Write objectend only for RTMP

  TTV_ErrorCode ec = mRtmpStream->GetError();

  if (TTV_SUCCEEDED(ec)) {
    ec = WriteMetaPacket(encoder);
  }

  if (TTV_SUCCEEDED(ec)) {
    ec = WriteAudioHeader(audioFlags, params.audioFormat);
  }

  if (TTV_SUCCEEDED(ec)) {
    if (!params.videoSps.empty() && !params.videoPps.empty()) {
      ec = WriteVideoSpsPps(params.videoSps, params.videoPps);
    }
  }

  if (TTV_SUCCEEDED(ec)) {
    mOutputConnected = true;
  }

  return ec;
}

void ttv::broadcast::FlvMuxer::Update() {
  if (nullptr != mRtmpStream) {
    mRtmpStream->Update();
  }
}

TTV_ErrorCode ttv::broadcast::FlvMuxer::Stop() {
  mOutputConnected = false;

  if (mOutputFile != nullptr) {
    // TODO: Jump back to start of file and rewrite meta data
    // so we can have correct filesize and similar values
    fclose(mOutputFile);
    mOutputFile = nullptr;
  }

  if (nullptr != mRtmpStream) {
    mRtmpStream->Stop();
  }

  return TTV_EC_SUCCESS;
}

TTV_ErrorCode ttv::broadcast::FlvMuxer::BeginChunk(flv::TagTypes type, uint32_t timestamp, size_t length) {
  flv::TagHeader header;

  ToBigEndian24(static_cast<uint32_t>(length), header.mPacketLength);

  ToBigEndian24(timestamp, header.mPacketTimestamp);
  header.mPacketExtendedTimestamp = (timestamp >> 24);

  header.mPacketType = type;

  WriteToOutput(&header, false);

  TTV_ErrorCode ret = TTV_EC_SUCCESS;

  if (mRtmpStream != nullptr) {
    ret = mRtmpStream->BeginFLVChunk(type, timestamp, length);
  }

  return ret;
}

TTV_ErrorCode ttv::broadcast::FlvMuxer::EndChunk(size_t length) {
  flv::TagFooter footer;

  length += static_cast<uint32_t>(sizeof(flv::TagHeader));
  footer.mLength = ToBigEndian(static_cast<uint32_t>(length));

  WriteToOutput(&footer, false);

  TTV_ErrorCode ret = TTV_EC_SUCCESS;

  if (mRtmpStream != nullptr) {
    ret = mRtmpStream->EndFLVChunk();
  }

  return ret;
}

TTV_ErrorCode ttv::broadcast::FlvMuxer::WriteVideoPacket(const Packet& packet) {
  TTV_ASSERT(packet.timestamp <= UINT_MAX);
  uint32_t smallTimestamp = static_cast<uint32_t>(packet.timestamp);

  // TODO: Hardcoded to h.264
  uint padding = 5;

  const char nalStartSequence[3] = {0, 0, 1};  // Annex B escape
  using nalBlock_t = std::pair<const uint8_t*, size_t>;

  std::list<nalBlock_t> nalBlocks;

  auto position = std::search(packet.data.begin(), packet.data.end(), nalStartSequence, nalStartSequence + 3);

  // TODO: We should have the encoder tag the frame as Annex B or AVCC so we don't need to make this Darwin check here

  bool frameSizeEncoded = false;
#if TTV_PLATFORM_DARWIN
  // The Apple encoder on Mac/iOS sometimes places the size of the frame in the first 4 bytes of the data it returns. We
  // check if that's the case here so that we don't need to write it at the beginning of the frame
  //
  uint encodedFrameDataSize = ToBigEndian<uint>(*reinterpret_cast<uint*>(const_cast<uint8_t*>(packet.data.data())));
  if (encodedFrameDataSize == packet.data.size() - 4) {
    frameSizeEncoded = true;
  }
#endif

  // TODO: Move this conversion from Annex B to AVCC into a separate function
  if (!frameSizeEncoded) {
    while (position != packet.data.end()) {
      // Block start can be 2 or 3 0s followed by a 1, compensate for that
      if (position == packet.data.begin() ||
          packet.data[static_cast<size_t>(position - packet.data.begin() - 1)] != 0) {
        padding++;
      }
      position += 3;

      auto start = &packet.data[static_cast<size_t>(position - packet.data.begin())];

      auto end_position = std::search(position, packet.data.end(), nalStartSequence, nalStartSequence + 3);

      auto length = end_position - position;
      if (end_position == packet.data.end()) {
        length = packet.data.end() - position;
      } else if (packet.data[static_cast<size_t>(end_position - packet.data.begin() - 1)] == 0) {
        length--;
      }

      nalBlocks.push_back(nalBlock_t(start, length));

      position = end_position;
    }
  }

  auto packetSize = packet.data.size() + padding;

  TTV_ErrorCode ret = BeginChunk(flv::Video, smallTimestamp, packetSize);

  if (TTV_SUCCEEDED(ret)) {
    mVideoFlags[0] = (kFlvH264CodecId | (packet.keyframe ? kFlvKeyFrame : kFlvRegularFrame));
    mVideoFlags[1] = packet.sequenceHeader ? kAVCPacketSequenceHeader : kAVCPacketNALU;

    flv::uint24_t ctsBig;
    ToBigEndian24(packet.cts, ctsBig);
    mVideoFlags[2] = ctsBig[0];
    mVideoFlags[3] = ctsBig[1];
    mVideoFlags[4] = ctsBig[2];

    size_t written = WriteToOutput(mVideoFlags);

    if (nalBlocks.size() == 0) {
      written += WriteToOutput(packet.data);
    } else {
      for (auto it = nalBlocks.cbegin(); it != nalBlocks.cend(); ++it) {
        auto from = it->first;
        auto length = it->second;

        auto bigEndianLength = ToBigEndian<uint>(static_cast<uint>(length));
        written += WriteToOutput(&bigEndianLength);
        written += WriteToOutput(from, length);
      }
    }

// Dump as hex for debugging
#if 0
        if (written != packetSize)
        {
            for (size_t i = 0; i < packet.data.size(); ++i)
            {
                if ((i % 4) > 0)
                {
                    printf(" ");
                }

                printf("%04x", packet.data[i]);

                if ((i % 4) == 3)
                {
                    printf("\n");
                }
            }
        }
#endif

    assert(written == packetSize);
    ret = EndChunk(packetSize);
  }

  // Notify a video frame was submitted
  if (TTV_SUCCEEDED(ret) && mStreamStats != nullptr) {
    mTotalVideoPacketsSent++;
    mStreamStats->Add(StreamStats::StatType::TotalVideoPacketsSent, mTotalVideoPacketsSent);
  }

  return ret;
}

TTV_ErrorCode ttv::broadcast::FlvMuxer::WriteAudioPacket(const Packet& packet) {
  TTV_ASSERT(packet.timestamp <= UINT_MAX);
  uint32_t smallTimestamp = static_cast<uint32_t>(packet.timestamp);

  uint packetSize = static_cast<uint>(packet.data.size() + mAudioFlags.size());

  TTV_ErrorCode ret = BeginChunk(flv::Audio, smallTimestamp, packetSize);

  if (TTV_SUCCEEDED(ret)) {
    WriteToOutput(mAudioFlags);
    WriteToOutput(packet.data);

    ret = EndChunk(packetSize);
  }

  return ret;
}

TTV_ErrorCode ttv::broadcast::FlvMuxer::WriteVideoSpsPps(
  const std::vector<uint8_t>& sps, const std::vector<uint8_t>& pps) {
  const int kHeaderSkipValue = 4;
  const int kExtraPadding = 16;  // Padding for vector to accomodated the extra bytes below
  std::vector<uint8_t> videoDataheader;
  videoDataheader.reserve(sps.size() + pps.size() - kHeaderSkipValue * 2 + kExtraPadding);

  videoDataheader.push_back(kFlvH264CodecId | kFlvKeyFrame);

  // H.264 Sequence
  videoDataheader.push_back(0);

  // Composition Time
  videoDataheader.push_back(0);
  videoDataheader.push_back(0);
  videoDataheader.push_back(0);

  // Version
  videoDataheader.push_back(1);

  videoDataheader.push_back(sps[1 + kHeaderSkipValue]);
  videoDataheader.push_back(sps[2 + kHeaderSkipValue]);
  videoDataheader.push_back(sps[3 + kHeaderSkipValue]);

  videoDataheader.push_back(0xff);
  videoDataheader.push_back(0xe1);

  uint16_t spsLength = static_cast<uint16_t>(sps.size() - kHeaderSkipValue);
  uint16_t ppsLength = static_cast<uint16_t>(pps.size() - kHeaderSkipValue);

  videoDataheader.push_back(reinterpret_cast<uint8_t*>(&spsLength)[1]);
  videoDataheader.push_back(reinterpret_cast<uint8_t*>(&spsLength)[0]);

  videoDataheader.insert(videoDataheader.end(), sps.begin() + kHeaderSkipValue, sps.end());

  // Number of pps
  videoDataheader.push_back(1);

  videoDataheader.push_back(reinterpret_cast<uint8_t*>(&ppsLength)[1]);
  videoDataheader.push_back(reinterpret_cast<uint8_t*>(&ppsLength)[0]);

  videoDataheader.insert(videoDataheader.end(), pps.begin() + kHeaderSkipValue, pps.end());

  TTV_ErrorCode ret = BeginChunk(flv::Video, 0, videoDataheader.size());
  ASSERT_ON_ERROR(ret);

  if (TTV_SUCCEEDED(ret)) {
    WriteToOutput(videoDataheader);
    ret = EndChunk(videoDataheader.size());
    ASSERT_ON_ERROR(ret);
  }

  return ret;
}

TTV_ErrorCode ttv::broadcast::FlvMuxer::WriteAudioHeader(uint8_t audioFlags, AudioFormat encodingFormat) {
  TTV_ErrorCode ret = TTV_EC_SUCCESS;

  if (encodingFormat == AudioFormat::AAC) {
    std::vector<uint8_t> audioDataHeader;
    audioDataHeader.push_back(audioFlags);
    audioDataHeader.push_back(0);  // AAC Sequence Header

    struct {
      uint16_t gaSpecific : 3;
      uint16_t channels : 4;
      uint16_t sampleIndex : 4;
      uint16_t objectType : 5;
    } aacHeader;

    aacHeader.objectType = 2;  // AAC LC
    aacHeader.sampleIndex = 4;
    aacHeader.channels = mInitParams.audioStereo ? 2 : 1;
    aacHeader.gaSpecific = 0;

    audioDataHeader.push_back(reinterpret_cast<uint8_t*>(&aacHeader)[1]);
    audioDataHeader.push_back(reinterpret_cast<uint8_t*>(&aacHeader)[0]);

    ret = BeginChunk(flv::Audio, 0, audioDataHeader.size());
    ASSERT_ON_ERROR(ret);

    if (TTV_SUCCEEDED(ret)) {
      WriteToOutput(audioDataHeader);

      ret = EndChunk(audioDataHeader.size());
      ASSERT_ON_ERROR(ret);
    }
  }

  return ret;
}

TTV_ErrorCode ttv::broadcast::FlvMuxer::WriteMetaPacket(const std::shared_ptr<AMF0Encoder>& encoder) {
  TTV_ErrorCode ret = BeginChunk(flv::Meta, 0, encoder->GetBuffer().size());

  if (TTV_SUCCEEDED(ret)) {
    WriteToOutput(encoder->GetBuffer());

    ret = EndChunk(encoder->GetBuffer().size());
    ASSERT_ON_ERROR(ret);
  }

  return ret;
}

size_t ttv::broadcast::FlvMuxer::WriteToOutput(const uint8_t* data, size_t length, bool rtmpData) {
  size_t sizeWritten = 0;

  if (mOutputFile != nullptr) {
    sizeWritten = fwrite(data, 1, length, mOutputFile);
    assert(sizeWritten == length);
  }

  if (rtmpData && nullptr != mRtmpStream) {
    TTV_ErrorCode ret = mRtmpStream->AddFLVData(data, length);

    if (TTV_SUCCEEDED(ret)) {
      sizeWritten = length;
    }
  }

  return sizeWritten;
}

TTV_ErrorCode ttv::broadcast::FlvMuxer::GetError() {
  if (mRtmpStream != nullptr) {
    return mRtmpStream->GetError();
  } else {
    return TTV_EC_SUCCESS;
  }
}

TTV_ErrorCode ttv::broadcast::FlvMuxer::GetAverageSendBitRate(
  uint64_t measurementWindowMilliseconds, uint64_t& bitsPerSecond) const {
  if (mRtmpStream != nullptr) {
    return mRtmpStream->GetAverageSendBitRate(measurementWindowMilliseconds, bitsPerSecond);
  } else {
    return TTV_EC_NOT_AVAILABLE;
  }
}

TTV_ErrorCode ttv::broadcast::FlvMuxer::GetCongestionLevel(
  uint64_t measurementWindowMilliseconds, double& congestionLevel) const {
  if (mRtmpStream != nullptr) {
    return mRtmpStream->GetCongestionLevel(measurementWindowMilliseconds, congestionLevel);
  } else {
    return TTV_EC_NOT_AVAILABLE;
  }
}
