#include "testhttprequest.h"

#include "twitchsdk/core/json/corejsonutil.h"
#include "twitchsdk/core/json/reader.h"
#include "twitchsdk/core/json/writer.h"
#include "twitchsdk/core/thread.h"
#include "twitchsdk/core/tracer.h"

#include <fstream>
#include <iostream>
#include <sstream>
#include <streambuf>

#include "gtest/gtest.h"

using namespace ttv;

int MockResponse::smNextMockRequestId(0);

MockResponse::MockResponse(const std::string& url)
    : mId(smNextMockRequestId),
      mRequestCount(0),
      mDelayMilliseconds(0),
      mUrl(url),
      mType(ttv::HTTP_GET_REQUEST),
      mStrictRequestParams(true),
      mStrictRequestRegex(false),
      mStrictHeaderParams(false),
      mStrictBody(false),
      mStrictJsonValues(false) {
  smNextMockRequestId++;
}

MockResponse& MockResponse::SetUrl(const std::string& url) {
  mUrl.SetUrl(url);
  return *this;
}

MockResponse& MockResponse::SetType(ttv::HttpRequestType type) {
  mType = type;
  return *this;
}

MockResponse& MockResponse::SetRequestBody(const std::string& body, bool enableStrictBody) {
  if (enableStrictBody) {
    StrictBody(true);
  }

  mRequestBody = body;
  return *this;
}

MockResponse& MockResponse::AddRequestRegex(const std::regex& regex, bool enableStrictRequestRegex) {
  if (enableStrictRequestRegex) {
    StrictRequestRegex(true);
  }

  mRequestRegex.push_back(regex);
  return *this;
}

MockResponse& MockResponse::AddRequestParam(const std::string& name, uint32_t value) {
  mUrl.SetParam(name, value);
  return *this;
}

MockResponse& MockResponse::AddRequestParam(const std::string& name, const std::string& value) {
  mUrl.SetParam(name, value);
  return *this;
}

MockResponse& MockResponse::AddRequestHeader(const std::string& name, const std::string& value) {
  mHeaderParams[name] = value;
  return *this;
}

MockResponse& MockResponse::AddJsonValue(
  const std::vector<std::string>& key, const ttv::json::Value& value, bool enableStrictJsonValue) {
  if (enableStrictJsonValue) {
    StrictJsonValue(true);
  }

  mJsonValues[key] = value;

  return *this;
}

MockResponse& MockResponse::StrictRequestParams(bool strict) {
  mStrictRequestParams = strict;
  return *this;
}

MockResponse& MockResponse::StrictHeaderParams(bool strict) {
  mStrictHeaderParams = strict;
  return *this;
}

MockResponse& MockResponse::StrictBody(bool strict) {
  mStrictBody = strict;
  return *this;
}

MockResponse& MockResponse::StrictRequestRegex(bool strict) {
  mStrictRequestRegex = strict;
  return *this;
}

MockResponse& MockResponse::StrictJsonValue(bool strict) {
  mStrictJsonValues = strict;
  return *this;
}

MockResponse& MockResponse::SetStatusCode(uint statusCode) {
  mResponseContent.statusCode = statusCode;
  return *this;
}

MockResponse& MockResponse::SetResponseHandler(const MockResponseHandler& handler) {
  mResponseHandler = handler;
  return *this;
}

MockResponse& MockResponse::SetResponseBody(const ttv::json::Value& data) {
  ttv::json::FastWriter writer;
  auto str = writer.write(data);
  mResponseContent.body.assign(str.begin(), str.end());
  return *this;
}

MockResponse& MockResponse::SetResponseBody(const char* data) {
  if (data != nullptr) {
    size_t len = strlen(data);
    mResponseContent.body.assign(data, data + len);
  } else {
    mResponseContent.body.clear();
  }

  return *this;
}

MockResponse& MockResponse::SetResponseBody(const std::string& data) {
  mResponseContent.body.assign(data.begin(), data.end());
  return *this;
}

MockResponse& MockResponse::SetResponseBody(const std::vector<char>& data) {
  mResponseContent.body = data;
  return *this;
}

MockResponse& MockResponse::SetResponseBodyFromFile(const std::string& path) {
  std::ifstream response(path, std::ifstream::in);

  const bool opened = response.is_open();
  if (!opened) {
    std::cout << "TestHttpRequest: Could not open file \"" << path << "\"" << std::endl;

    EXPECT_TRUE(opened);
  }

  std::string data((std::istreambuf_iterator<char>(response)), std::istreambuf_iterator<char>());
  mResponseContent.body.assign(data.begin(), data.end());

  return *this;
}

MockResponse& MockResponse::AddResponseHeader(const std::string& name, const std::string& value) {
  mResponseContent.headers[name] = value;
  return *this;
}

MockResponse& MockResponse::SetDelay(uint64_t milliseconds) {
  mDelayMilliseconds = milliseconds;
  return *this;
}

std::shared_ptr<MockResponse> MockResponse::Done() {
  return shared_from_this();
}

bool MockResponse::IsMatch(const std::string& requestUrl, ttv::HttpRequestType type,
  const std::vector<ttv::HttpParam>& headerParams, const std::string& requestBody,
  const ttv::json::Value& jroot) const {
  ParamMap params;

  for (const auto& header : headerParams) {
    params[header.paramName] = header.paramValue;
  }

  return IsMatch(requestUrl, type, params, requestBody, jroot);
}

bool MockResponse::IsMatch(const std::string& requestUrl, ttv::HttpRequestType type, const ParamMap& headerParams,
  const std::string& requestBody, const ttv::json::Value& jroot) const {
  Uri url(requestUrl);

  if (mUrl.GetProtocol() != url.GetProtocol() || mUrl.GetHostName() != url.GetHostName() ||
      mUrl.GetPath() != url.GetPath() || mType != type) {
    return false;
  }

  if (mStrictRequestParams) {
    if (mUrl != url) {
      return false;
    }
  }

  // NOTE: We don't look for the exact set of headers, we just check to make sure this response is at least a superset
  // of the provided values
  if (mStrictHeaderParams) {
    for (const auto& header : headerParams) {
      auto piter = mHeaderParams.find(header.first);
      if (piter == mHeaderParams.end() || piter->second != header.second) {
        return false;
      }
    }
  }

  if (mStrictBody) {
    if (mRequestBody != requestBody) {
      return false;
    }
  }

  if (mStrictJsonValues) {
    for (const auto& kvp : mJsonValues) {
      auto lookup = FindValueByPath(jroot, kvp.first);

      if (!lookup.HasValue()) {
        return false;
      }

      if (lookup.Value() != kvp.second) {
        return false;
      }
    }
  }

  if (mStrictRequestRegex) {
    for (const auto& regex : mRequestRegex) {
      if (!std::regex_search(requestBody, regex)) {
        return false;
      }
    }
  }

  return true;
}

TTV_ErrorCode MockResponse::SendHttpRequest(const std::string& /*requestName*/, const std::string& url,
  const std::vector<ttv::HttpParam>& headerParams, const uint8_t* requestBody, size_t requestBodySize,
  ttv::HttpRequestType httpReqType, uint /*timeOutInSecs*/, ttv::HttpRequestHeadersCallback headersCallback,
  ttv::HttpRequestCallback responseCallback, void* userData) {
  mRequestCount++;

  if (GetDelayMilliseconds() > 0) {
    ttv::Sleep(GetDelayMilliseconds());
  }

  if (mResponseHandler != nullptr) {
    const std::string& body = std::string(reinterpret_cast<const char*>(requestBody), requestBodySize);
    mResponseContent = mResponseHandler(url, httpReqType, headerParams, body);
  }

  bool shouldProceed = headersCallback(mResponseContent.statusCode, mResponseContent.headers, userData);
  if (shouldProceed) {
    responseCallback(mResponseContent.statusCode, mResponseContent.body, userData);
  }

  return TTV_EC_SUCCESS;
}

void MockResponse::ResetRequestCount() {
  mRequestCount = 0;
}

void MockResponse::AssertRequestsMade() {
  ASSERT_NE(mRequestCount, 0);
}

void MockResponse::AssertNoRequestsMade() {
  ASSERT_EQ(mRequestCount, 0);
}

TestHttpRequest::TestHttpRequest() : mValidate(false) {
  mOffline = false;
  (void)CreateMutex(mResponseMutex, "TestHttpRequest");
}

TestHttpRequest::~TestHttpRequest() {}

MockResponse& TestHttpRequest::AddResponse(const std::string& url) {
  ttv::trace::Message("TestHttpRequest", MessageLevel::Debug, "Registered response: %s", url.c_str());

  std::shared_ptr<MockResponse> mock = std::make_shared<MockResponse>(url);

  ttv::AutoMutex lock(mResponseMutex.get());
  mHttpResponses[mock->GetId()] = mock;
  return *mock;
}

void TestHttpRequest::RemoveResponse(std::shared_ptr<MockResponse> mock) {
  ttv::AutoMutex lock(mResponseMutex.get());
  auto iter = mHttpResponses.find(mock->GetId());

  if (iter != mHttpResponses.end()) {
    mHttpResponses.erase(iter);
  }
}

std::map<int, std::shared_ptr<MockResponse>>::iterator TestHttpRequest::FindResponse(const std::string& url,
  ttv::HttpRequestType type, const std::vector<ttv::HttpParam>& headerParams, const std::string& requestBody) {
  // We want to sort the potential matches by priority.  We don't want something with fewer restrictions to override
  // something with more restrictions
  // example:
  // mock1.SetUrl("gql.twitch.tv/gql")
  // mock2.SetUrl("gql.twitch.tv/gql").SetResponseBody("...")
  // it should try finding mock2 first since it is more strict than mock1.

  std::vector<std::shared_ptr<MockResponse>> prioritySorted;
  for (const auto& kvp : mHttpResponses) {
    prioritySorted.push_back(kvp.second);
  }

  std::sort(prioritySorted.begin(), prioritySorted.end(),
    [](const std::shared_ptr<MockResponse>& a, const std::shared_ptr<MockResponse>& b) {
      if (a->IsStrictBody() != b->IsStrictBody()) {
        return a->IsStrictBody();
      }

      if (a->IsStrictJsonValue() != b->IsStrictJsonValue()) {
        return a->IsStrictJsonValue();
      }

      if (a->IsStrictRequestRegex() != b->IsStrictRequestRegex()) {
        return a->IsStrictRequestRegex();
      }

      if (a->IsStrictHeaderParams() != b->IsStrictHeaderParams()) {
        return a->IsStrictHeaderParams();
      }

      if (a->IsStrictRequestParams() != b->IsStrictRequestParams()) {
        return a->IsStrictRequestParams();
      }

      return false;
    });

  // Find the matching request
  ttv::json::Value jsonVal;
  if (!ttv::ParseDocument(requestBody, jsonVal)) {
    jsonVal = ttv::json::Value();
  }

  auto find_itr = std::find_if(prioritySorted.begin(), prioritySorted.end(),
    [&url, &type, &requestBody, &headerParams, &jsonVal](const std::shared_ptr<MockResponse>& response) {
      return response->IsMatch(url, type, headerParams, requestBody, jsonVal);
    });

  if (find_itr == prioritySorted.end()) {
    return mHttpResponses.end();
  }

  return mHttpResponses.find((*find_itr)->GetId());
}

TTV_ErrorCode TestHttpRequest::SendHttpRequest(const std::string& requestName, const std::string& url,
  const std::vector<ttv::HttpParam>& headerParams, const uint8_t* requestBody, size_t requestBodySize,
  ttv::HttpRequestType httpReqType, uint timeOutInSecs, ttv::HttpRequestHeadersCallback headersCallback,
  ttv::HttpRequestCallback responseCallback, void* userData) {
  if (mOffline) {
    return TTV_EC_API_REQUEST_FAILED;
  }

  const std::string& body = std::string(reinterpret_cast<const char*>(requestBody), requestBodySize);
  std::shared_ptr<MockResponse> response;
  {
    ttv::AutoMutex lock(mResponseMutex.get());

    // Find the matching request
    auto iter = FindResponse(url, httpReqType, headerParams, body);
    if (iter != mHttpResponses.end()) {
      response = iter->second;
    }
  }

  // Found one
  if (response != nullptr) {
    ttv::trace::Message("TestHttpRequest", MessageLevel::Debug, "Url requested: %s", url.c_str());

    response->SendHttpRequest(requestName, url, headerParams, requestBody, requestBodySize, httpReqType, timeOutInSecs,
      headersCallback, responseCallback, userData);
  }
  // Didn't find one
  else {
    ttv::trace::Message("TestHttpRequest", MessageLevel::Debug, "REQUESTED URL NOT FOUND: %s", url.c_str());

    if (mValidate) {
      EXPECT_FALSE(true);
    }

    ParamMap headers;
    bool shouldProceed = headersCallback(404, headers, userData);
    if (shouldProceed) {
      std::vector<char> body;
      responseCallback(404, body, userData);
    }
  }

  return TTV_EC_SUCCESS;
}
