/****************************************************************************
 * Twitch SDK
 *
 * This software is supplied under the terms of a license agreement with
 * Twitch Interactive, Inc. and may not be copied or used except in accordance
 * with the terms of that agreement
 *
 * Copyright (c) 2012-2016 Twitch Interactive, Inc.
 ***************************************************************************/

#include "twitchsdk/core/internal/pch.h"

#include "twitchsdk/core/standardwebsocket.h"

#include "twitchsdk/core/systemclock.h"
#include "twitchsdk/core/thread.h"

#include <websocketpp/client.hpp>
#include <websocketpp/config/core.hpp>
#include <websocketpp/connection.hpp>

#define ENABLE_WEBSOCKETPP_PRINT 0

// TODO: Cleanup windows includes
#ifdef SetPort
#undef SetPort
#endif

namespace {
using namespace ttv;

using client = websocketpp::client<websocketpp::config::core>;

struct Message {
  IWebSocket::MessageType type;
  std::vector<char> message;
};

uint64_t kConnectionTimeoutMilliseconds = 10000;

websocketpp::frame::opcode::value Convert(ttv::IWebSocket::MessageType value) {
  switch (value) {
    case IWebSocket::MessageType::Text:
      return websocketpp::frame::opcode::value::text;
    case IWebSocket::MessageType::Binary:
      return websocketpp::frame::opcode::value::binary;
    default:
      // TODO: Handle other types
      TTV_ASSERT(false);
      return websocketpp::frame::opcode::value::binary;
  }
}
}  // namespace

namespace ttv {
struct StandardWebSocket::SocketData {
  SocketData() : mConnected(false) {}

  std::string mUri;
  client mClient;
  client::connection_ptr mConnection;
  std::shared_ptr<ISocket> mBaseSocket;  //!< The socket used for the actual transport.
  std::vector<char> mReadBuffer;         //!< The temporary read buffer for reading from the base socket.
  std::vector<char> mRawBuffer;          //!< Data that has been received but not yet read by the client.
  std::vector<Message> mMessageBuffer;   //!< Messages that have been received but not read by the client.
  bool mConnected;
};
}  // namespace ttv

ttv::StandardWebSocket::StandardWebSocket() {
  mInnerData = std::make_unique<SocketData>();

#if !ENABLE_WEBSOCKETPP_PRINT
  mInnerData->mClient.clear_access_channels(websocketpp::log::alevel::all);
  mInnerData->mClient.clear_error_channels(websocketpp::log::alevel::all);
#endif
}

ttv::StandardWebSocket::~StandardWebSocket() {
  (void)Disconnect();
}

TTV_ErrorCode ttv::StandardWebSocket::Initialize(const std::string& uri) {
  ttv::trace::Message("Socket", MessageLevel::Debug, "WebSocket: Creating WebSocket for uri: %s", uri.c_str());

  mInnerData->mUri = uri;
  mInnerData->mRawBuffer.reserve(4096);
  mInnerData->mReadBuffer.resize(4096);

  TTV_ErrorCode ec = TTV_EC_SUCCESS;

  Uri url(uri);

  // Use a raw socket as the underlying connection
  if (url.GetProtocol() == "ws") {
    url.SetProtocol("tcp");
    if (url.GetPort() == "") {
      url.SetPort("80");
    }
  }
  // Use a tls socket as the underlying connection
  else if (url.GetProtocol() == "wss") {
    // Convert to non-secure
    url.SetProtocol("ws");
    mInnerData->mUri = url.GetUrl();

    url.SetProtocol("tls");
    if (url.GetPort() == "") {
      url.SetPort("443");
    }
  } else {
    ec = TTV_EC_UNIMPLEMENTED;
  }

  if (TTV_SUCCEEDED(ec)) {
    ec = ttv::CreateSocket(url.GetUrl(), mInnerData->mBaseSocket);
    if (TTV_FAILED(ec)) {
      ttv::trace::Message("Socket", MessageLevel::Debug, "WebSocket: Failed create base socket with protocol: %s",
        url.GetProtocol().c_str());
    }
  }

  return ec;
}

TTV_ErrorCode ttv::StandardWebSocket::Connect() {
  TTV_ASSERT(!Connected());
  if (Connected()) {
    return TTV_EC_SOCKET_EALREADY;
  } else if (mInnerData->mBaseSocket == nullptr) {
    return TTV_EC_UNIMPLEMENTED;
  }

  std::error_code err;

  TTV_ErrorCode ec = mInnerData->mBaseSocket->Connect();

  if (TTV_SUCCEEDED(ec)) {
    mInnerData->mConnection = mInnerData->mClient.get_connection(mInnerData->mUri, err);
    if (err) {
      ttv::trace::Message("Socket", MessageLevel::Error, "WebSocket: Failed get_connection: %s", err.message().c_str());
      ec = TTV_EC_SOCKET_CREATE_FAILED;
    }
  }

  if (TTV_SUCCEEDED(ec)) {
    bool failed = false;

    std::function<void(websocketpp::connection_hdl)> onSuccess = [this](websocketpp::connection_hdl /*handle*/) {
      ttv::trace::Message("Socket", MessageLevel::Debug, "WebSocket: Received onSuccess callback");

      mInnerData->mConnected = true;
    };

    std::function<void(websocketpp::connection_hdl)> onClose = [this](websocketpp::connection_hdl /*handle*/) {
      ttv::trace::Message("Socket", MessageLevel::Debug, "WebSocket: Received onClose callback");

      mInnerData->mConnected = false;

      mInnerData->mConnection.reset();
    };

    std::function<void(websocketpp::connection_hdl, client::message_ptr)> onIncoming =
      [this](websocketpp::connection_hdl /*handle*/, client::message_ptr msg) {
        const auto& data = msg->get_payload();

        // Copy the data into the socket buffer
        mInnerData->mMessageBuffer.emplace_back(Message());
        Message& packet = mInnerData->mMessageBuffer[mInnerData->mMessageBuffer.size() - 1];
        packet.message.insert(packet.message.end(), data.c_str(), data.c_str() + data.size());

        switch (msg->get_opcode()) {
          case websocketpp::frame::opcode::value::text:
            packet.type = MessageType::Text;
            break;
          case websocketpp::frame::opcode::value::binary:
            packet.type = MessageType::Binary;
            break;
          default:
            packet.type = MessageType::Unknown;
            TTV_ASSERT(false);
            break;
        }

        ttv::trace::Message(
          "Socket", MessageLevel::Debug, "WebSocket: Received onIncoming callback with %u bytes", data.size());
        ttv::trace::Message("Socket", MessageLevel::Debug, "WebSocket:   %.*s", data.size(), data.data());
      };

    std::function<std::error_code(websocketpp::connection_hdl, char const*, size_t)> onOutgoing =
      [this](websocketpp::connection_hdl /*handle*/, char const* buffer, size_t length) -> std::error_code {
      ttv::trace::Message(
        "Socket", MessageLevel::Debug, "WebSocket: Received onOutgoing callback with %u bytes", length);
      ttv::trace::Message("Socket", MessageLevel::Debug, "WebSocket:   %.*s", length, buffer);

      // Feed the outgoing websocket bytes through the base socket
      ttv::trace::Message(
        "Socket", MessageLevel::Debug, "WebSocket: Sending %u outgoing bytes over base socket", length);
      TTV_ErrorCode outgoingEc = mInnerData->mBaseSocket->Send(reinterpret_cast<const uint8_t*>(buffer), length);
      if (TTV_SUCCEEDED(outgoingEc)) {
        return std::error_code();
      } else {
        ttv::trace::Message(
          "Socket", MessageLevel::Error, "WebSocket: Failed to send bytes on base socket, disconnecting");
        (void)Disconnect();
        return std::error_code(-1, std::system_category());
      }
    };

    std::function<void(websocketpp::connection_hdl)> onFail = [this, &failed](websocketpp::connection_hdl handle) {
      auto connection = mInnerData->mClient.get_con_from_hdl(handle);
      if (connection != nullptr) {
        std::string reason = connection->get_ec().message();
        ttv::trace::Message("Socket", MessageLevel::Error, "WebSocket: Received onFail callback: %s", reason.c_str());
      }

      mInnerData->mConnected = false;
      failed = true;

      mInnerData->mConnection.reset();
    };

    std::function<bool(websocketpp::connection_hdl, std::string)> onPing = [](websocketpp::connection_hdl /*handle*/,
                                                                             std::string /*data*/) -> bool {
      // Allow the pong to be sent back
      return true;
    };

    mInnerData->mConnection->set_open_handler(onSuccess);
    mInnerData->mConnection->set_close_handler(onClose);
    mInnerData->mConnection->set_message_handler(onIncoming);
    mInnerData->mConnection->set_write_handler(onOutgoing);
    mInnerData->mConnection->set_fail_handler(onFail);
    mInnerData->mConnection->set_ping_handler(onPing);

    // Kick off an async connect
    ttv::trace::Message("Socket", MessageLevel::Debug, "WebSocket: Initiating connection");
    (void)mInnerData->mClient.connect(mInnerData->mConnection);

    // We need to pump data until we know whether the connection succeeded or not
    uint64_t start = GetSystemTimeMilliseconds();
    while (!failed) {
      if (mInnerData->mConnected) {
        if (ec == TTV_EC_SOCKET_EWOULDBLOCK) {
          ec = TTV_EC_SUCCESS;
        }
        break;
      } else if ((GetSystemTimeMilliseconds() - start) >= kConnectionTimeoutMilliseconds) {
        ttv::trace::Message("Socket", MessageLevel::Error, "WebSocket: Timed out trying to connect");
        ec = TTV_EC_SOCKET_ETIMEDOUT;
        break;
      }

      ec = PollBaseSocket();

      // Break during an error
      if (SOCKET_FAILED(ec)) {
        break;
      }

      // Sleep a little to wait for the response to come in on the socket
      Sleep(10);
    }

    if (failed) {
      ec = TTV_EC_SOCKET_CONNECT_FAILED;
    }
  }

  if (TTV_SUCCEEDED(ec)) {
    ttv::trace::Message("Socket", MessageLevel::Debug, "WebSocket: Connection successful");
  } else {
    ttv::trace::Message("Socket", MessageLevel::Error, "WebSocket: Failed to connect: %s", ErrorToString(ec));
    (void)Disconnect();
  }

  return ec;
}

TTV_ErrorCode ttv::StandardWebSocket::Disconnect() {
  TTV_ErrorCode ec = TTV_EC_SUCCESS;

  if (mInnerData->mConnected) {
    mInnerData->mConnected = false;

    if (mInnerData->mConnection != nullptr) {
      std::error_code err;

      try {
        // This may cause onClose callback to be called
        mInnerData->mClient.close(mInnerData->mConnection, websocketpp::close::status::normal, "", err);
      } catch (...) {
        ttv::trace::Message("Socket", MessageLevel::Error,
          "WebSocket: Exception thrown while closing websocket connection, catching and continuing");
      }

      if (err) {
        ttv::trace::Message(
          "Socket", MessageLevel::Error, "WebSocket: Error closing connection: %s", err.message().c_str());
        ec = TTV_EC_SOCKET_ERR;
      }

      mInnerData->mConnection.reset();
    }
  }

  if (mInnerData->mBaseSocket != nullptr) {
    ec = mInnerData->mBaseSocket->Disconnect();
  }

  return ec;
}

TTV_ErrorCode ttv::StandardWebSocket::PollBaseSocket() {
  TTV_ErrorCode ec = TTV_EC_SOCKET_ENOTCONN;

  if (mInnerData->mBaseSocket != nullptr && mInnerData->mBaseSocket->Connected() &&
      mInnerData->mConnection != nullptr) {
    // Read incoming bytes from the underlying socket
    size_t received = 0;
    ec = mInnerData->mBaseSocket->Recv(
      reinterpret_cast<uint8_t*>(mInnerData->mReadBuffer.data()), mInnerData->mReadBuffer.size(), received);

    if (received > 0) {
      ttv::trace::Message(
        "Socket", MessageLevel::Debug, "WebSocket: Received %u incoming bytes from base socket", received);
      ttv::trace::Message("Socket", MessageLevel::Debug, "WebSocket:   %.*s", received, mInnerData->mReadBuffer.data());
    }

    if (SOCKET_SUCCEEDED(ec)) {
      // Pass the bytes through the websocket
      if (received > 0) {
        // Keep an extra reference to ensure that a callback like onFail or onClose that releases the reference to
        // mConnection doesn't cause the connection to be released while it is on the stack
        auto keepAlive = mInnerData->mConnection;

        // This should call the message handler (onIncoming)
        size_t read =
          mInnerData->mConnection->read_all(reinterpret_cast<char*>(mInnerData->mReadBuffer.data()), received);

        ttv::trace::Message("Socket", MessageLevel::Debug, "WebSocket: Passed %u incoming bytes to websocket", read);

        TTV_ASSERT(read == received);
      }
    }
  }

  return ec;
}

TTV_ErrorCode ttv::StandardWebSocket::Send(MessageType type, const uint8_t* buffer, size_t length) {
  TTV_ErrorCode ec = TTV_EC_SOCKET_ENOTCONN;

  if (mInnerData->mConnection != nullptr) {
    ec = TTV_EC_SUCCESS;

    websocketpp::frame::opcode::value opCode = Convert(type);

    // This will call the write handler (onOutgoing) and send the bytes out over the base socket
    std::error_code err;
    mInnerData->mClient.send(mInnerData->mConnection, buffer, length, opCode, err);

    if (err) {
      ttv::trace::Message(
        "Socket", MessageLevel::Error, "WebSocket: Error sending from a socket: %s", err.message().c_str());
      ec = TTV_EC_SOCKET_SEND_ERROR;

      (void)Disconnect();
    }
  }

  return ec;
}

TTV_ErrorCode ttv::StandardWebSocket::Recv(MessageType& type, uint8_t* buffer, size_t length, size_t& received) {
  TTV_ErrorCode ec = TTV_EC_SUCCESS;

  received = 0;
  type = MessageType::None;

  if (!mInnerData->mMessageBuffer.empty()) {
    const auto& msg = mInnerData->mMessageBuffer[0];

    if (length < msg.message.size()) {
      ec = TTV_EC_INVALID_BUFFER;
    } else {
      memcpy(buffer, msg.message.data(), msg.message.size());
      received = msg.message.size();
      type = msg.type;

      mInnerData->mMessageBuffer.erase(mInnerData->mMessageBuffer.begin());
    }
  }

  return ec;
}

TTV_ErrorCode ttv::StandardWebSocket::Peek(MessageType& type, size_t& length) {
  TTV_ErrorCode ec = PollBaseSocket();

  if (SOCKET_SUCCEEDED(ec)) {
    if (mInnerData->mMessageBuffer.empty()) {
      length = 0;
      type = MessageType::None;
    } else {
      length = mInnerData->mMessageBuffer[0].message.size();
      type = mInnerData->mMessageBuffer[0].type;

      ec = TTV_EC_SUCCESS;
    }
  } else {
    (void)Disconnect();
  }

  return ec;
}

bool ttv::StandardWebSocket::Connected() {
  return mInnerData->mConnected;
}

ttv::StandardWebSocketFactory::~StandardWebSocketFactory() {}

bool ttv::StandardWebSocketFactory::IsProtocolSupported(const std::string& protocol) {
  return protocol == "ws" || protocol == "wss";
}

TTV_ErrorCode ttv::StandardWebSocketFactory::CreateWebSocket(
  const std::string& uri, std::shared_ptr<ttv::IWebSocket>& result) {
  result.reset();

  Uri url(uri);
  if (url.GetProtocol() == "ws" || url.GetProtocol() == "wss") {
    std::shared_ptr<StandardWebSocket> socket = std::make_shared<StandardWebSocket>();
    if (TTV_SUCCEEDED(socket->Initialize(uri))) {
      result = socket;
      return TTV_EC_SUCCESS;
    }
  }

  return TTV_EC_UNIMPLEMENTED;
}
