/****************************************************************************
 * Twitch SDK
 *
 * This software is supplied under the terms of a license agreement with
 * Twitch Interactive, Inc. and may not be copied or used except in accordance
 * with the terms of that agreement
 *
 * Copyright (c) 2012-2016 Twitch Interactive, Inc.
 ***************************************************************************/

#include "twitchsdk/core/internal/pch.h"

#include "twitchsdk/core/httprequestutils.h"

#include "twitchsdk/core/stringutilities.h"
#include "twitchsdk/core/task/simplejsonhttptask.h"
#include "twitchsdk/core/task/taskrunner.h"
#include "twitchsdk/core/thread.h"
#include "twitchsdk/core/user/user.h"

#include <algorithm>
#include <regex>

namespace {
bool ShouldEncode(unsigned char c) {
  switch (c) {
    case '_':
    case '-':
    case '~':
    case '.':
      return false;
  }

  return isalnum(c) == 0;
}
}  // namespace

void ttv::UrlEncode(const std::string& inputString, std::stringstream& outputStream) {
  outputStream << std::hex << std::uppercase;
  for (auto currentCharacter : inputString) {
    if (ShouldEncode(static_cast<unsigned char>(currentCharacter))) {
      // Going from a negative char to an unsigned integer will cause the additional 3 bytes to be padded full of 1s.
      // Attempting to use an unsigned char directly will cause the stringstream << operator to insert the ASCII symbol
      // the char maps to.
      outputStream << "%" << static_cast<unsigned int>(static_cast<unsigned char>(currentCharacter));
    } else {
      outputStream << currentCharacter;
    }
  }
}

std::string ttv::UrlEncode(const std::string& inputString) {
  std::stringstream stream;
  UrlEncode(inputString, stream);
  return stream.str();
}

void ttv::UrlDecode(const std::string& inputString, std::stringstream& outputStream) {
  outputStream << std::hex << std::uppercase;
  size_t index = 0;
  while (index < inputString.size()) {
    if (inputString[index] == '%') {
      // Convert the next 2 hex digits into a character
      if (inputString.size() > index + 2) {
        char hexValue[4];
        hexValue[0] = inputString.data()[index + 1];
        hexValue[1] = inputString.data()[index + 2];
        hexValue[2] = '\0';

        int intValue = 0;
        sscanf(hexValue, "%x", &intValue);

        outputStream << static_cast<char>(intValue);

        index += 3;
      }
      // Invalid encoding
      else {
        return;
      }
    } else {
      outputStream << inputString[index];
      index++;
    }
  }
}

void ttv::UrlDecode(const std::string& input, std::string& result) {
  std::stringstream sstream;
  UrlDecode(input, sstream);
  result = sstream.str();
}

std::string ttv::UrlDecode(const std::string& inputString) {
  std::stringstream sstream;
  UrlDecode(inputString, sstream);
  return sstream.str();
}

std::string ttv::BuildUrlEncodedRequestParams(const std::vector<std::pair<std::string, std::string>>& requestParams) {
  std::stringstream paramsStream;
  for (auto it = requestParams.begin(); it != requestParams.end(); ++it) {
    if (it != requestParams.begin()) {
      paramsStream << "&";
    }
    UrlEncode(it->first, paramsStream);
    paramsStream << "=";
    UrlEncode(it->second, paramsStream);
  }
  return paramsStream.str();
}

std::string ttv::BuildUrlEncodedRequestParams(const std::vector<ttv::HttpParam>& requestParams) {
  std::stringstream paramsStream;
  for (auto it = requestParams.begin(); it != requestParams.end(); ++it) {
    if (it != requestParams.begin()) {
      paramsStream << "&";
    }
    UrlEncode(it->paramName, paramsStream);
    paramsStream << "=";
    UrlEncode(it->paramValue, paramsStream);
  }
  return paramsStream.str();
}

std::string ttv::BuildUrlEncodedRequestParams(const std::map<std::string, std::string>& requestParams) {
  std::stringstream paramsStream;
  for (auto it = requestParams.begin(); it != requestParams.end(); ++it) {
    if (it != requestParams.begin()) {
      paramsStream << "&";
    }
    UrlEncode(it->first, paramsStream);
    paramsStream << "=";
    UrlEncode(it->second, paramsStream);
  }
  return paramsStream.str();
}

std::string ttv::BuildHttpHeader(const std::vector<ttv::HttpParam>& headerParams) {
  std::stringstream headerStream;

  for (auto it = headerParams.begin(); it != headerParams.end(); ++it) {
    headerStream << it->paramName.c_str();
    headerStream << ":";
    headerStream << it->paramValue.c_str();
    headerStream << "\r\n";  // Header parameters are seperated by CRLF
  }

  return headerStream.str();
}

bool ttv::ContainsHttpParameter(const std::vector<HttpParam>& headers, const std::string& name) {
  std::string lowerName = name;
  (void)std::transform(lowerName.begin(), lowerName.end(), lowerName.begin(), ::tolower);

  auto iter = std::find_if(headers.begin(), headers.end(), [&lowerName](const HttpParam& p) -> bool {
    std::string lowerParam = p.paramName;
    (void)std::transform(lowerParam.begin(), lowerParam.end(), lowerParam.begin(), ::tolower);

    return lowerParam == lowerName;
  });

  return iter != headers.end();
}

TTV_ErrorCode ttv::SplitHttpParameters(
  const std::string& parameterString, std::vector<std::pair<std::string, std::string>>& result) {
  size_t currentIndex = 0;

  while (true) {
    auto equalsIndex = parameterString.find('=', currentIndex);

    if (equalsIndex == std::string::npos) {
      break;
    } else {
      auto valueIndex = equalsIndex + 1;
      std::string key;
      UrlDecode(parameterString.substr(currentIndex, equalsIndex - currentIndex), key);

      std::string value;
      auto ampersandIndex = parameterString.find('&', valueIndex);

      // If we didn't find the ampersand, we just get a substring to the end of the string.
      auto valueLength = (ampersandIndex == std::string::npos) ? std::string::npos : ampersandIndex - valueIndex;
      UrlDecode(parameterString.substr(equalsIndex + 1, valueLength), value);

      result.emplace_back(key, value);

      if ((ampersandIndex == std::string::npos) || (ampersandIndex == parameterString.size())) {
        break;
      } else {
        currentIndex = ampersandIndex + 1;
      }
    }
  }

  return TTV_EC_SUCCESS;
}

TTV_ErrorCode ttv::SplitHttpParameters(const std::string& parameterString, std::map<std::string, std::string>& result) {
  std::vector<std::pair<std::string, std::string>> params;
  TTV_ErrorCode ec = SplitHttpParameters(parameterString, params);

  if (TTV_SUCCEEDED(ec)) {
    for (const auto& pair : params) {
      result[pair.first] = pair.second;
    }
  }

  return ec;
}

TTV_ErrorCode ttv::GenerateSslVerificationHosts(const std::string& originalHost, std::vector<std::string>& result) {
  if (ttv::IsHostAnIpAddress(originalHost)) {
    return TTV_EC_INVALID_ARG;
  }

  std::string host(originalHost);
  int64_t periods = std::count(host.begin(), host.end(), '.');

  while (periods >= 2) {
    result.emplace_back(host);
    host = "*" + host.substr(host.find('.'));
    result.emplace_back(host);
    host = host.substr(2);  // Get rid of leading "*."
    periods--;

    // We assume that the host name doesn't start with a .
    // We also assume that the TLD has no periods in it (i.e. com, tv) - we would need a table of TLDs otherwise.
  }
  result.emplace_back(host);

  return TTV_EC_SUCCESS;
}

bool ttv::IsHostAnIpAddress(const std::string& hostName) {
  // Assume a standard 32-bit IPv4 address in dotted-decimal representation (a.b.c.d).
  std::regex ipRegex("\\d{1,3}\\.\\d{1,3}\\.\\d{1,3}\\.\\d{1,3}", std::regex::ECMAScript);
  return std::regex_match(hostName, ipRegex);
}

ttv::Uri::Uri() {}

ttv::Uri::Uri(const std::string& url) {
  SetUrl(url);
}

void ttv::Uri::SetUrl(const std::string& url) {
  DisassembleUrl(url);
}

void ttv::Uri::DisassembleUrl(const std::string& url) {
  mParams.clear();

  // Find the protocol
  size_t protocolIndex = url.find("://");
  if (protocolIndex != std::string::npos) {
    mProtocol = url.substr(0, protocolIndex);
    protocolIndex += 3;
  } else {
    mProtocol = "";
    protocolIndex = 0;
  }

  // Find the host name
  size_t slashIndex = url.find('/', protocolIndex);
  size_t paramsIndex = url.find('?');
  if (slashIndex != std::string::npos) {
    mHostName = url.substr(protocolIndex, slashIndex - protocolIndex);

    mPath = url.substr(slashIndex, paramsIndex - slashIndex);

    // Trim off trailing / unless the path is "/"
    while (mPath.size() > 1) {
      if (mPath.back() == '/') {
        mPath.erase(mPath.size() - 1, 1);
      } else {
        break;
      }
    }
  } else {
    // No path
    if (paramsIndex != std::string::npos) {
      mHostName = url.substr(protocolIndex, paramsIndex - protocolIndex);
      slashIndex = protocolIndex + mHostName.size();
      mPath = url.substr(slashIndex, paramsIndex - slashIndex);
    } else {
      mHostName = url.substr(protocolIndex);
      mPath = "";
    }
  }

  // Find the port
  size_t colonIndex = mHostName.find(':');
  if (colonIndex != std::string::npos) {
    mPort = mHostName.substr(colonIndex + 1);
    mHostName = mHostName.substr(0, colonIndex);
  } else {
    mPort = "";
  }

  // Find the query paramters
  if (paramsIndex != std::string::npos) {
    std::vector<std::string> params;
    Split(url.substr(paramsIndex + 1), params, '&', true);

    for (auto& str : params) {
      std::string key;
      std::string value;

      size_t index = str.find('=');

      if (index != std::string::npos) {
        key = str.substr(0, index);
        value = str.substr(index + 1);
      } else {
        key = str;
      }

      if (key != "") {
        mParams[key] = UrlDecode(value);
      }
    }
  }
}

std::string ttv::Uri::AssembleUrl() const {
  std::stringstream url;

  if (mProtocol != "") {
    url << mProtocol << "://";
  }

  url << mHostName;

  if (mPort != "") {
    url << ':' << mPort;
  }

  url << mPath;

  if (!mParams.empty()) {
    url << "?" + BuildUrlEncodedRequestParams(mParams);
  }

  return url.str();
}

ttv::Uri::operator std::string() const {
  return GetUrl();
}

bool ttv::Uri::operator==(const Uri& other) const {
  if (mProtocol != other.mProtocol) {
    return false;
  }

  if (mHostName != other.mHostName) {
    return false;
  }

  if (mPort != other.mPort) {
    return false;
  }

  if (mPath != other.mPath) {
    return false;
  }

  if (mParams.size() != other.mParams.size()) {
    return false;
  }

  for (const auto& param : mParams) {
    const auto& iter = other.mParams.find(param.first);
    if (iter == other.mParams.end() || iter->second != param.second) {
      return false;
    }
  }

  return true;
}

bool ttv::Uri::operator!=(const Uri& other) const {
  return !(*this == other);
}

bool ttv::Uri::GetPort(uint32_t& result) const {
  result = 0;

  if (mPort == "") {
    return false;
  }

  return ParseNum(mPort, result);
}

void ttv::Uri::GetPathComponents(std::vector<std::string>& result) const {
  Split(mPath, result, '/', false);
}

void ttv::Uri::ClearParams() {
  mParams.clear();
}

void ttv::Uri::SetParam(const std::string& param, const char* value) {
  if (value != nullptr) {
    mParams[param] = std::string(value);
  }
}

void ttv::Uri::SetParam(const std::string& param, const std::string& value) {
  mParams[param] = value;
}

void ttv::Uri::SetParam(const std::string& param, uint32_t value) {
  char buffer[64];
  snprintf(buffer, sizeof(buffer), "%u", value);
  mParams[param] = buffer;
}

void ttv::Uri::SetParam(const std::string& param, int32_t value) {
  char buffer[64];
  snprintf(buffer, sizeof(buffer), "%d", value);
  mParams[param] = buffer;
}

void ttv::Uri::SetParam(const std::string& param, uint64_t value) {
  char buffer[64];
  snprintf(buffer, sizeof(buffer), "%llu", static_cast<unsigned long long int>(value));
  mParams[param] = buffer;
}

void ttv::Uri::SetParam(const std::string& param, int64_t value) {
  char buffer[64];
  snprintf(buffer, sizeof(buffer), "%lld", static_cast<long long int>(value));
  mParams[param] = buffer;
}

void ttv::Uri::SetParam(const std::string& param, double value) {
  char buffer[64];
  snprintf(buffer, sizeof(buffer), "%g", value);
  mParams[param] = buffer;
}

void ttv::Uri::SetParam(const std::string& param, bool value) {
  mParams[param] = value ? "true" : "false";
}

bool ttv::Uri::ContainsParam(const std::string& param) const {
  return mParams.find(param) != mParams.end();
}

std::string ttv::Uri::GetUrl() const {
  return AssembleUrl();
}

ttv::PagedRequestFetcher::PagedRequestFetcher()
    : mCreateTaskCallback(nullptr), mCompleteCallback(nullptr), mCancel(false) {}

TTV_ErrorCode ttv::PagedRequestFetcher::Start(
  const std::string& initialCursor, CreateTaskCallback createTaskCallback, CompleteCallback completeCallback) {
  TTV_ASSERT(!InProgress());

  mCreateTaskCallback = createTaskCallback;
  mCompleteCallback = completeCallback;
  mCursor = initialCursor;
  mCancel = false;

  return FetchPage();
}

void ttv::PagedRequestFetcher::Cancel() {
  mCancel = true;

  if (mCurrentTask != nullptr) {
    mCurrentTask->Abort();
  }
}

void ttv::PagedRequestFetcher::FetchComplete(TTV_ErrorCode ec, const std::string& cursor) {
  mCurrentTask.reset();

  // Failed to fetch a page
  if (TTV_FAILED(ec)) {
    mCompleteCallback(ec);
  }
  // Cancelled
  else if (mCancel) {
    mCompleteCallback(TTV_EC_REQUEST_ABORTED);
  }
  // Fetch the next one
  else {
    mCursor = cursor;
    FetchPage();
  }
}

TTV_ErrorCode ttv::PagedRequestFetcher::FetchPage() {
  TTV_ErrorCode ec = mCreateTaskCallback(mCursor, mCurrentTask);
  if (ec == TTV_EC_SHUT_DOWN) {
    mCurrentTask.reset();
  }

  if (mCurrentTask == nullptr) {
    mCompleteCallback(ec);
  }

  return ec;
}
