/****************************************************************************
 * Twitch SDK
 *
 * This software is supplied under the terms of a license agreement with
 * Twitch Interactive, Inc. and may not be copied or used except in accordance
 * with the terms of that agreement
 *
 * Copyright (c) 2012-2016 Twitch Interactive, Inc.
 ***************************************************************************/

#include "fixtures/corebasetest.h"
#include "testutilities.h"
#include "twitchsdk/core/task/task.h"
#include "twitchsdk/core/task/taskrunner.h"

#include <functional>

#include "gtest/gtest.h"

using namespace ttv;
using namespace ttv::test;

namespace {
class TestTask : public Task {
 public:
  using Callback = std::function<void(TestTask* source, TTV_ErrorCode ec)>;

 public:
  TestTask(TTV_ErrorCode result, Callback callback) : Task(nullptr, nullptr), mCallback(callback), mResult(result) {}

  virtual void Run() override {
    // NOOP
  }

  virtual void OnComplete() override {
    if (mCallback != nullptr) {
      if (mAborted) {
        mResult = TTV_EC_REQUEST_ABORTED;
      }

      mCallback(this, mResult);
    }
  }

 protected:
  virtual const char* GetTaskName() const override { return "TestTask"; }

 private:
  Callback mCallback;
  TTV_ErrorCode mResult;
};

void TestPagedRequestFetcher(
  std::shared_ptr<TaskRunner> taskRunner, std::function<void()> updateFunc, std::vector<TTV_ErrorCode> taskResults) {
  auto fetcher = std::make_shared<PagedRequestFetcher>();

  TTV_ErrorCode expectedErrorCode = TTV_EC_SUCCESS;
  bool allDone = false;

  TestTask::Callback taskCallback = [&expectedErrorCode, fetcher, &taskResults](
                                      TestTask* /*source*/, TTV_ErrorCode ec) {
    ASSERT_EQ(ec, expectedErrorCode);

    std::string cursor;

    if (TTV_SUCCEEDED(ec)) {
      if (!taskResults.empty()) {
        cursor = "next";
      }
    }

    fetcher->FetchComplete(ec, cursor);
  };

  PagedRequestFetcher::CreateTaskCallback createCallback = [&expectedErrorCode, &taskResults, taskCallback,
                                                             &taskRunner](const std::string& /*cursor*/,
                                                             std::shared_ptr<Task>& task) -> TTV_ErrorCode {
    task.reset();

    if (!taskResults.empty()) {
      expectedErrorCode = taskResults[0];
      taskResults.erase(taskResults.begin());

      std::shared_ptr<TestTask> fetchTask = std::make_shared<TestTask>(expectedErrorCode, taskCallback);
      taskRunner->AddTask(fetchTask);
      task = fetchTask;
    } else {
      task = nullptr;
    }

    return TTV_EC_SUCCESS;
  };

  PagedRequestFetcher::CompleteCallback completeCallback = [&expectedErrorCode, &allDone](TTV_ErrorCode ec) {
    ASSERT_EQ(ec, expectedErrorCode);

    allDone = true;
  };

  TTV_ErrorCode ec = fetcher->Start("", createCallback, completeCallback);
  ASSERT_TRUE(TTV_SUCCEEDED(ec));

  auto checkFunc = [&allDone]() { return allDone; };

  WaitUntilResultWithPollTask(1000, checkFunc, updateFunc);
}
}  // namespace

TEST_F(CoreBaseTest, PagedRequestFetcher_SinglePage_Success) {
  auto taskRunner = CreateTaskRunner();

  std::vector<TTV_ErrorCode> results;
  results.push_back(TTV_EC_SUCCESS);
  TestPagedRequestFetcher(taskRunner, GetDefaultUpdateFunc(), results);
}

TEST_F(CoreBaseTest, PagedRequestFetcher_SinglePage_Fail) {
  auto taskRunner = CreateTaskRunner();

  std::vector<TTV_ErrorCode> results;
  results.push_back(TTV_EC_INVALID_JSON);
  TestPagedRequestFetcher(taskRunner, GetDefaultUpdateFunc(), results);
}

TEST_F(CoreBaseTest, PagedRequestFetcher_MultiePage_Success) {
  auto taskRunner = CreateTaskRunner();

  std::vector<TTV_ErrorCode> results;
  results.push_back(TTV_EC_SUCCESS);
  results.push_back(TTV_EC_SUCCESS);
  results.push_back(TTV_EC_SUCCESS);
  TestPagedRequestFetcher(taskRunner, GetDefaultUpdateFunc(), results);
}

TEST_F(CoreBaseTest, PagedRequestFetcher_MultiPage_Fail) {
  auto taskRunner = CreateTaskRunner();

  std::vector<TTV_ErrorCode> results;
  results.push_back(TTV_EC_SUCCESS);
  results.push_back(TTV_EC_INVALID_JSON);
  results.push_back(TTV_EC_SUCCESS);
  TestPagedRequestFetcher(taskRunner, GetDefaultUpdateFunc(), results);
}

TEST_F(CoreBaseTest, Uri) {
  {
    Uri uri("http://api.twitch.tv/hello/blah?p1=1&p2=2");

    ASSERT_TRUE(uri.GetProtocol() == "http");
    ASSERT_TRUE(uri.GetHostName() == "api.twitch.tv");
    ASSERT_TRUE(uri.GetPath() == "/hello/blah");
    ASSERT_TRUE(uri.GetParams().size() == 2);
    ASSERT_TRUE(uri.GetParams().find("p1") != uri.GetParams().end());
    ASSERT_TRUE(uri.GetParams()["p1"] == "1");
    ASSERT_TRUE(uri.GetParams().find("p2") != uri.GetParams().end());
    ASSERT_TRUE(uri.GetParams()["p2"] == "2");
  }

  {
    Uri uri("http://myhost/hello/blah?");

    ASSERT_TRUE(uri.GetProtocol() == "http");
    ASSERT_TRUE(uri.GetHostName() == "myhost");
    ASSERT_TRUE(uri.GetPath() == "/hello/blah");
    ASSERT_TRUE(uri.GetParams().size() == 0);
  }

  {
    Uri uri("myhost");

    ASSERT_TRUE(uri.GetProtocol() == "");
    ASSERT_TRUE(uri.GetHostName() == "myhost");
    ASSERT_TRUE(uri.GetPath() == "");
    ASSERT_TRUE(uri.GetParams().size() == 0);
  }

  {
    Uri uri("myhost/hello/blah");

    ASSERT_TRUE(uri.GetProtocol() == "");
    ASSERT_TRUE(uri.GetHostName() == "myhost");
    ASSERT_TRUE(uri.GetPath() == "/hello/blah");
    ASSERT_TRUE(uri.GetParams().size() == 0);
  }

  {
    Uri uri("myhost?p1=1&p2=2");

    ASSERT_TRUE(uri.GetProtocol() == "");
    ASSERT_TRUE(uri.GetHostName() == "myhost");
    ASSERT_TRUE(uri.GetPath() == "");
    ASSERT_TRUE(uri.GetParams().size() == 2);
    ASSERT_TRUE(uri.GetParams().find("p1") != uri.GetParams().end());
    ASSERT_TRUE(uri.GetParams()["p1"] == "1");
    ASSERT_TRUE(uri.GetParams().find("p2") != uri.GetParams().end());
    ASSERT_TRUE(uri.GetParams()["p2"] == "2");
  }
}
