#include "testutilities.h"

#include "twitchsdk/core/thread.h"
#include "twitchsdk/core/timer.h"

#include <memory>

#include "gtest/gtest.h"

namespace {
const uint64_t kSleepTime = 100;
}

using namespace ttv;

bool ttv::test::WaitUntilResultWithPollTask(
  uint waitUntilMilliseconds, std::function<bool()> checkForResults, std::function<void(uint64_t)> pollFunction) {
  WaitForExpiry waitUntilTimer;
  waitUntilTimer.Set(waitUntilMilliseconds);

  bool gotResult = false;
  for (;;) {
    pollFunction(kSleepTime);
    gotResult = checkForResults();

    if (gotResult || waitUntilTimer.Check(true)) {
      break;
    }

    ttv::Sleep(kSleepTime);
  }

  waitUntilTimer.Clear();
  return gotResult;
}

bool ttv::test::WaitUntilResultWithPollTask(
  uint waitUntilMilliseconds, std::function<bool()> checkForResults, std::function<void()> pollFunction) {
  WaitForExpiry waitUntilTimer;
  waitUntilTimer.Set(waitUntilMilliseconds);

  bool gotResult = false;
  for (;;) {
    pollFunction();
    gotResult = checkForResults();

    if (gotResult || waitUntilTimer.Check(true)) {
      break;
    }

    ttv::Sleep(100);
  }

  waitUntilTimer.Clear();
  return gotResult;
}

bool ttv::test::WaitUntilResultWithPollTask(const std::shared_ptr<TestEventScheduler>& eventScheduler,
  uint waitUntilMilliseconds, std::function<bool()> checkForResults, std::function<void()> pollFunction) {
  WaitForExpiry waitUntilTimer;
  waitUntilTimer.Set(waitUntilMilliseconds);

  bool gotResult = false;
  for (;;) {
    pollFunction();
    gotResult = checkForResults();

    if (gotResult || waitUntilTimer.Check(true)) {
      break;
    }

    if (eventScheduler != nullptr) {
      eventScheduler->WaitForEventWithTimeout(100);
    }
  }

  waitUntilTimer.Clear();
  return gotResult;
}

void ttv::test::InitializeModule(const std::shared_ptr<ttv::IModule>& module) {
  ASSERT_EQ(module->GetState(), IModule::State::Uninitialized);

  bool callbackReceived = false;
  TTV_ErrorCode ec = module->Initialize([&callbackReceived](TTV_ErrorCode ec) {
    ASSERT_TRUE(TTV_SUCCEEDED(ec));
    callbackReceived = true;
  });
  ASSERT_TRUE(TTV_SUCCEEDED(ec));

  std::function<void()> updateFunction = [module]() { module->Update(); };

  std::function<bool()> checkFunction = [module, &callbackReceived]() {
    return callbackReceived && (module->GetState() == IModule::State::Initialized);
  };

  ASSERT_TRUE(ttv::test::WaitUntilResultWithPollTask(1000, checkFunction, updateFunction));
}

void ttv::test::ShutdownModules(const std::vector<std::shared_ptr<ttv::IModule>>& modules) {
  // Copy the list
  auto list = modules;

  // Shutdown modules in the order they were given
  for (auto iter = list.begin(); iter != list.end();) {
    std::shared_ptr<IModule> module = *iter;

    if (module->GetState() != IModule::State::Uninitialized) {
      std::function<void()> updateFunction = [list]() {
        for (auto m : list) {
          m->Update();
        }
      };

      bool callbackReceived = false;
      std::function<bool()> checkForShutdown = [module, &callbackReceived]() {
        if (module->GetState() == IModule::State::Initialized) {
          module->Shutdown([&callbackReceived](TTV_ErrorCode ec) {
            EXPECT_TRUE(TTV_SUCCEEDED(ec));
            callbackReceived = true;
          });
        }

        return callbackReceived && (module->GetState() == IModule::State::Uninitialized);
      };

      EXPECT_TRUE(ttv::test::WaitUntilResultWithPollTask(5000, checkForShutdown, updateFunction));
      EXPECT_EQ(module->GetState(), IModule::State::Uninitialized);
    }

    iter = list.erase(iter);
  }
}

void ttv::test::InitializeComponent(const std::shared_ptr<Component>& instance) {
  ttv::test::InitializeComponent(instance, 1000, TTV_EC_SUCCESS);
}

void ttv::test::InitializeComponent(
  const std::shared_ptr<Component>& instance, uint timeout, TTV_ErrorCode expectedError) {
  auto taskRunner = instance->GetTaskRunner();
  ASSERT_NE(taskRunner, nullptr);

  TTV_ErrorCode ec = instance->Initialize();
  ASSERT_EQ(ec, expectedError);

  if (TTV_FAILED(ec)) {
    return;
  }

  std::function<void()> updateFunc = [instance, taskRunner]() {
    taskRunner->PollTasks();
    instance->Update();
  };

  std::function<bool()> checkFunc = [instance]() { return instance->GetState() == Component::State::Initialized; };

  ASSERT_TRUE(ttv::test::WaitUntilResultWithPollTask(timeout, checkFunc, updateFunc));
}

void ttv::test::ShutdownComponent(const std::shared_ptr<Component>& instance) {
  ttv::test::ShutdownComponent(instance, 1000);
}

void ttv::test::ShutdownComponent(const std::shared_ptr<Component>& instance, uint timeout) {
  if (instance != nullptr) {
    auto taskRunner = instance->GetTaskRunner();

    TTV_ErrorCode ec = instance->Shutdown();
    if (TTV_FAILED(ec)) {
      ASSERT_TRUE(TTV_SUCCEEDED(ec));
    }

    std::function<void()> updateFunc = [instance, taskRunner]() {
      taskRunner->PollTasks();
      instance->Update();
    };

    std::function<bool()> checkFunc = [instance]() { return instance->GetState() == Component::State::Uninitialized; };

    bool finished = ttv::test::WaitUntilResultWithPollTask(timeout, checkFunc, updateFunc);
    if (!finished) {
      ASSERT_TRUE(finished);
    }
  }
}

void ttv::test::ShutdownTaskRunner(const std::shared_ptr<TaskRunner>& taskRunner) {
  if (taskRunner != nullptr) {
    std::function<void()> updateFunc = [taskRunner]() { taskRunner->Shutdown(); };

    ASSERT_TRUE(
      ttv::test::WaitUntilResultWithPollTask(5000, [taskRunner]() { return taskRunner->IsShutdown(); }, updateFunc));

    taskRunner->CompleteShutdown();
  }
}

void ttv::test::EnsureEqual(const char* a, const char* b, bool copy) {
  if (a == nullptr || b == nullptr || !copy) {
    ASSERT_EQ(a, b);
  } else {
    ASSERT_EQ(strcmp(a, b), 0);
  }
}
