/****************************************************************************
 * Twitch SDK
 *
 * This software is supplied under the terms of a license agreement with
 * Twitch Interactive, Inc. and may not be copied or used except in accordance
 * with the terms of that agreement
 *
 * Copyright (c) 2012-2016 Twitch Interactive, Inc.
 ***************************************************************************/

#include "rtmp.h"

#include "amf0.h"
#include "buffer.h"
#include "testsocket.h"
#include "testutilities.h"
#include "twitchsdk/core/httprequestutils.h"
#include "twitchsdk/core/socket.h"
#include "twitchsdk/core/stringutilities.h"

#include "gtest/gtest.h"

using namespace ttv;
using namespace ttv::broadcast::test;
using namespace ttv::broadcast::test::amf0;

// http://wwwimages.adobe.com/content/dam/Adobe/en/devnet/rtmp/pdf/rtmp_specification_1.0.pdf

namespace {
using namespace ttv::broadcast::test::rtmp;

const size_t kHandshakeBlobSize = 1536;

/**
 * Removes all the 1-byte fragment headers that are scattered in the chunk body.  I believe that we don't interleave
 * different types of data so the assumption of not having to split the data into different chunk bodies is relevant for
 * now.
 */
void DefragmentChunkBody(std::vector<uint8_t>& buffer, uint32_t chunkSize, uint32_t fragmentSize = 128) {
  uint32_t pos = fragmentSize;

  while (pos < chunkSize) {
    buffer.erase(buffer.begin() + pos, buffer.begin() + pos + 1);
    pos += fragmentSize - 1;
    chunkSize--;
  }
}

size_t GetChunkMessageHeaderSize(ChunkMessageHeaderType type) {
  switch (type) {
    case ChunkMessageHeaderType::Large:
      return 11;
    case ChunkMessageHeaderType::Medium:
      return 7;
    case ChunkMessageHeaderType::Small:
      return 3;
    case ChunkMessageHeaderType::Empty:
      return 0;
    default:
      EXPECT_FALSE(true);
      return 0;
  }
}
}  // namespace

ttv::broadcast::test::rtmp::ConnectCommand::ConnectCommand() : transactionId(0) {}

ttv::broadcast::test::rtmp::ConnectCommand::~ConnectCommand() {}

ttv::broadcast::test::rtmp::CreateStreamCommand::CreateStreamCommand() : transactionId(0) {}

ttv::broadcast::test::rtmp::CreateStreamCommand::~CreateStreamCommand() {}

ttv::broadcast::test::rtmp::ChunkParser::ChunkParser() {
  Reset();
}

void ttv::broadcast::test::rtmp::ChunkParser::AddListener(const std::shared_ptr<IRtmpListener>& listener) {
  mListeners.AddListener(listener);
}

void ttv::broadcast::test::rtmp::ChunkParser::RemoveListener(const std::shared_ptr<IRtmpListener>& listener) {
  mListeners.RemoveListener(listener);
}

bool ttv::broadcast::test::rtmp::ChunkParser::ParseChunkHeader(std::vector<uint8_t>& buffer) {
  if (buffer.size() < 1) {
    return false;
  }

  // Chunk Basic Header

  size_t consumed = 0;
  uint8_t id = buffer[0] & 0x3F;
  size_t chunkMessageHeaderSize = 0;

  mChunkHeader.basicHeader.chunkMessageHeaderType =
    static_cast<ChunkMessageHeaderType>(buffer[consumed] & 0xC0);  // 2 highest order bits
  consumed += 1;

  //  2 byte form
  if (id == 0) {
    if (buffer.size() < 2) {
      return false;
    }

    mChunkHeader.basicHeader.chunkStreamId = buffer[1] + 64;
    consumed += 1;
  }
  // 3 byte form
  else if (id == 1) {
    if (buffer.size() < 3) {
      return false;
    }

    mChunkHeader.basicHeader.chunkStreamId = buffer[2] * 256 + buffer[1] + 64;
    consumed += 2;
  }
  // Low level control protocol commands
  else if (id == 2) {
    // TODO:
  }
  // 1 byte form
  else {
    mChunkHeader.basicHeader.chunkStreamId = id;
    consumed += 0;
  }

  // Chunk Message Header

  // Wait for the full message header
  chunkMessageHeaderSize = GetChunkMessageHeaderSize(mChunkHeader.basicHeader.chunkMessageHeaderType);
  if (buffer.size() < chunkMessageHeaderSize) {
    return false;
  }

  // Read timestamp
  if (mChunkHeader.basicHeader.chunkMessageHeaderType <= ChunkMessageHeaderType::Small) {
    mChunkHeader.messageHeader.timestamp = ReadUInt24(buffer, consumed, false);
    consumed += 3;
  }

  if (mChunkHeader.basicHeader.chunkMessageHeaderType <= ChunkMessageHeaderType::Medium) {
    // Read message length
    mChunkHeader.messageHeader.messageLength = ReadUInt24(buffer, consumed, false);
    consumed += 3;

    // Read message type id
    mChunkHeader.messageHeader.messageTypeId = ReadByte<MessageTypeId>(buffer, consumed, false);
    consumed += 1;
  }

  // Read message stream id
  if (mChunkHeader.basicHeader.chunkMessageHeaderType <= ChunkMessageHeaderType::Large) {
    mChunkHeader.messageHeader.messageStreamId = ReadUInt32(buffer, consumed, false);
    consumed += 4;
  }

  // Extended Timestamp
  if (mChunkHeader.messageHeader.timestamp == 0xFFFFFF) {
    mChunkHeader.extendedTimestamp = ReadUInt32(buffer, consumed, false);
    consumed += 4;
  }

  // Remove the consumed data since we have the full chunk header
  buffer.erase(buffer.begin(), buffer.begin() + consumed);

  // Chunk Data follows

  return true;
}

bool ttv::broadcast::test::rtmp::ChunkParser::ParseChunkBody(std::vector<uint8_t>& buffer) {
  if (buffer.size() < mChunkHeader.messageHeader.messageLength) {
    return false;
  }

  DefragmentChunkBody(buffer, static_cast<uint32_t>(buffer.size()));

  switch (mChunkHeader.messageHeader.messageTypeId) {
    case MessageTypeId::Amf0:
      return ParseChunkBody_Amf0(buffer);
    default:
      return false;
  }
}

bool ttv::broadcast::test::rtmp::ChunkParser::ParseChunkBody_Amf0(std::vector<uint8_t>& buffer) {
  std::vector<Amf0Value*> values;

  while (!buffer.empty()) {
    Amf0Value* value = new Amf0Value();

    ParseAmf0Value(buffer, *value);

    values.push_back(value);
  }

  if (!values.empty()) {
    const auto& name = values[0];
    EXPECT_TRUE(name->type == Amf0Type::String);

    if (name->str == "connect") {
      EXPECT_TRUE(values.size() >= 3);

      std::shared_ptr<ConnectCommand> cmd = std::make_shared<ConnectCommand>();

      // Transaction id
      EXPECT_EQ(values[1]->type, Amf0Type::Number);
      cmd->transactionId = static_cast<uint32_t>(values[1]->number);

      // Command arguments
      cmd->commandArguments = values[2]->object;

      // User arguments
      if (values.size() > 3) {
        cmd->userArguments = values[3]->object;
      }

      mListeners.Invoke(
        [cmd](const std::shared_ptr<IRtmpListener>& listener) { listener->OnConnectCommandReceived(cmd); });
    } else if (name->str == "call") {
      EXPECT_TRUE(false);
    } else if (name->str == "close") {
      EXPECT_TRUE(false);
    } else if (name->str == "createStream") {
    } else {
      EXPECT_TRUE(false);
    }
  }

  return true;
}

bool ttv::broadcast::test::rtmp::ChunkParser::Parse(std::vector<uint8_t>& buffer) {
  bool gotFullChunk = false;
  bool done = false;

  while (!done) {
    done = true;

    switch (mState) {
      case State::ReadingChunkMessageHeader: {
        if (!ParseChunkHeader(buffer)) {
          break;
        }

        done = false;
        mState = State::WaitingForChunkBody;

        break;
      }
      case State::WaitingForChunkBody: {
        if (!ParseChunkBody(buffer)) {
          break;
        }

        // Done with this chunk
        done = false;
        gotFullChunk = true;
        Reset();

        break;
      }
    }
  }

  return gotFullChunk;
}

void ttv::broadcast::test::rtmp::ChunkParser::Reset() {
  mChunkHeader = {};
  mState = State::ReadingChunkMessageHeader;
}

ttv::broadcast::test::rtmp::ChunkWriter::ChunkWriter() {
  Reset();
}

void ttv::broadcast::test::rtmp::ChunkWriter::SetChunkStreamId(uint32_t streamId) {
  mChunkHeader.basicHeader.chunkStreamId = streamId;
}

void ttv::broadcast::test::rtmp::ChunkWriter::SetTimestamp(uint32_t timestamp) {
  mChunkHeader.messageHeader.timestamp = timestamp;
}

void ttv::broadcast::test::rtmp::ChunkWriter::SetMessageTypeId(MessageTypeId type) {
  mChunkHeader.messageHeader.messageTypeId = type;
}

void ttv::broadcast::test::rtmp::ChunkWriter::SetMessageStreamId(uint32_t streamId) {
  mChunkHeader.messageHeader.messageStreamId = streamId;
}

void ttv::broadcast::test::rtmp::ChunkWriter::WriteBody(const Amf0Value& value) {
  WriteAmf0Value(mBodyBuffer, value, false);
}

void ttv::broadcast::test::rtmp::ChunkWriter::WriteBody(const std::vector<uint8_t>& bytes) {
  mBodyBuffer.insert(mBodyBuffer.end(), bytes.begin(), bytes.end());
}

void ttv::broadcast::test::rtmp::ChunkWriter::GetChunk(std::vector<uint8_t>& buffer) {
  // TODO: Handle low level control messages which are indicated by first byte being 2
  // TODO: Finish this

  char fmt = static_cast<char>(mChunkHeader.basicHeader.chunkMessageHeaderType);  // 2 highest order bits

  // 1-byte form
  if (mChunkHeader.basicHeader.chunkStreamId > 2 && mChunkHeader.basicHeader.chunkStreamId < 64) {
    buffer.push_back(fmt | static_cast<char>(mChunkHeader.basicHeader.chunkStreamId));
  }
  // 2-byte form
  else if (mChunkHeader.basicHeader.chunkStreamId < 319) {
    buffer.push_back(fmt | 0x0);
    buffer.push_back(static_cast<char>(mChunkHeader.basicHeader.chunkStreamId) - 64);
  }
  // 3-byte form
  else {
    buffer.push_back(fmt | 0x1);
    buffer.push_back(static_cast<char>(mChunkHeader.basicHeader.chunkStreamId) - 64);
  }

  ////  2 byte form
  // if (mChunkHeader.basicHeader.chunkMessageHeaderType == 0)
  //{
  //  mChunkHeader.basicHeader.chunkStreamId = buffer[1] + 64;
  //}
  //// 3 byte form
  // else if (id == 1)
  //{
  //  mChunkHeader.basicHeader.chunkStreamId = buffer[2] * 256 + buffer[1] + 64;
  //}

  // TODO: Take care of fragments
}

void ttv::broadcast::test::rtmp::ChunkWriter::Reset() {
  mChunkHeader = {};
  mBodyBuffer.clear();
}

ttv::broadcast::test::rtmp::RtmpServer::RtmpServer() : mState(RtmpState::Uninitialized), mBandwidthTest(false) {
  mChunkListener = std::make_shared<RtmpListenerProxy>();

  mChunkListener->mOnConnectCommandReceivedFunc = [this](const std::shared_ptr<ConnectCommand>& cmd) {
    EXPECT_TRUE(mState == RtmpState::WaitingForConnect);
    EXPECT_TRUE(cmd->transactionId == 0);
    EXPECT_TRUE(cmd->commandArguments != nullptr);

    // TODO: Validate args
    for (const auto& prop : cmd->commandArguments->properties) {
      if (prop.first == "tcUrl") {
        // rtmp://live-syd.twitch.tv/app/valid_stream_key?client_id=TEST_CLIENT_ID&sdk_version=sdk_DEV&video_encoder=TestVideoEncoder&os=Win32&bandwidthtest=true
        EXPECT_TRUE(prop.second != nullptr);
        EXPECT_TRUE(prop.second->type == Amf0Type::String);

        Uri url(prop.second->str);
        EXPECT_TRUE(url.GetProtocol() == "rtmp");

        EXPECT_TRUE(url.GetHostName() == "live-syd.twitch.tv");

        std::string path = url.GetPath();
        std::vector<std::string> tokens;
        ttv::Split(url.GetPath(), tokens, '/', false);
        std::string streamKey = tokens[tokens.size() - 1];
        EXPECT_TRUE(streamKey == "valid_stream_key");

        const auto& params = url.GetParams();

        auto iter = params.find("sdk_version");
        EXPECT_TRUE(iter != params.end());
        // EXPECT_TRUE(iter->second != ""); // TODO: Merge master and use the global function to get sdk version

        iter = params.find("client_id");
        EXPECT_TRUE(iter != params.end());
        EXPECT_TRUE(iter->second == "TEST_CLIENT_ID");

        iter = params.find("video_encoder");
        EXPECT_TRUE(iter != params.end());
        EXPECT_TRUE(iter->second == "TestVideoEncoder");

        iter = params.find("bandwidthtest");
        EXPECT_TRUE(iter != params.end());
        if (mBandwidthTest) {
          EXPECT_TRUE(iter->second == "true");
        } else {
          EXPECT_TRUE(iter->second == "false");
        }
      } else if (prop.first == "app") {
        EXPECT_TRUE(prop.second != nullptr);
        EXPECT_TRUE(prop.second->type == Amf0Type::String);
        EXPECT_TRUE(prop.second->str == "app");
      } else if (prop.first == "type") {
        EXPECT_TRUE(prop.second != nullptr);
        EXPECT_TRUE(prop.second->type == Amf0Type::String);
        EXPECT_TRUE(prop.second->str == "nonprivate");
      } else {
        // We're not expecting any more params from the SDK right now
        EXPECT_TRUE(false);
      }
    }

    // Send the reply to the create stream request
    std::vector<uint8_t> chunkBody;
    WriteAmf0String(chunkBody, Amf0Value("_result"));
    WriteAmf0Number(chunkBody, Amf0Value(1u));  // transaction id
    // WriteAmf0Object(chunkBody, Amf0Object()); // Properties

    // Amf0Object info;
    // info.properties.emplace_back("code", std::make_shared<Amf0Value>("NetConnection.Connect.Success"));
    // WriteAmf0Object(sendBuffer, info); // Information

    // Create the header

    std::string str;
    str.assign(chunkBody.begin(), chunkBody.end());
    mSocket->PushReceivedPayload(str, nullptr);

    SetState(RtmpState::WaitingForCreateStream);
  };

  mChunkListener->mOnCreateStreamCommandReceived = [this](const std::shared_ptr<CreateStreamCommand>& cmd) {
    EXPECT_TRUE(mState == RtmpState::WaitingForCreateStream);
    EXPECT_TRUE(cmd->transactionId == 0);

    // Send the reply to the create stream request
    std::vector<uint8_t> sendBuffer;
    WriteAmf0String(sendBuffer, Amf0Value("_result"));
    WriteAmf0Number(sendBuffer, Amf0Value(cmd->transactionId));  // transaction id
    WriteAmf0Object(sendBuffer, Amf0Object());                   // Command data
    WriteAmf0Number(sendBuffer, Amf0Value(1u));                  // Stream Id

    std::string str;
    str.assign(sendBuffer.begin(), sendBuffer.end());
    mSocket->PushReceivedPayload(str, nullptr);

    SetState(RtmpState::WaitingForPublish);
  };

  mChunkParser.AddListener(mChunkListener);
}

ttv::broadcast::test::rtmp::RtmpServer::~RtmpServer() {}

void ttv::broadcast::test::rtmp::RtmpServer::AddListener(const std::shared_ptr<IRtmpListener>& listener) {
  mListeners.AddListener(listener);
  mChunkParser.AddListener(listener);
}

void ttv::broadcast::test::rtmp::RtmpServer::RemoveListener(const std::shared_ptr<IRtmpListener>& listener) {
  mListeners.RemoveListener(listener);
  mChunkParser.RemoveListener(listener);
}

void ttv::broadcast::test::rtmp::RtmpServer::SetSocket(const std::shared_ptr<ttv::test::TestSocket>& socket) {
  EXPECT_NE(socket, nullptr);

  mSocket = socket;

  mSocket->SetSendCallback([this](const uint8_t* buffer, size_t length, size_t& sent) -> TTV_ErrorCode {
    // Fill up the local buffer
    sent = length;
    mReceiveBuffer.insert(mReceiveBuffer.end(), buffer, buffer + length);

    Update();

    return TTV_EC_SUCCESS;
  });
}

void ttv::broadcast::test::rtmp::RtmpServer::Initialize() {
  EXPECT_NE(mSocket, nullptr);
}

void ttv::broadcast::test::rtmp::RtmpServer::Update() {
  bool done = false;

  while (!done) {
    done = true;

    switch (mState) {
      case RtmpState::Uninitialized: {
        if (mReceiveBuffer.size() >= 1) {
          uint8_t version = mReceiveBuffer[0];
          mReceiveBuffer.erase(mReceiveBuffer.begin());
          EXPECT_EQ(version, 0x3);

          done = false;
          SetState(RtmpState::WaitingForC1);
        }

        break;
      }
      case RtmpState::WaitingForC1: {
        if (mReceiveBuffer.size() >= kHandshakeBlobSize) {
          auto begin = mReceiveBuffer.begin();
          auto end = begin + kHandshakeBlobSize;
          mReceiveBuffer.erase(begin, end);

          // The client is waiting for a single byte
          mSocket->PushReceivedPayload("1", nullptr);

          // The client is waiting for a chunk of size kChunkSize
          std::string chunk(kHandshakeBlobSize, '1');
          mSocket->PushReceivedPayload(chunk, nullptr);

          done = false;
          SetState(RtmpState::WaitingForC2);
        }

        break;
      }
      case RtmpState::WaitingForC2: {
        if (mReceiveBuffer.size() >= kHandshakeBlobSize) {
          auto begin = mReceiveBuffer.begin();
          auto end = begin + kHandshakeBlobSize;
          mReceiveBuffer.erase(begin, end);

          // The client is waiting for a chunk of size kChunkSize
          std::string chunk(kHandshakeBlobSize, '1');
          mSocket->PushReceivedPayload(chunk, nullptr);

          done = false;
          SetState(RtmpState::HandshakeDone);
        }

        break;
      }
      case RtmpState::HandshakeDone: {
        done = false;
        SetState(RtmpState::WaitingForConnect);

        break;
      }
      case RtmpState::WaitingForConnect: {
        // Returns true if at least one chunk was read
        if (mChunkParser.Parse(mReceiveBuffer)) {
          done = false;
        }

        break;
      }
      case RtmpState::WaitingForCreateStream: {
        break;
      }
      case RtmpState::Broadcasting: {
        break;
      }
      default: { break; }
    }
  }
}

void ttv::broadcast::test::rtmp::RtmpServer::Shutdown() {
  mSocket.reset();
}

void ttv::broadcast::test::rtmp::RtmpServer::SetState(RtmpState state) {
  if (state == mState) {
    return;
  }

  mState = state;

  mListeners.Invoke([this, state](const std::shared_ptr<IRtmpListener>& listener) { listener->OnStateChanged(state); });
}
