/****************************************************************************
 * Twitch SDK
 *
 * This software is supplied under the terms of a license agreement with
 * Twitch Interactive, Inc. and may not be copied or used except in accordance
 * with the terms of that agreement
 *
 * Copyright (c) 2012-2016 Twitch Interactive, Inc.
 ***************************************************************************/

#include "twitchsdk/core/internal/pch.h"

#include "twitchsdk/core/pubsub/pubsubclientconnection.h"

#include "twitchsdk/core/json/reader.h"
#include "twitchsdk/core/json/value.h"
#include "twitchsdk/core/json/writer.h"
#include "twitchsdk/core/pubsub/pubsubclient.h"
#include "twitchsdk/core/pubsub/pubsubclientmessages.h"
#include "twitchsdk/core/stringutilities.h"
#include "twitchsdk/core/systemclock.h"
#include "twitchsdk/core/user/oauthtoken.h"
#include "twitchsdk/core/user/user.h"

#include <ctime>
#include <sstream>

// https://twitch.quip.com/ptTfAnm3oa9t

namespace {
static uint32_t gNextPubSubClientConnectionIndex = 0;
}

using namespace ttv::pubsub;

namespace {
const char* kLogger = "PubSubClientConnection";

const uint64_t kMaxPollTimeMilliseconds = 500;
const uint64_t kPingIntervalMilliseconds =
  5 * 60 * 1000;  //!< The server will disconnect on us if we don't ping at least every 10 minutes
const uint64_t kPingJitterMilliseconds =
  1 * 60 * 1000;  //!< The server will disconnect on us if we don't ping at least every 10 minutes
const uint64_t kPongMilliseconds =
  5000;  //!< The server should send a pong in response to our ping so disconnect after 5 seconds and reconnect if we don't get it
const uint64_t kRequestTimeoutMilliseconds = 5000;  //!< The amount of time to wait for a reponse to a message.

namespace PubSubErrors {
const char* kBadMessage = "ERR_BADMESSAGE";
const char* kBadAuth = "ERR_BADAUTH";
const char* kBadTopic = "ERR_BADTOPIC";
};  // namespace PubSubErrors
}  // namespace

ttv::PubSubClientConnection::OutstandingRequest::OutstandingRequest() {
  requestTime = GetCurrentTimeAsUnixTimestamp();
}

ttv::PubSubClientConnection::PubSubClientConnection(std::shared_ptr<User> user, SettingRepository* settingsRepository)
    : mUser(user), mState(PubSubState::Disconnected), mConnectionIndex(gNextPubSubClientConnectionIndex) {
  gNextPubSubClientConnectionIndex++;

  Log(MessageLevel::Debug, "PubSubClientConnection()");

  mReadBuffer.reserve(4096);

  std::string uri;
  if (settingsRepository != nullptr) {
    settingsRepository->GetSetting(kPubSubEndpointUriKey, uri);
  }

  if (uri.empty()) {
    uri = "wss://pubsub-edge.twitch.tv";
  } else {
    Log(MessageLevel::Info, "Using overridden PubSub endpoint %s", uri.c_str());
  }

  ttv::CreateWebSocket(uri, mSocket);
}

ttv::PubSubClientConnection::TopicSubscriptionState::Enum ttv::PubSubClientConnection::GetTopicState(
  const std::string& topic) const {
  auto iter = mTopicStates.find(topic);
  if (iter == mTopicStates.end()) {
    return TopicSubscriptionState::Unsubscribed;
  }

  return iter->second;
}

void ttv::PubSubClientConnection::AddListener(const std::shared_ptr<IListener>& listener) {
  mListeners.AddListener(listener);
}

void ttv::PubSubClientConnection::RemoveListener(const std::shared_ptr<IListener>& listener) {
  mListeners.RemoveListener(listener);
}

TTV_ErrorCode ttv::PubSubClientConnection::Connect() {
  Log(MessageLevel::Debug, "Connect()");

  TTV_ErrorCode ec = TTV_EC_SUCCESS;

  if (mSocket != nullptr) {
    ec = mSocket->Connect();

    if (TTV_SUCCEEDED(ec)) {
      SetConnectionState(PubSubState::Connected, ec);
    }
  } else {
    ec = TTV_EC_SHUT_DOWN;
  }

  return ec;
}

TTV_ErrorCode ttv::PubSubClientConnection::Disconnect() {
  Log(MessageLevel::Debug, "Disconnect()");

  if (mSocket != nullptr) {
    mSocket->Disconnect();

    mSocket.reset();
  }

  SetConnectionState(PubSubState::Disconnected, TTV_EC_SUCCESS);

  return TTV_EC_SUCCESS;
}

TTV_ErrorCode ttv::PubSubClientConnection::Send(const std::string& text) {
  Log(MessageLevel::Debug, "Send(): %s", text.c_str());

  TTV_ErrorCode ec;

  if (Connected()) {
    ec = mSocket->Send(IWebSocket::MessageType::Text, reinterpret_cast<const uint8_t*>(text.c_str()), text.size());

    if (TTV_FAILED(ec)) {
      Log(MessageLevel::Error, "Failed to send: %s", ErrorToString(ec));

      SetConnectionState(PubSubState::Disconnected, ec);
    }
  } else {
    ec = TTV_EC_SOCKET_ENOTCONN;
  }

  return ec;
}

bool ttv::PubSubClientConnection::Connected() {
  return mState == PubSubState::Connected;
}

TTV_ErrorCode ttv::PubSubClientConnection::Listen(const std::string& topic, std::shared_ptr<User> user) {
  Log(MessageLevel::Debug, "Listen(): %s", topic.c_str());

  auto oauth = user->GetOAuthToken();
  if (oauth == nullptr || !oauth->GetValid()) {
    return TTV_EC_INVALID_ARG;
  }

  TopicSubscriptionState::Enum state = GetTopicState(topic);

  // Already in progress
  if (state == TopicSubscriptionState::Subscribed || state == TopicSubscriptionState::Subscribing) {
    return TTV_EC_SUCCESS;
  }
  // We are currently unsubscribing so wait for the unsub to finish then issue a sub
  else if (state == TopicSubscriptionState::Unsubscribing) {
    return TTV_EC_REQUEST_PENDING;
  }

  std::string nonce = GetGuid();

  json::Value root;
  root["type"] = "LISTEN";
  root["nonce"] = nonce;
  root["data"] = json::Value();

  json::Value& data = root["data"];
  data["topics"] = json::Value(json::arrayValue);
  data["topics"].append(topic);
  data["auth_token"] = oauth->GetToken();

  TTV_ErrorCode ec = SendMessageOverSocket(root);

  if (TTV_SUCCEEDED(ec)) {
    mTopicStates[topic] = TopicSubscriptionState::Subscribing;

    OutstandingRequest request;
    request.nonce = nonce;
    request.callback = [this, user, oauth, topic](
                         TTV_ErrorCode callbackEc, const std::string& /*error*/, const json::Value& /*result*/) {
      Log(MessageLevel::Debug, "Listen callback: %s", ErrorToString(callbackEc));

      TopicSubscriptionState::Enum topicState = GetTopicState(topic);
      TTV_ASSERT(topicState == TopicSubscriptionState::Subscribing);

      if (TTV_SUCCEEDED(callbackEc)) {
        mTopicStates[topic] = TopicSubscriptionState::Subscribed;
        topicState = TopicSubscriptionState::Subscribed;
      } else {
        if (callbackEc == TTV_EC_AUTHENTICATION) {
          mListeners.Invoke([this, oauth, callbackEc](const std::shared_ptr<IListener>& listener) {
            listener->OnAuthenticationError(this, callbackEc, oauth);
          });
        }

        auto iter = mTopicStates.find(topic);
        TTV_ASSERT(iter != mTopicStates.end());
        if (iter != mTopicStates.end()) {
          if (topicState == TopicSubscriptionState::Subscribing) {
            mTopicStates.erase(iter);
          }
        }

        topicState = TopicSubscriptionState::Unsubscribed;
      }

      mListeners.Invoke([this, topic, topicState, callbackEc](std::shared_ptr<IListener> listener) {
        listener->OnTopicSubscriptionChanged(this, topic, topicState, callbackEc);
      });
    };

    mOutstandingRequests[nonce] = request;
  }

  return ec;
}

TTV_ErrorCode ttv::PubSubClientConnection::Unlisten(const std::string& topic) {
  Log(MessageLevel::Debug, "Unlisten(): %s", topic.c_str());

  TopicSubscriptionState::Enum state = GetTopicState(topic);

  // Already in progress
  if (state == TopicSubscriptionState::Unsubscribed || state == TopicSubscriptionState::Unsubscribing) {
    return TTV_EC_SUCCESS;
  }
  // We are currently subscribing so wait for the sub to finish then issue an unsub
  else if (state == TopicSubscriptionState::Subscribing) {
    return TTV_EC_REQUEST_PENDING;
  }

  std::string nonce = GetGuid();

  json::Value root;
  root["type"] = "UNLISTEN";
  root["nonce"] = nonce;
  root["data"] = json::Value();

  json::Value& data = root["data"];
  data["topics"] = json::Value(json::arrayValue);
  data["topics"].append(topic);

  TTV_ErrorCode ec = SendMessageOverSocket(root);

  if (TTV_SUCCEEDED(ec)) {
    mTopicStates[topic] = TopicSubscriptionState::Unsubscribing;

    OutstandingRequest request;
    request.nonce = nonce;
    request.callback = [this, topic](
                         TTV_ErrorCode callbackEc, const std::string& /*error*/, const json::Value& /*result*/) {
      Log(MessageLevel::Debug, "Unlisten callback: %s", ErrorToString(callbackEc));

      TopicSubscriptionState::Enum topicState = GetTopicState(topic);
      TTV_ASSERT(topicState == TopicSubscriptionState::Unsubscribing);

      // NOTE: I don't know what kind of error would make this fail, we just pretend we're unsubbed

      auto iter = mTopicStates.find(topic);
      TTV_ASSERT(iter != mTopicStates.end());
      if (iter != mTopicStates.end()) {
        mTopicStates.erase(iter);
      }

      topicState = TopicSubscriptionState::Unsubscribed;

      mListeners.Invoke([this, topic, topicState, callbackEc](std::shared_ptr<IListener> listener) {
        listener->OnTopicSubscriptionChanged(this, topic, topicState, callbackEc);
      });
    };

    mOutstandingRequests[nonce] = request;
  }

  return ec;
}

TTV_ErrorCode ttv::PubSubClientConnection::SendMessageOverSocket(const json::Value& root) {
  if (!Connected()) {
    return TTV_EC_SOCKET_ENOTCONN;
  }

  std::string text = mJsonWriter.write(root) + "\r\n";
  Log(MessageLevel::Debug, "SendMessageOverSocket(): %s", text.c_str());

  TTV_ErrorCode ec =
    mSocket->Send(IWebSocket::MessageType::Text, reinterpret_cast<const uint8_t*>(text.data()), text.size());

  if (TTV_FAILED(ec)) {
    Log(MessageLevel::Error, "Error sending message: %s", ErrorToString(ec));

    SetConnectionState(PubSubState::Disconnected, ec);
  }

  return ec;
}

TTV_ErrorCode ttv::PubSubClientConnection::SendNullDataMessage(const std::string& type) {
  json::Value root;
  root["type"] = type;
  root["data"] = json::Value();

  return SendMessageOverSocket(root);
}

TTV_ErrorCode ttv::PubSubClientConnection::InitiatePing() {
  Log(MessageLevel::Debug, "InitiatePing()");

  TTV_ErrorCode ec = SendNullDataMessage("PING");

  if (TTV_SUCCEEDED(ec)) {
    mPingTimer.SetWithJitter(kPingIntervalMilliseconds, kPingJitterMilliseconds);
    mPongTimer.Set(kPongMilliseconds);
  }

  return ec;
}

TTV_ErrorCode ttv::PubSubClientConnection::HandleIncomingMessage() {
  Log(MessageLevel::Debug, "HandleIncomingMessage() Received message: %s", mReadBuffer.c_str());

  json::Value jRoot;
  bool parseRet = mJsonReader.parse(mReadBuffer.data(), mReadBuffer.data() + mReadBuffer.size(), jRoot);

  if (!parseRet) {
    Log(MessageLevel::Error, "Failed to parse message as json: %s", mReadBuffer.c_str());
    return TTV_EC_INVALID_JSON;
  }

  const json::Value& jType = jRoot["type"];
  const json::Value& jError = jRoot["error"];

  std::string type;
  std::string nonce;
  std::string error;

  if (jType.isNull() || !jType.isString()) {
    Log(MessageLevel::Error, "'type' missing from json: %s", mReadBuffer.c_str());
    return TTV_EC_INVALID_JSON;
  }

  type = jType.asString();

  // See if we have an outstanding request we need to reply to the client
  const json::Value& jNonce = jRoot["nonce"];
  if (!jNonce.isNull() && jNonce.isString()) {
    nonce = jNonce.asString();
  }

  if (!jError.isNull() && jError.isString()) {
    error = jError.asString();
  }

  TTV_ErrorCode ec = TTV_EC_SUCCESS;

  // Log the error
  if (error != "") {
    if (error == PubSubErrors::kBadAuth) {
      ec = TTV_EC_AUTHENTICATION;
    } else if (error == PubSubErrors::kBadMessage) {
      ec = TTV_EC_PUBSUB_BAD_MESSAGE;
    } else if (error == PubSubErrors::kBadTopic) {
      ec = TTV_EC_PUBSUB_BAD_TOPIC;
    } else {
      ec = TTV_EC_PUBSUB_RESPONSE_ERROR;
    }

    Log(MessageLevel::Error, "Error from pubsub: %s, %s, %s", error.c_str(), type.c_str(), nonce.c_str());
  }

  // A response to a sent message
  if (type == "RESPONSE") {
    if (nonce != "") {
      // See if we have an outstanding request in progress for this nonce
      auto iter = mOutstandingRequests.find(nonce);
      if (iter != mOutstandingRequests.end()) {
        auto request = iter->second;
        mOutstandingRequests.erase(iter);

        if (request.callback != nullptr) {
          request.callback(ec, error, jRoot);
        }
      } else {
        Log(MessageLevel::Error, "Unable to find outstanding request for nonce: %s", nonce.c_str());
      }
    }
  }
  // An unsolicited message
  else if (type == "MESSAGE") {
    const json::Value& jData = jRoot["data"];
    const json::Value& jTopic = jData["topic"];
    const json::Value& jMessage = jData["message"];

    if (jData.isNull() || !jData.isObject() || jTopic.isNull() || !jTopic.isString() || jMessage.isNull() ||
        !jMessage.isString()) {
      Log(MessageLevel::Error, "MESSAGE has invalid format, skipping");
    } else {
      std::string topic = jTopic.asString();
      std::string message = jMessage.asString();

      // We expect the message string to be json so try and parse it
      json::Value jParsedMessage;
      parseRet = mJsonReader.parse(message.data(), message.data() + message.size(), jParsedMessage);

      const json::Value* data;

      // Parsed as valid json
      if (parseRet) {
        data = &jParsedMessage;
      }
      // Wasn't valid json so pass as the original string
      else {
        Log(MessageLevel::Debug, "Couldn't parse message as json, passing as string");
        data = &jMessage;
      }

      mListeners.Invoke([this, &topic, data](std::shared_ptr<IListener> listener) {
        listener->OnTopicMessageReceived(this, topic, *data);
      });
    }
  } else if (type == "PONG") {
    mPongTimer.Clear();
  } else if (type == "RECONNECT") {
    mListeners.Invoke([this](std::shared_ptr<IListener> listener) { listener->OnReconnectReceived(this); });
  } else {
    Log(MessageLevel::Error, "Received unhandled message type: %s", type.c_str());
    ec = TTV_EC_INVALID_JSON;
  }

  return ec;
}

TTV_ErrorCode ttv::PubSubClientConnection::PollSocket() {
  TTV_ErrorCode ec = TTV_EC_SUCCESS;
  uint64_t start = GetSystemTimeMilliseconds();
  for (;;) {
    if (!Connected()) {
      return TTV_EC_SOCKET_ENOTCONN;
    }

    // Check the socket for messages
    IWebSocket::MessageType type = IWebSocket::MessageType::None;
    size_t messageSize = 0;
    ec = mSocket->Peek(type, messageSize);

    if (type == IWebSocket::MessageType::None || SOCKET_FAILED(ec)) {
      break;
    }

    // Read the message
    mReadBuffer.resize(messageSize);
    ec = mSocket->Recv(type, reinterpret_cast<uint8_t*>(&mReadBuffer[0]), mReadBuffer.size(), messageSize);

    // Got a message
    if (TTV_SUCCEEDED(ec) && messageSize > 0) {
      ec = HandleIncomingMessage();
    }

    // Make sure we don't run for too long
    uint64_t now = GetSystemTimeMilliseconds();
    if (now - start >= kMaxPollTimeMilliseconds) {
      Log(MessageLevel::Debug, "Processing messages for a long time, breaking");
      break;
    }
  }

  if (SOCKET_FAILED(ec)) {
    SetConnectionState(PubSubState::Disconnected, ec);
  }

  return ec;
}

void ttv::PubSubClientConnection::SetConnectionState(PubSubState state, TTV_ErrorCode ec) {
  if (mState == state) {
    return;
  }

  mState = state;

  Log(MessageLevel::Debug, "SetConnectionState(): %d", static_cast<int>(state));

  switch (state) {
    case PubSubState::Disconnected: {
      mPingTimer.Clear();
      mPongTimer.Clear();

      // Abort all outstanding requests
      std::vector<InternalCallback> abortedCallbacks;
      for (auto iter = mOutstandingRequests.begin(); iter != mOutstandingRequests.end(); ++iter) {
        auto& request = iter->second;

        if (request.callback != nullptr) {
          abortedCallbacks.push_back(request.callback);
        }
      }

      mOutstandingRequests.clear();

      // Call the callbacks outside the loop, since they can potentially mutate our
      // collection of outstanding requests.
      for (const auto& abortedCallback : abortedCallbacks) {
        abortedCallback(TTV_EC_SOCKET_ECONNABORTED, "", json::Value(json::nullValue));
      }

      // Fire implicit UNLISTENs
      auto statesCopy = mTopicStates;
      mTopicStates.clear();

      for (auto kvp : statesCopy) {
        const auto& topicName = kvp.first;

        if (kvp.second != TopicSubscriptionState::Unsubscribed) {
          mListeners.Invoke([this, topicName, ec](std::shared_ptr<IListener> listener) {
            listener->OnTopicSubscriptionChanged(this, topicName, TopicSubscriptionState::Unsubscribed, ec);
          });
        }
      }

      Disconnect();

      break;
    }
    case PubSubState::Connected: {
      mPongTimer.Clear();
      mPingTimer.SetWithJitter(kPingIntervalMilliseconds, kPingJitterMilliseconds);

      break;
    }
    default: {
      TTV_ASSERT(false);
      break;
    }
  }

  mListeners.Invoke(
    [this, state, ec](std::shared_ptr<IListener> listener) { listener->OnConnectionStateChanged(this, state, ec); });
}

void ttv::PubSubClientConnection::Update() {
  // If the pong timer goes off then we need to reconnect to the server
  if (mPongTimer.Check(true)) {
    mListeners.Invoke([this](std::shared_ptr<IListener> listener) { listener->OnPongTimeout(this); });
  }

  // Make sure we keep the connection alive
  if (mPingTimer.Check(true)) {
    InitiatePing();
  }

  // Timeout pending requests
  Timestamp now = GetCurrentTimeAsUnixTimestamp();
  std::vector<InternalCallback> timeoutCallbacks;
  for (auto iter = mOutstandingRequests.begin(); iter != mOutstandingRequests.end();) {
    auto& request = iter->second;

    // Timeout
    if (now - request.requestTime >= kRequestTimeoutMilliseconds) {
      if (request.callback != nullptr) {
        timeoutCallbacks.push_back(request.callback);
      }

      iter = mOutstandingRequests.erase(iter);
    } else {
      ++iter;
    }
  }

  // Call the callbacks outside the loop, since they can potentially mutate our
  // collection of outstanding requests.
  for (const auto& timeoutCallback : timeoutCallbacks) {
    timeoutCallback(TTV_EC_REQUEST_TIMEDOUT, "", json::Value(json::nullValue));
  }
}

void ttv::PubSubClientConnection::Log(MessageLevel level, const char* format, ...) {
  MessageLevel activeLevel = MessageLevel::None;
  ttv::trace::GetComponentMessageLevel(kLogger, activeLevel);
  if (level < activeLevel) {
    return;
  }

  auto user = mUser.lock();
  if (user == nullptr) {
    return;
  }

  std::stringstream stream;
  stream << '[' << user->GetUserName() << ", " << mConnectionIndex << "] ";
  stream << format;

  va_list args;
  va_start(args, format);
  ttv::trace::MessageVaList(kLogger, level, stream.str().c_str(), args);
  va_end(args);
}
