/****************************************************************************
 * Twitch SDK
 *
 * This software is supplied under the terms of a license agreement with
 * Twitch Interactive, Inc. and may not be copied or used except in accordance
 * with the terms of that agreement
 *
 * Copyright (c) 2012-2016 Twitch Interactive, Inc.
 ***************************************************************************/

#include "twitchsdk/broadcast/internal/pch.h"

#include "twitchsdk/broadcast/passthroughvideoencoder.h"

#include "twitchsdk/broadcast/iframewriter.h"
#include "twitchsdk/broadcast/ipreencodedvideoframereceiver.h"
#include "twitchsdk/broadcast/packet.h"
#include "twitchsdk/broadcast/videoframe.h"
#include "twitchsdk/core/systemclock.h"
#include "twitchsdk/core/thread.h"

namespace {
using namespace ttv::broadcast;
using ProcessedCallback = std::function<void()>;

const char* kLoggerName = "PassThroughVideoEncoder";

/**
 * The video frame required to pack pre-encoded buffers.
 */
class PreEncodedVideoFrame : public VideoFrame {
 public:
  PreEncodedVideoFrame(
    IPreEncodedVideoFrameReceiver::VideoPacket&& frameData, bool keyFrame, uint64_t timeStamp, uint32_t frameIndex)
      : VideoFrame(IPreEncodedVideoFrameReceiver::GetReceiverTypeId()), mFrameData(frameData), mProcessed(false) {
    SetFrameIndex(frameIndex);
    SetIsKeyFrame(keyFrame);
    SetTimeStamp(timeStamp);
  }

  void SetProcessed() { mProcessed = true; }
  bool GetProcessed() const { return mProcessed; }

  const IPreEncodedVideoFrameReceiver::VideoPacket& GetFrameData() const { return mFrameData; }

 private:
  IPreEncodedVideoFrameReceiver::VideoPacket mFrameData;
  bool mProcessed;  // We don't want to process this frame again if it's resubmitted due to submission lag.
};

class PreEncodedReceiver : public IPreEncodedVideoFrameReceiver {
 public:
  PreEncodedReceiver() : mNextFrameIndex(0) {}

 public:
  virtual TTV_ErrorCode PackageFrame(
    VideoPacket&& videoPacket, bool keyFrame, uint64_t timeStamp, std::shared_ptr<VideoFrame>& result) override {
    result = std::make_shared<PreEncodedVideoFrame>(std::move(videoPacket), keyFrame, timeStamp, mNextFrameIndex);
    mNextFrameIndex++;

    return TTV_EC_SUCCESS;
  }

 private:
  uint32_t mNextFrameIndex;
};
}  // namespace

ttv::broadcast::PassThroughVideoEncoder::PassThroughVideoEncoder() : IVideoEncoder(), mStreamIndex(0), mStarted(false) {
  ttv::trace::Message(kLoggerName, MessageLevel::Info, "PassThroughVideoEncoder created");
}

ttv::broadcast::PassThroughVideoEncoder::~PassThroughVideoEncoder() {
  ttv::trace::Message(kLoggerName, MessageLevel::Info, "PassThroughVideoEncoder destroyed");
}

TTV_ErrorCode ttv::broadcast::PassThroughVideoEncoder::SetSps(const std::vector<uint8_t>& sps) {
  mSps = sps;

  return TTV_EC_SUCCESS;
}

TTV_ErrorCode ttv::broadcast::PassThroughVideoEncoder::SetPps(const std::vector<uint8_t>& pps) {
  mPps = pps;

  return TTV_EC_SUCCESS;
}

TTV_ErrorCode ttv::broadcast::PassThroughVideoEncoder::SetAdjustTargetBitRateFunc(AdjustTargetBitRateFunc&& func) {
  mAdjustTargetBitRateFunc = std::move(func);

  return TTV_EC_SUCCESS;
}

std::string ttv::broadcast::PassThroughVideoEncoder::GetName() const {
  return kLoggerName;
}

bool ttv::broadcast::PassThroughVideoEncoder::SupportsBitRateAdjustment() const {
  return mAdjustTargetBitRateFunc != nullptr;
}

TTV_ErrorCode ttv::broadcast::PassThroughVideoEncoder::SetFrameWriter(
  const std::shared_ptr<IFrameWriter>& frameWriter) {
  if (mStarted) {
    return TTV_EC_INVALID_STATE;
  }

  mFrameWriter = frameWriter;

  return TTV_EC_SUCCESS;
}

TTV_ErrorCode ttv::broadcast::PassThroughVideoEncoder::ValidateVideoParams(const VideoParams& /*videoParams*/) const {
  return TTV_EC_SUCCESS;
}

TTV_ErrorCode ttv::broadcast::PassThroughVideoEncoder::Initialize() {
  ttv::trace::Message(kLoggerName, MessageLevel::Debug, "PassThroughVideoEncoder::Initialize()");

  return TTV_EC_SUCCESS;
}

TTV_ErrorCode ttv::broadcast::PassThroughVideoEncoder::Shutdown() {
  ttv::trace::Message(kLoggerName, MessageLevel::Debug, "PassThroughVideoEncoder::Shutdown()");

  return TTV_EC_SUCCESS;
}

TTV_ErrorCode ttv::broadcast::PassThroughVideoEncoder::Start(uint32_t streamIndex, const VideoParams& /*videoParams*/) {
  ttv::trace::Message(kLoggerName, MessageLevel::Debug, "PassThroughVideoEncoder::Start()");

  if (mStarted) {
    return TTV_EC_INVALID_STATE;
  }

  if (mFrameWriter == nullptr) {
    TTV_ASSERT(false && "mFrameWriter != nullptr");
    ttv::trace::Message(
      kLoggerName, MessageLevel::Error, "Inside PassThroughVideoEncoder::Start - Bad frame writer parameter");
    return TTV_EC_INVALID_ARG;
  }

  mStreamIndex = streamIndex;
  mStarted = true;

  return TTV_EC_SUCCESS;
}

TTV_ErrorCode ttv::broadcast::PassThroughVideoEncoder::SetTargetBitRate(uint32_t kbps) {
  if (mAdjustTargetBitRateFunc != nullptr) {
    return mAdjustTargetBitRateFunc(kbps);
  } else {
    return TTV_EC_UNSUPPORTED;
  }
}

TTV_ErrorCode ttv::broadcast::PassThroughVideoEncoder::SubmitFrame(const std::shared_ptr<VideoFrame>& videoFrame) {
  ttv::trace::Message(kLoggerName, MessageLevel::Debug, "PassThroughVideoEncoder::SubmitFrame()");

  if (!mStarted) {
    return TTV_EC_INVALID_STATE;
  }

  TTV_ASSERT(videoFrame != nullptr);

  TTV_ASSERT(videoFrame->GetReceiverTypeId() == IPreEncodedVideoFrameReceiver::GetReceiverTypeId());
  if (videoFrame->GetReceiverTypeId() != IPreEncodedVideoFrameReceiver::GetReceiverTypeId()) {
    return TTV_EC_BROADCAST_INVALID_SUBMISSION_METHOD;
  }

  auto preEncodedFrame = std::static_pointer_cast<PreEncodedVideoFrame>(videoFrame);

  if (!preEncodedFrame->GetProcessed()) {
    auto packet = std::make_unique<Packet>();

    packet->streamIndex = mStreamIndex;
    packet->timestamp = SystemTimeToMs(preEncodedFrame->GetTimeStamp());  // The encoder needs to output milliseconds
    packet->keyframe = preEncodedFrame->IsKeyFrame();
    packet->data = std::move(preEncodedFrame->GetFrameData());

    mFrameWriter->WritePacket(std::move(packet));

    preEncodedFrame->SetProcessed();
  }

  return TTV_EC_SUCCESS;
}

TTV_ErrorCode ttv::broadcast::PassThroughVideoEncoder::Stop() {
  ttv::trace::Message(kLoggerName, MessageLevel::Debug, "PassThroughVideoEncoder::Stop()");

  mStarted = false;

  return TTV_EC_SUCCESS;
}

TTV_ErrorCode ttv::broadcast::PassThroughVideoEncoder::GetSpsPps(std::vector<uint8_t>& sps, std::vector<uint8_t>& pps) {
  sps = mSps;
  pps = mPps;

  return TTV_EC_SUCCESS;
}

TTV_ErrorCode ttv::broadcast::PassThroughVideoEncoder::ValidateFrame(const std::shared_ptr<VideoFrame>& videoframe) {
  TTV_ASSERT(videoframe->GetReceiverTypeId() == IPreEncodedVideoFrameReceiver::GetReceiverTypeId());

  if (videoframe->GetReceiverTypeId() != IPreEncodedVideoFrameReceiver::GetReceiverTypeId()) {
    return TTV_EC_BROADCAST_INVALID_VIDEOFRAME;
  }

  return TTV_EC_SUCCESS;
}

bool ttv::broadcast::PassThroughVideoEncoder::SupportsReceiverProtocol(
  IVideoFrameReceiver::ReceiverTypeId typeId) const {
  return typeId == IPreEncodedVideoFrameReceiver::GetReceiverTypeId();
}

std::shared_ptr<ttv::broadcast::IVideoFrameReceiver> ttv::broadcast::PassThroughVideoEncoder::GetReceiverImplementation(
  IVideoFrameReceiver::ReceiverTypeId typeId) {
  if (typeId == IPreEncodedVideoFrameReceiver::GetReceiverTypeId()) {
    if (mReceiver == nullptr) {
      mReceiver = std::make_shared<PreEncodedReceiver>();
    }

    return mReceiver;
  }

  return nullptr;
}
