/****************************************************************************
 * Twitch SDK
 *
 * This software is supplied under the terms of a license agreement with
 * Twitch Interactive, Inc. and may not be copied or used except in accordance
 * with the terms of that agreement
 *
 * Copyright (c) 2012-2016 Twitch Interactive, Inc.
 ***************************************************************************/

#include "twitchsdk/core/internal/pch.h"

#pragma comment(lib, "wininet.lib")

#include "twitchsdk/core/assertion.h"
#include "twitchsdk/core/httprequestutils.h"
#include "twitchsdk/core/wininethttprequest.h"

#include <windows.h>

#include <WinInet.h>

#include <algorithm>
#include <codecvt>
#include <sstream>

namespace {
bool ConvertUtf8ToUtf16(const std::string& str, std::wstring& result) {
  int length = MultiByteToWideChar(CP_UTF8, 0, str.c_str(), -1, nullptr, 0);
  if (length == 0) {
    result = L"";
    return true;
  }

  std::vector<wchar_t> buffer;
  buffer.resize(length);
  length = MultiByteToWideChar(CP_UTF8, 0, str.c_str(), -1, buffer.data(), length);
  result = buffer.data();
  return true;
}

int IsWhitespace(int ch) {
  return ch == ' ' || ch == ' ';
}

void ParseHeaders(const std::vector<wchar_t>& raw, std::map<std::string, std::string>& headers) {
  const char* kLineDelimeter = "\r\n";
  static const size_t kLineDelimeterLength = strlen(kLineDelimeter);

  std::wstring_convert<std::codecvt_utf8_utf16<wchar_t>> converter;
  std::string converted = converter.to_bytes(raw.data());

  // Break into lines
  for (;;) {
    size_t index = converted.find(kLineDelimeter);

    // No more lines
    if (index == std::string::npos || index == 0) {
      break;
    }

    std::string line = converted.substr(0, index);
    converted = converted.substr(index + kLineDelimeterLength);

    // Break the line into key-value pair
    index = line.find(':');
    if (index != std::string::npos) {
      std::string name = line.substr(0, index);
      std::string value = line.substr(index + 1);

      // Trim whitespace from the start of the value
      (void)value.erase(
        value.begin(), std::find_if(value.begin(), value.end(), std::not1(std::ptr_fun<int, int>(IsWhitespace))));

      headers[name] = value;
    }
  }
}
}  // namespace

//--------------------------------------------------------------------------
TTV_ErrorCode ttv::WinInetHttpRequest::SendHttpRequest(const std::string& /*requestName*/, const std::string& origUrl,
  const std::vector<HttpParam>& requestHeaders, const uint8_t* requestBody, size_t requestBodySize,
  HttpRequestType httpReqType, uint timeOutInSecs, HttpRequestHeadersCallback headersCallback,
  HttpRequestCallback responseCallback, void* userData) {
  if (origUrl.size() == 0) {
    return TTV_EC_INVALID_HTTP_REQUEST_PARAMS;
  }

  std::string url = origUrl;

  // Extract the protocol name from the URL and determine if we need to use SSL
  bool useSsl = false;
  if (url.find("http://") == 0) {
    useSsl = false;
    url.assign(url.substr(sizeof("http://") - 1));
  } else if (url.find("https://") == 0) {
    useSsl = true;
    url.assign(url.substr(sizeof("https://") - 1));
  } else {
    return TTV_EC_INVALID_HTTP_REQUEST_PARAMS;
  }

  // Break up the URL into server URL and resource URL
  std::string serverUrl;
  std::string resourceUrl;
  size_t serverUrlLen = url.find_first_of('/');
  if (serverUrlLen == std::string::npos) {
    serverUrl = url;
  } else {
    serverUrl = url.substr(0, serverUrlLen);
    resourceUrl = url.substr(serverUrlLen + 1);
  }

  if (!responseCallback) {
    return TTV_EC_INVALID_HTTP_REQUEST_PARAMS;
  }

  TTV_ErrorCode ret = TTV_EC_HTTPREQUEST_ERROR;
  HINTERNET hInternet = InternetOpenW(L"Twitch", INTERNET_OPEN_TYPE_DIRECT, NULL, NULL, 0);

  if (hInternet) {
    BOOL enableDecoding = TRUE;
    (void)InternetSetOption(hInternet, INTERNET_OPTION_HTTP_DECODING, (LPVOID)&enableDecoding, sizeof(BOOL));

    DWORD timeoutMs = 1000 * timeOutInSecs;
    BOOL timeoutRet =
      InternetSetOption(hInternet, INTERNET_OPTION_CONNECT_TIMEOUT, (LPVOID)&timeoutMs, sizeof(timeoutMs));
    if (timeoutRet) {
      timeoutRet = InternetSetOption(hInternet, INTERNET_OPTION_SEND_TIMEOUT, (LPVOID)&timeoutMs, sizeof(timeoutMs));
      if (timeoutRet) {
        timeoutRet =
          InternetSetOption(hInternet, INTERNET_OPTION_RECEIVE_TIMEOUT, (LPVOID)&timeoutMs, sizeof(timeoutMs));
      }
    }
    if (timeoutRet) {
      INTERNET_PORT port = INTERNET_DEFAULT_HTTP_PORT;
      if (useSsl) {
        port = INTERNET_DEFAULT_HTTPS_PORT;
      }

      std::wstring wstr;
      ConvertUtf8ToUtf16(serverUrl, wstr);

      HINTERNET hConnect = InternetConnectW(hInternet, wstr.data(), port, NULL, NULL, INTERNET_SERVICE_HTTP, 0, NULL);

      if (hConnect) {
        const wchar_t* requestType = L"";

        switch (httpReqType) {
          case HTTP_GET_REQUEST:
            requestType = L"GET";
            break;
          case HTTP_PUT_REQUEST:
            requestType = L"PUT";
            break;
          case HTTP_POST_REQUEST:
            requestType = L"POST";
            break;
          case HTTP_DELETE_REQUEST:
            requestType = L"DELETE";
            break;
          default:
            TTV_ASSERT(false);
        }

        DWORD flags = 0;
        if (useSsl) {
          flags |= INTERNET_FLAG_SECURE;
        }

        ConvertUtf8ToUtf16(resourceUrl, wstr);

        HINTERNET hRequest = HttpOpenRequestW(hConnect, requestType, wstr.data(), HTTP_VERSIONW, L"", NULL, flags, 0);

        if (hRequest) {
          ConvertUtf8ToUtf16(BuildHttpHeader(requestHeaders), wstr);

          LPVOID requestBodyPtr = requestBodySize > 0 ? (LPVOID)requestBody : NULL;
          BOOL requestSent = HttpSendRequestW(hRequest, wstr.data(), static_cast<DWORD>(wstr.size()), requestBodyPtr,
            static_cast<DWORD>(requestBodySize));
          if (requestSent) {
            // Read the status code
            DWORD statusCode = 0;
            DWORD statusCodeSize = sizeof(statusCode);
            DWORD headerIndex = 0;
            HttpQueryInfoW(
              hRequest, HTTP_QUERY_STATUS_CODE | HTTP_QUERY_FLAG_NUMBER, &statusCode, &statusCodeSize, &headerIndex);

            bool notifyResponse = true;

            // Parse the headers
            if (headersCallback != nullptr) {
              // Read the size of the headers buffer
              DWORD headersSize = 0;
              HttpQueryInfoW(hRequest, HTTP_QUERY_RAW_HEADERS_CRLF, NULL, &headersSize, NULL);

              // Grab the raw header data
              std::vector<wchar_t> headerBuffer;
              headerBuffer.resize(headersSize / sizeof(wchar_t));
              HttpQueryInfoW(hRequest, HTTP_QUERY_RAW_HEADERS_CRLF, headerBuffer.data(), &headersSize, NULL);

              // Parse headers
              std::map<std::string, std::string> responseHeaders;
              ParseHeaders(headerBuffer, responseHeaders);
              notifyResponse = headersCallback(static_cast<uint>(statusCode), responseHeaders, userData);
            }

            // Parse the response
            if (notifyResponse) {
              const uint kBufferSize = 1024;
              char buffer[kBufferSize];

              BOOL keepReading = true;
              DWORD bytesRead = 0;

              std::vector<char> response;

              do {
                keepReading = InternetReadFile(hRequest, buffer, kBufferSize, &bytesRead);
                response.insert(response.cend(), buffer, buffer + bytesRead);
              } while (keepReading && bytesRead != 0);

              responseCallback(static_cast<uint>(statusCode), response, userData);
            }

            ret = TTV_EC_SUCCESS;
          }

          InternetCloseHandle(hRequest);
        }

        InternetCloseHandle(hConnect);
      }
    }

    InternetCloseHandle(hInternet);
  }

  return ret;
}
