#include "testvideoencoder.h"

#include "twitchsdk/broadcast/irawvideoframereceiver.h"
#include "twitchsdk/broadcast/videoframe.h"

#include "gtest/gtest.h"

using namespace ttv;
using namespace ttv::broadcast;

namespace {
class RawVideoFrame : public VideoFrame {
 public:
  RawVideoFrame(const uint8_t* frameBuffer, PixelFormat /*pixelFormat*/, bool verticalFlip, uint64_t timeStamp,
    IRawVideoFrameReceiver::UnlockFunc unlockCallback)
      : VideoFrame(IRawVideoFrameReceiver::GetReceiverTypeId()) {
    SetVerticalFlip(verticalFlip);
    SetTimeStamp(timeStamp);

    SetUnlockCallback([unlockCallback, frameBuffer]() { unlockCallback(frameBuffer); });
  }
};

class RawReceiver : public IRawVideoFrameReceiver {
 public:
  virtual TTV_ErrorCode PackageFrame(const uint8_t* frameBuffer, PixelFormat pixelFormat, bool verticalFlip,
    uint64_t timeStamp, UnlockFunc unlockCallback, std::shared_ptr<VideoFrame>& result) override {
    result = std::make_shared<RawVideoFrame>(frameBuffer, pixelFormat, verticalFlip, timeStamp, unlockCallback);

    return TTV_EC_SUCCESS;
  }
};
}  // namespace

ttv::broadcast::test::TestVideoEncoder::TestVideoEncoder() : mInitialized(false), mStarted(false) {}

ttv::broadcast::test::TestVideoEncoder::~TestVideoEncoder() {
  EXPECT_EQ(mStarted, false);
}

bool ttv::broadcast::test::TestVideoEncoder::SupportsBitRateAdjustment() const {
  // TODO: Allow this to be customized
  return false;
}

TTV_ErrorCode ttv::broadcast::test::TestVideoEncoder::SetFrameWriter(const std::shared_ptr<IFrameWriter>& frameWriter) {
  // TODO: Make sure not encoding
  mFrameWriter = frameWriter;

  return TTV_EC_SUCCESS;
}

TTV_ErrorCode ttv::broadcast::test::TestVideoEncoder::ValidateVideoParams(const VideoParams& /*videoParams*/) const {
  return TTV_EC_SUCCESS;
}

TTV_ErrorCode ttv::broadcast::test::TestVideoEncoder::Initialize() {
  EXPECT_NE(mInitialized, true);

  mInitialized = true;

  return TTV_EC_SUCCESS;
}

TTV_ErrorCode ttv::broadcast::test::TestVideoEncoder::Shutdown() {
  EXPECT_NE(mInitialized, false);

  mInitialized = false;

  return TTV_EC_SUCCESS;
}

TTV_ErrorCode ttv::broadcast::test::TestVideoEncoder::Start(uint32_t streamIndex, const VideoParams& videoParams) {
  EXPECT_EQ(streamIndex, 0);
  EXPECT_EQ(mInitialized, true);
  EXPECT_EQ(mStarted, false);

  mVideoParams = videoParams;
  mStarted = true;

  return TTV_EC_SUCCESS;
}

TTV_ErrorCode ttv::broadcast::test::TestVideoEncoder::SubmitFrame(const std::shared_ptr<VideoFrame>& /*videoFrame*/) {
  EXPECT_EQ(mInitialized, true);
  EXPECT_EQ(mStarted, true);

  return TTV_EC_SUCCESS;
}

TTV_ErrorCode ttv::broadcast::test::TestVideoEncoder::Stop() {
  EXPECT_EQ(mInitialized, true);
  EXPECT_EQ(mStarted, true);

  mStarted = false;

  return TTV_EC_SUCCESS;
}

TTV_ErrorCode ttv::broadcast::test::TestVideoEncoder::GetSpsPps(
  std::vector<uint8_t>& /*sps*/, std::vector<uint8_t>& /*pps*/) {
  return TTV_EC_SUCCESS;
}

TTV_ErrorCode ttv::broadcast::test::TestVideoEncoder::ValidateFrame(const std::shared_ptr<VideoFrame>& /*videoframe*/) {
  EXPECT_EQ(mInitialized, true);

  return TTV_EC_SUCCESS;
}

TTV_ErrorCode ttv::broadcast::test::TestVideoEncoder::SetTargetBitRate(uint32_t /*kbps*/) {
  return TTV_EC_UNSUPPORTED;
}

std::string ttv::broadcast::test::TestVideoEncoder::GetName() const {
  return "TestVideoEncoder";
}

bool ttv::broadcast::test::TestVideoEncoder::SupportsReceiverProtocol(
  IVideoFrameReceiver::ReceiverTypeId typeId) const {
  return typeId == IRawVideoFrameReceiver::GetReceiverTypeId();
}

std::shared_ptr<IVideoFrameReceiver> ttv::broadcast::test::TestVideoEncoder::GetReceiverImplementation(
  IVideoFrameReceiver::ReceiverTypeId typeId) {
  if (typeId == IRawVideoFrameReceiver::GetReceiverTypeId()) {
    return std::make_shared<RawReceiver>();
  } else {
    return nullptr;
  }
}
