/****************************************************************************
 * Twitch SDK
 *
 * This software is supplied under the terms of a license agreement with
 * Twitch Interactive, Inc. and may not be copied or used except in accordance
 * with the terms of that agreement
 *
 * Copyright (c) 2012-2016 Twitch Interactive, Inc.
 ***************************************************************************/

#include "twitchsdk/core/internal/pch.h"

#pragma comment(lib, "msxml6.lib")

#include "twitchsdk/core/assertion.h"
#include "twitchsdk/core/httprequestutils.h"
#include "twitchsdk/core/ixhr2httprequest.h"

#include <wrl.h>

#include <algorithm>
#include <codecvt>
#ifdef TTV_PLATFORM_STORE_APP
#include <MsXml6.h>
#elif defined(TTV_PLATFORM_X1_GAME)
#include <ixmlhttprequest2.h>
#endif

using namespace Microsoft::WRL;

namespace {
int IsWhitespace(int ch) {
  return ch == ' ' || ch == ' ';
}
}  // namespace

//--------------------------------------------------------------------------------------
// XHRPostStream - Helper class that implements the ISequentialStream interface
//--------------------------------------------------------------------------------------
class XHRPostStream : public Microsoft::WRL::RuntimeClass<RuntimeClassFlags<ClassicCom>, ISequentialStream> {
 public:  // ISequentialStream
  STDMETHODIMP Read(void* dstBuffer, ULONG bytesToRead, ULONG* bytesRead);
  STDMETHODIMP Write(const void* /*srcBuffer*/, ULONG /*bytesToWrite*/, ULONG* /*bytesWritten*/) {
    return STG_E_ACCESSDENIED;
  }

 public:
  void Init(const std::string& postStream);
  size_t GetLength() const;

 private:
  XHRPostStream() : mReadOffset(0) {}
  ~XHRPostStream() {}

 private:
  std::string mPostStream;
  size_t mReadOffset;
  friend Microsoft::WRL::ComPtr<XHRPostStream> Microsoft::WRL::Details::Make<XHRPostStream>();
};

//--------------------------------------------------------------------------
void XHRPostStream::Init(const std::string& postStream) {
  mPostStream = postStream;
}

//--------------------------------------------------------------------------
size_t XHRPostStream::GetLength() const {
  return mPostStream.size();
}

//--------------------------------------------------------------------------
STDMETHODIMP XHRPostStream::Read(void* dstBuffer, ULONG bytesToRead, ULONG* bytesRead) {
  HRESULT hr = S_OK;

  size_t bytesToCopy = mPostStream.size() - mReadOffset;
  if (bytesToRead < bytesToCopy) {
    bytesToCopy = bytesToRead;
  }

  if (bytesToCopy > 0) {
    memcpy(dstBuffer, mPostStream.data() + mReadOffset, bytesToCopy);
    mReadOffset += bytesToCopy;
  }

  if (bytesToCopy < bytesToRead) {
    hr = S_FALSE;  // End of stream
  }

  if (bytesRead != NULL) {
    *bytesRead = static_cast<ULONG>(bytesToCopy);
  }

  return hr;
}

//--------------------------------------------------------------------------------------
// XHRCallback - Helper class that implements the IXMLHTTPRequest2Callback interface
//--------------------------------------------------------------------------------------
class XHRCallback : public RuntimeClass<RuntimeClassFlags<ClassicCom>, IXMLHTTPRequest2Callback> {
 public:  // IXMLHTTPRequest2Callback
  STDMETHODIMP OnRedirect(IXMLHTTPRequest2* xhr, const WCHAR* redirectUrl);
  STDMETHODIMP OnHeadersAvailable(IXMLHTTPRequest2* xhr, DWORD status, const WCHAR* statusText);
  STDMETHODIMP OnDataAvailable(IXMLHTTPRequest2* xhr, ISequentialStream* responseStream);
  STDMETHODIMP OnResponseReceived(IXMLHTTPRequest2* xhr, ISequentialStream* responseStream);
  STDMETHODIMP OnError(IXMLHTTPRequest2* xhr, HRESULT hr);

 public:
  void EnableHeaderParsing(bool enable) { mParseHeaders = enable; }
  void WaitForComplete(
    uint timeoutInSecs, uint& status, std::vector<char>& response, std::map<std::string, std::string>& headers);

 private:
  XHRCallback();
  ~XHRCallback();

  STDMETHODIMP RuntimeClassInitialize();
  friend HRESULT Details::MakeAndInitialize<XHRCallback, XHRCallback>(XHRCallback**);

 private:
  HANDLE mCompleteEvent;
  HRESULT mHr;
  DWORD mStatus;
  std::vector<char> mResponse;
  std::map<std::string, std::string> mHeaders;
  bool mParseHeaders;
};

//--------------------------------------------------------------------------
XHRCallback::XHRCallback() : mCompleteEvent(NULL), mHr(S_OK), mStatus(0), mParseHeaders(true) {}

//--------------------------------------------------------------------------
XHRCallback::~XHRCallback() {
  if (mCompleteEvent) {
    CloseHandle(mCompleteEvent);
    mCompleteEvent = NULL;
  }
}

//--------------------------------------------------------------------------
STDMETHODIMP XHRCallback::RuntimeClassInitialize() {
  mCompleteEvent = CreateEventEx(NULL, NULL, CREATE_EVENT_MANUAL_RESET, EVENT_ALL_ACCESS);

  HRESULT hr = S_OK;
  if (mCompleteEvent == NULL) {
    hr = HRESULT_FROM_WIN32(GetLastError());
  }

  return hr;
}

//--------------------------------------------------------------------------
STDMETHODIMP XHRCallback::OnRedirect(IXMLHTTPRequest2* /*xhr*/, const WCHAR* /*redirectUrl*/) {
  return S_OK;
}

//--------------------------------------------------------------------------
STDMETHODIMP XHRCallback::OnHeadersAvailable(IXMLHTTPRequest2* xhr, DWORD status, const WCHAR* /*statusText*/) {
  mStatus = status;

  if (!mParseHeaders) {
    return S_OK;
  }

  WCHAR* buffer = nullptr;
  HRESULT hr = xhr->GetAllResponseHeaders(&buffer);
  if (SUCCEEDED(hr)) {
    const char* kLineDelimeter = "\r\n";
    static const size_t kLineDelimeterLength = strlen(kLineDelimeter);

    std::wstring_convert<std::codecvt_utf8_utf16<wchar_t>> converter;
    std::string converted = converter.to_bytes(buffer);

    // Break into lines
    for (;;) {
      size_t index = converted.find(kLineDelimeter);

      // No more lines
      if (index == std::string::npos || index == 0) {
        break;
      }

      std::string line = converted.substr(0, index);
      converted = converted.substr(index + kLineDelimeterLength);

      // Break the line into key-value pair
      index = line.find(':');
      if (index != std::string::npos) {
        std::string name = line.substr(0, index);
        std::string value = line.substr(index + 1);

        // Trim whitespace from the start of the value
        (void)value.erase(
          value.begin(), std::find_if(value.begin(), value.end(), std::not1(std::ptr_fun<int, int>(IsWhitespace))));

        mHeaders[name] = value;
      }
    }
  }

  if (buffer != nullptr) {
    CoTaskMemFree(buffer);
    buffer = nullptr;
  }

  return S_OK;
}

//--------------------------------------------------------------------------
STDMETHODIMP XHRCallback::OnDataAvailable(IXMLHTTPRequest2* /*xhr*/, ISequentialStream* /*responseStream*/) {
  return S_OK;
}

//--------------------------------------------------------------------------
STDMETHODIMP XHRCallback::OnResponseReceived(IXMLHTTPRequest2* /*xhr*/, ISequentialStream* responseStream) {
  mHr = S_OK;
  if (responseStream) {
    ULONG bytesRead = 0;
    const ULONG kMaxBufferLength = 1024;
    uint8_t buffer[kMaxBufferLength];
    do {
      mHr = responseStream->Read(buffer, kMaxBufferLength, &bytesRead);
      if (FAILED(mHr)) {
        break;
      }
      if (bytesRead != 0) {
        mResponse.insert(mResponse.cend(), buffer, buffer + bytesRead);
      }
    } while (bytesRead != 0);
  } else {
    mHr = E_INVALIDARG;
  }

  SetEvent(mCompleteEvent);
  return mHr;
}

//--------------------------------------------------------------------------
STDMETHODIMP XHRCallback::OnError(IXMLHTTPRequest2* /*xhr*/, HRESULT hr) {
  mHr = hr;
  SetEvent(mCompleteEvent);
  return S_OK;
}

//--------------------------------------------------------------------------
void XHRCallback::WaitForComplete(
  uint timeoutInSecs, uint& status, std::vector<char>& response, std::map<std::string, std::string>& headers) {
  DWORD timeoutInMs = (timeoutInSecs == 0 || timeoutInSecs == INFINITE) ? INFINITE : timeoutInSecs * 1000;
  DWORD waitRet = WaitForSingleObjectEx(mCompleteEvent, timeoutInMs, FALSE);

  HRESULT hr = S_OK;
  if (waitRet == WAIT_FAILED) {
    hr = HRESULT_FROM_WIN32(GetLastError());
  } else if (waitRet != WAIT_OBJECT_0) {
    hr = E_ABORT;
  }

  if (FAILED(mHr)) {
    hr = mHr;
  }

  status = 0;
  if (SUCCEEDED(hr)) {
    status = mStatus;
    response = mResponse;
    headers = mHeaders;
  }
}

//--------------------------------------------------------------------------
// Implementation of HTTP Request functions
//--------------------------------------------------------------------------

//--------------------------------------------------------------------
TTV_ErrorCode ttv::IXhr2HttpRequest::ThreadInit() {
  HRESULT hr = CoInitializeEx(NULL, COINIT_MULTITHREADED);
  return SUCCEEDED(hr) ? TTV_EC_SUCCESS : TTV_EC_COINITIALIZE_FAIED;
}

//--------------------------------------------------------------------------
TTV_ErrorCode ttv::IXhr2HttpRequest::SendHttpRequest(const std::string& /*requestName*/, const std::string& origUrl,
  const std::vector<HttpParam>& requestHeaders, const std::string& requestBody, HttpRequestType httpReqType,
  uint timeOutInSecs, HttpRequestHeadersCallback headersCallback, HttpRequestCallback responseCallback,
  void* userData) {
  std::string url = origUrl;

  ComPtr<IXMLHTTPRequest2> xhr;
  HRESULT hr = CoCreateInstance(CLSID_FreeThreadedXMLHTTP60, NULL, CLSCTX_INPROC_SERVER, IID_PPV_ARGS(&xhr));
  if (SUCCEEDED(hr)) {
    ComPtr<XHRCallback> xhrCallback;
    hr = MakeAndInitialize<XHRCallback>(&xhrCallback);

    if (SUCCEEDED(hr)) {
      xhrCallback->EnableHeaderParsing(headersCallback != nullptr);

      const wchar_t* requestType = L"";

      switch (httpReqType) {
        case HTTP_GET_REQUEST:
          requestType = L"GET";
          break;
        case HTTP_PUT_REQUEST:
          requestType = L"PUT";
          break;
        case HTTP_POST_REQUEST:
          requestType = L"POST";
          break;
        case HTTP_DELETE_REQUEST:
          requestType = L"DELETE";
          requestBody = BuildRequestParams(requestParams);
          break;
        default:
          TTV_ASSERT(false);
      }

      std::wstring_convert<std::codecvt_utf8_utf16<wchar_t>> converter;
      hr = xhr->Open(requestType, converter.from_bytes(url).c_str(), xhrCallback.Get(), NULL, NULL, NULL, NULL);
      if (SUCCEEDED(hr)) {
        if (requestHeaders.size() > 0) {
          auto it = requestHeaders.begin();
          while (it != requestHeaders.end() && SUCCEEDED(hr)) {
            hr = xhr->SetRequestHeader(
              converter.from_bytes(it->paramName).c_str(), converter.from_bytes(it->paramValue).c_str());
            ++it;
          }
        }

        ComPtr<XHRPostStream> xhtrPostStream = Make<XHRPostStream>();
        xhtrPostStream->Init(requestBody);
        size_t postStreamLength = xhtrPostStream->GetLength();
        hr = xhr->Send(xhtrPostStream.Get(), postStreamLength);
        if (SUCCEEDED(hr)) {
          std::vector<char> response;
          std::map<std::string, std::string> responseHeaders;
          uint status = 0;
          xhrCallback->WaitForComplete(timeOutInSecs, status, response, responseHeaders);

          bool notifyResponse = true;
          if (headersCallback != nullptr) {
            notifyResponse = headersCallback(status, responseHeaders, userData);
          }

          if (notifyResponse) {
            responseCallback(status, response, userData);
          }
        }
      }
    }
  }

  if (FAILED(hr)) {
    xhr->Abort();
  }
  return SUCCEEDED(hr) ? TTV_EC_SUCCESS : TTV_EC_HTTPREQUEST_ERROR;
}

//--------------------------------------------------------------------
TTV_ErrorCode ttv::IXhr2HttpRequest::ThreadShutdown() {
  CoUninitialize();
  return TTV_EC_SUCCESS;
}
