/****************************************************************************
 * Twitch SDK
 *
 * This software is supplied under the terms of a license agreement with
 * Twitch Interactive, Inc. and may not be copied or used except in accordance
 * with the terms of that agreement
 *
 * Copyright (c) 2012-2016 Twitch Interactive, Inc.
 ***************************************************************************/

#include "testutilities.h"
#include "twitchsdk/core/mutex.h"

using namespace ttv;
using namespace ttv::test;

namespace {
class MutexTest : public ::testing::Test {};
}  // namespace

TEST_F(MutexTest, MultipleThreads) {
  // Tests IMutex

  std::unique_ptr<IMutex> mutex;
  TTV_ErrorCode ec = ttv::CreateMutex(mutex, "MutexTest");
  ASSERT_TRUE(TTV_SUCCEEDED(ec));
  ASSERT_NE(mutex, nullptr);

  std::vector<std::shared_ptr<IThread>> threads;

  int tasksRun = 0;

  int numThreads = 100;
  int tasksPerThread = 10000;
  for (int i = 0; i < numThreads; i++) {
    auto insertFunc = [&mutex, &tasksRun, tasksPerThread]() {
      for (int i = 0; i < tasksPerThread; i++) {
        {
          AutoMutex lock(mutex.get());
          tasksRun++;
        }
      }
    };

    std::shared_ptr<IThread> thread;
    CreateThread(insertFunc, "thread " + std::to_string(i), thread);
    threads.push_back(thread);
  }

  for (auto thread : threads) {
    thread->Run();
  }

  ASSERT_EQ(threads.size(), numThreads);

  auto it = threads.begin();
  while (it != threads.end()) {
    (*it)->Join();
    it->reset();
    it = threads.erase(it);
  }
  ASSERT_EQ(threads.size(), 0);

  ASSERT_EQ(tasksRun, (numThreads * tasksPerThread));
}

TEST_F(MutexTest, ConditionMutex) {
  // Tests IConditionMutex

  std::unique_ptr<IConditionMutex> mutex;
  TTV_ErrorCode ec = ttv::CreateConditionMutex(mutex, "MutexTest");
  ASSERT_TRUE(TTV_SUCCEEDED(ec));
  ASSERT_NE(mutex, nullptr);

  std::vector<std::shared_ptr<IThread>> threads;

  int tasksRun = 0;

  int numThreads = 100;
  for (int i = 0; i < numThreads; i++) {
    auto insertFunc = [&mutex, &tasksRun]() {
      AutoMutex lock(mutex.get());
      mutex->Wait();

      tasksRun++;
    };

    std::shared_ptr<IThread> thread;
    CreateThread(insertFunc, "thread " + std::to_string(i), thread);
    threads.push_back(thread);
  }

  for (auto thread : threads) {
    thread->Run();
  }

  ASSERT_EQ(threads.size(), numThreads);
  ttv::Sleep(200);

  ASSERT_EQ(tasksRun, 0);

  mutex->Signal();
  ttv::Sleep(100);
  ASSERT_EQ(tasksRun, 1);

  mutex->Broadcast();
  ttv::Sleep(200);
  ASSERT_EQ(tasksRun, numThreads);

  auto it = threads.begin();
  while (it != threads.end()) {
    (*it)->Join();
    it->reset();
    it = threads.erase(it);
  }
  ASSERT_EQ(threads.size(), 0);

  ASSERT_EQ(tasksRun, numThreads);
}
