/****************************************************************************
 * Twitch SDK
 *
 * This software is supplied under the terms of a license agreement with
 * Twitch Interactive, Inc. and may not be copied or used except in accordance
 * with the terms of that agreement
 *
 * Copyright (c) 2012-2016 Twitch Interactive, Inc.
 ***************************************************************************/
#include "fixtures/sdkbasetest.h"
#include "testsystemclock.h"
#include "testutilities.h"
#include "twitchsdk/core/cache.h"
#include "twitchsdk/core/lrucache.h"
#include "twitchsdk/core/thread.h"

using namespace ttv;

namespace {
const char* EXPECTED_FALSE_FOR_KEY_NOT_FOUND = "should return false if key not found";

class CacheTest : public ttv::test::SdkBaseTest {};
}  // namespace

TEST_F(CacheTest, AddRemoveClear) {
  Cache<int, int> cache;
  cache.SetEntry(1, 1);
  EXPECT_TRUE(cache.ContainsEntry(1));
  int result;
  EXPECT_FALSE(cache.GetEntry(2, result));
  cache.SetEntry(2, 2);
  cache.RemoveEntry(2);
  EXPECT_FALSE(cache.ContainsEntry(2));
  cache.Clear();
  EXPECT_TRUE(cache.GetSize() == 0);
}

TEST_F(CacheTest, ExpireAndPurgeExpired) {
  Cache<int, int> cache;
  cache.SetEntry(1, 1);
  cache.ExpireEntry(1);
  EXPECT_FALSE(cache.ExpireEntry(2)) << EXPECTED_FALSE_FOR_KEY_NOT_FOUND;
  cache.PurgeExpired();
  EXPECT_FALSE(cache.ContainsEntry(1));

  cache.SetEntry(1, 1);
  cache.SetEntry(2, 2);
  cache.ExpireAll();
  cache.PurgeExpired();
  EXPECT_FALSE(cache.ContainsEntry(1));
  EXPECT_FALSE(cache.ContainsEntry(2));
}

TEST_F(CacheTest, SetExpireMarkNeverExpireAndPurge) {
  uint64_t expiryAgeMillis = 200;
  Cache<int, int> cache;
  cache.SetExpiryAge(expiryAgeMillis);
  cache.SetEntry(1, 1);
  cache.MarkEntryNeverExpires(1);
  EXPECT_FALSE(cache.MarkEntryNeverExpires(999)) << EXPECTED_FALSE_FOR_KEY_NOT_FOUND;
  cache.SetEntry(2, 2);
  cache.SetEntryExpiryTime(2, static_cast<Timestamp>(GetSystemTimeMilliseconds() + expiryAgeMillis));
  EXPECT_FALSE(cache.SetEntryExpiryTime(999, 0)) << EXPECTED_FALSE_FOR_KEY_NOT_FOUND;

  std::function<void()> updateFunc = [&cache]() { cache.PurgeExpired(); };

  std::function<bool()> checkNotContainsNeverExpiredKeyFunc = [&cache]() { return !cache.ContainsEntry(1); };

  std::function<bool()> checkNotContainsManuallyExpiredKeyFunc = [&cache]() { return !cache.ContainsEntry(2); };

  EXPECT_FALSE(ttv::test::WaitUntilResultWithPollTask(
    static_cast<uint>(expiryAgeMillis), checkNotContainsNeverExpiredKeyFunc, updateFunc));
  EXPECT_TRUE(ttv::test::WaitUntilResultWithPollTask(
    static_cast<uint>(expiryAgeMillis), checkNotContainsManuallyExpiredKeyFunc, updateFunc));
}

TEST_F(CacheTest, PurgeUnused) {
  uint64_t purgeTimeoutMillis = 200;
  Cache<int, int> cache;
  cache.SetEntry(1, 1);

  std::function<void()> updateFunc = [&cache, purgeTimeoutMillis]() { cache.PurgeUnused(purgeTimeoutMillis); };

  std::function<bool()> checkNotContainsFunc = [&cache]() { return !cache.ContainsEntry(1); };

  EXPECT_FALSE(ttv::test::WaitUntilResultWithPollTask(static_cast<uint>(purgeTimeoutMillis) - 100, checkNotContainsFunc,
    updateFunc));  // WaitUntilResultWithPollTask polls every 100ms
  EXPECT_TRUE(
    ttv::test::WaitUntilResultWithPollTask(static_cast<uint>(purgeTimeoutMillis), checkNotContainsFunc, updateFunc));
}

TEST_F(CacheTest, MarkUsedAndPurge) {
  uint64_t purgeTimeoutMillis = 200;
  Cache<int, int> cache;
  cache.SetEntry(1, 1);
  EXPECT_FALSE(cache.MarkEntryUsed(2)) << EXPECTED_FALSE_FOR_KEY_NOT_FOUND;

  std::function<void()> updateFunc = [&cache, purgeTimeoutMillis]() {
    cache.MarkEntryUsed(1);
    cache.PurgeUnused(purgeTimeoutMillis);
  };

  std::function<bool()> checkNotContainsFunc = [&cache]() { return !cache.ContainsEntry(1); };

  EXPECT_FALSE(
    ttv::test::WaitUntilResultWithPollTask(static_cast<uint>(purgeTimeoutMillis), checkNotContainsFunc, updateFunc));
}

TEST_F(CacheTest, MarkNeverUnusedAndPurge) {
  uint64_t purgeTimeoutMillis = 200;
  Cache<int, int> cache;
  cache.SetEntry(1, 1);
  cache.MarkEntryNeverUnused(1);
  EXPECT_FALSE(cache.MarkEntryNeverUnused(2)) << EXPECTED_FALSE_FOR_KEY_NOT_FOUND;

  std::function<void()> updateFunc = [&cache, purgeTimeoutMillis]() { cache.PurgeUnused(purgeTimeoutMillis); };

  std::function<bool()> checkNotContainsFunc = [&cache]() { return !cache.ContainsEntry(1); };

  EXPECT_FALSE(
    ttv::test::WaitUntilResultWithPollTask(static_cast<uint>(purgeTimeoutMillis), checkNotContainsFunc, updateFunc));
}

TEST_F(CacheTest, ForEachVisitorFunc) {
  Cache<int, int> cache;
  cache.SetEntry(1, 1);
  cache.SetEntry(2, 2);
  cache.SetEntry(3, 3);
  cache.SetEntry(4, 4);

  int sum = 0;
  cache.ForEach([&sum](Cache<int, int>::CacheEntry& entry) { sum += entry.data; });

  ASSERT_EQ(sum, 10);
}

// Demonstrates using ForEach on a cache with ptr value_types to make changes to CacheEntry data
TEST_F(CacheTest, ForEachExpiredVisitorFunc) {
  Cache<int, std::shared_ptr<int>> cache;
  cache.SetEntry(1, std::make_shared<int>(1));
  cache.SetEntry(2, std::make_shared<int>(2));
  cache.SetEntry(3, std::make_shared<int>(3));
  cache.SetEntry(4, std::make_shared<int>(4));

  cache.ForEach([](Cache<int, std::shared_ptr<int>>::CacheEntry& entry) { *entry.data *= 2; });

  cache.ExpireAll();
  cache.SetEntry(10, std::make_shared<int>(10));

  int sum = 0;
  cache.ForEachExpired([&sum](Cache<int, std::shared_ptr<int>>::CacheEntry& entry) { sum += *entry.data; });

  ASSERT_EQ(sum, 20);
}

TEST_F(CacheTest, Test32BitOverflow) {
  // We want the time to be a half hour before we roll over the 32-bit max.
  uint64_t targetTimeModulo = std::numeric_limits<uint32_t>::max() - (1000 * 60 * 30);
  uint64_t currentTimeModulo = GetSystemTimeMilliseconds() % std::numeric_limits<uint32_t>::max();

  if (currentTimeModulo < targetTimeModulo) {
    mTestSystemClock->SetOffset(targetTimeModulo - currentTimeModulo);
  }

  Cache<int, int> cache;

  // Set expiration to an hour
  cache.SetExpiryAge(1000 * 60 * 60);

  cache.SetEntry(4, 4);

  cache.ForEachExpired([](Cache<int, int>::CacheEntry& /*entry*/) { ADD_FAILURE(); });
}

TEST_F(CacheTest, LRU) {
  LruCache<int, int> cache(3);
  cache.SetEntry(1, 1);
  cache.SetEntry(2, 2);
  cache.SetEntry(3, 3);
  EXPECT_TRUE(cache.ContainsEntry(1));
  EXPECT_TRUE(cache.ContainsEntry(2));
  EXPECT_TRUE(cache.ContainsEntry(3));

  cache.SetEntry(4, 4);
  EXPECT_TRUE(cache.ContainsEntry(2));
  EXPECT_TRUE(cache.ContainsEntry(3));
  EXPECT_TRUE(cache.ContainsEntry(4));
  EXPECT_FALSE(cache.ContainsEntry(1));

  int result;
  // Entry 2 should be evicted next, but with the GetEntry call, 3 will become the least recently used instead.
  EXPECT_TRUE(cache.GetEntry(2, result));
  cache.SetEntry(1, 1);
  EXPECT_TRUE(cache.ContainsEntry(1));
  EXPECT_TRUE(cache.ContainsEntry(2));
  EXPECT_TRUE(cache.ContainsEntry(4));
  EXPECT_FALSE(cache.ContainsEntry(3));

  cache.RemoveEntry(4);
  EXPECT_TRUE(cache.ContainsEntry(1));
  EXPECT_TRUE(cache.ContainsEntry(2));
  EXPECT_FALSE(cache.ContainsEntry(4));

  cache.SetEntry(3, 3);
  EXPECT_TRUE(cache.ContainsEntry(1));
  EXPECT_TRUE(cache.ContainsEntry(2));
  EXPECT_TRUE(cache.ContainsEntry(3));

  EXPECT_EQ(cache.GetSize(), 3);
  cache.Clear();
  EXPECT_EQ(cache.GetSize(), 0);
}
