/****************************************************************************
 * Twitch SDK
 *
 * This software is supplied under the terms of a license agreement with
 * Twitch Interactive, Inc. and may not be copied or used except in accordance
 * with the terms of that agreement
 *
 * Copyright (c) 2012-2016 Twitch Interactive, Inc.
 ***************************************************************************/

#include "twitchsdk/core/internal/pch.h"

#include "twitchsdk/core/timer.h"

#include "twitchsdk/core/random.h"
#include "twitchsdk/core/systemclock.h"

#include <algorithm>
#include <limits>
#include <random>

namespace {
const uint64_t kDefaultMaxInterval = 1000 * 60;
const uint64_t kDefaultJitter = 1000 * 1;
}  // namespace

uint64_t ttv::JitterTime(uint64_t baseMs, uint64_t widthMs) {
  widthMs = std::min(widthMs, baseMs);

  if (widthMs != 0) {
    std::uniform_int_distribution<int32_t> distribution(-static_cast<int32_t>(widthMs), static_cast<int32_t>(widthMs));
    int32_t jitter = distribution(ttv::random::GetGenerator());
    return static_cast<uint64_t>(static_cast<int32_t>(baseMs) + jitter);
  } else {
    return baseMs;
  }
}

ttv::WaitForEventWithTimeout::WaitForEventWithTimeout() : mExpiryTime(0), mState(WFEWT_TimedOut) {}

void ttv::WaitForEventWithTimeout::Reset(uint64_t timeout) {
  mExpiryTime = GetSystemTimeMilliseconds() + timeout;
  mState = WFEWT_Waiting;
}

void ttv::WaitForEventWithTimeout::Complete() {
  if (mState == WFEWT_Waiting) {
    mState = WFEWT_Complete;
  }
}

ttv::WaitForEventWithTimeout::eWaitState ttv::WaitForEventWithTimeout::GetState() {
  if (mState == WFEWT_Waiting) {
    if (GetSystemTimeMilliseconds() >= mExpiryTime) {
      mState = WFEWT_TimedOut;
    }
  }

  return mState;
}

ttv::WaitForExpiry::WaitForExpiry() : mStartTime(0), mEndTime(0) {}

void ttv::WaitForExpiry::Set(uint64_t milliseconds) {
  mStartTime = GetSystemTimeMilliseconds();
  mEndTime = mStartTime + milliseconds;
}

void ttv::WaitForExpiry::SetWithJitter(uint64_t milliseconds, uint64_t jitterWidthMs) {
  mStartTime = GetSystemTimeMilliseconds();
  mEndTime = mStartTime + JitterTime(milliseconds, jitterWidthMs);
}

void ttv::WaitForExpiry::AdjustDuration(uint64_t milliseconds) {
  mEndTime = mStartTime + milliseconds;
}

bool ttv::WaitForExpiry::Check(bool clearWhenExpired) {
  bool expired;

  if (mStartTime > 0) {
    expired = GetSystemTimeMilliseconds() >= mEndTime;

    if (expired && clearWhenExpired) {
      Clear();
    }
  } else {
    expired = false;
  }

  return expired;
}

void ttv::WaitForExpiry::Clear() {
  mStartTime = 0;
  mEndTime = 0;
}

uint64_t ttv::WaitForExpiry::GetRemainingTime() const {
  if (mStartTime > 0) {
    return std::max(static_cast<uint64_t>(0), mEndTime - GetSystemTimeMilliseconds());
  } else {
    return std::numeric_limits<uint64_t>::max();
  }
}

uint64_t ttv::WaitForExpiry::GetElapsedTime() const {
  if (mStartTime > 0) {
    return std::max(static_cast<uint64_t>(0), GetSystemTimeMilliseconds() - mStartTime);
  } else {
    return 0;
  }
}

ttv::RetryBackoffTable::RetryBackoffTable() : mJitterMilliseconds(kDefaultJitter), mNextAttemptNumber(0) {
  CreateTable(kDefaultMaxInterval);
}

ttv::RetryBackoffTable::RetryBackoffTable(const std::vector<uint64_t>& tableMilliseconds, uint64_t retryJitterWidthMs)
    : mJitterMilliseconds(retryJitterWidthMs), mNextAttemptNumber(0) {
  mBackOffTableMilliseconds = tableMilliseconds;
  mJitterMilliseconds = retryJitterWidthMs;
}

ttv::RetryBackoffTable::RetryBackoffTable(uint64_t maxInterval, uint64_t retryJitterWidthMs)
    : mJitterMilliseconds(retryJitterWidthMs), mNextAttemptNumber(0) {
  CreateTable(maxInterval);
}

void ttv::RetryBackoffTable::Advance() {
  if (mNextAttemptNumber < mBackOffTableMilliseconds.size() - 1) {
    mNextAttemptNumber++;
  }
}

void ttv::RetryBackoffTable::Reset() {
  mNextAttemptNumber = 0;
}

uint64_t ttv::RetryBackoffTable::GetInterval() const {
  uint64_t jitter = mNextAttemptNumber == 0 ? 0 : mJitterMilliseconds;
  return JitterTime(mBackOffTableMilliseconds[mNextAttemptNumber], jitter);
}

void ttv::RetryBackoffTable::CreateTable(uint64_t maxInterval) {
  // Compute the exponential table
  mBackOffTableMilliseconds.clear();

  uint64_t interval = std::min(static_cast<uint64_t>(1000), maxInterval);

  // Don't let the table grow too large if someone put in a large or invalid max interval
  while (interval < maxInterval && mBackOffTableMilliseconds.size() < 32) {
    mBackOffTableMilliseconds.push_back(interval);

    interval *= 2;
  }

  if (mBackOffTableMilliseconds.back() < maxInterval) {
    mBackOffTableMilliseconds.push_back(maxInterval);
  }
}

ttv::RetryTimer::RetryTimer() : mRetryJitter(0), mNextAttemptNumber(0) {
  SetBackoffTable(kDefaultMaxInterval, kDefaultJitter);
}

ttv::RetryTimer::RetryTimer(const std::vector<uint64_t>& backOffTableMs, uint64_t retryJitterWidthMs)
    : mRetryJitter(0), mNextAttemptNumber(0) {
  SetBackoffTable(backOffTableMs, retryJitterWidthMs);
}

ttv::RetryTimer::RetryTimer(uint64_t maxInterval, uint64_t retryJitterWidthMs)
    : mRetryJitter(0), mNextAttemptNumber(0) {
  SetBackoffTable(maxInterval, retryJitterWidthMs);
}

void ttv::RetryTimer::SetBackoffTable(uint64_t maxInterval, uint64_t retryJitterWidthMs) {
  // Compute the exponential table
  std::vector<uint64_t> table;
  table.push_back(0);

  uint64_t interval = std::min(static_cast<uint64_t>(1000), maxInterval);

  // Don't let the table grow too large if someone put in a large or invalid max interval
  while (interval < maxInterval && table.size() < 32) {
    table.push_back(interval);

    interval *= 2;
  }

  if (table.back() < maxInterval) {
    table.push_back(maxInterval);
  }

  SetBackoffTable(table, retryJitterWidthMs);
}

void ttv::RetryTimer::SetBackoffTable(const std::vector<uint64_t>& backOffTableMs, uint64_t retryJitterWidthMs) {
  TTV_ASSERT(!backOffTableMs.empty());

  mBackOffTable = backOffTableMs;
  mRetryJitter = retryJitterWidthMs;
}

void ttv::RetryTimer::StartGlobalReset(uint64_t milliseconds) {
  mGlobalResetTimer.Set(milliseconds);
}

bool ttv::RetryTimer::CheckNextRetry() {
  return mNextRetry.Check(true);
}

bool ttv::RetryTimer::CheckGlobalReset() {
  bool elapsed = mGlobalResetTimer.Check(true);

  if (elapsed) {
    Clear();
  }

  return elapsed;
}

void ttv::RetryTimer::ScheduleNextRetry() {
  if (!mNextRetry.IsSet()) {
    mNextRetry.Set(GetNextAttempt());
  }
}

void ttv::RetryTimer::Clear() {
  mNextRetry.Clear();
  mGlobalResetTimer.Clear();

  mNextAttemptNumber = 0;
}

void ttv::RetryTimer::ClearGlobalReset() {
  mGlobalResetTimer.Clear();
}

uint64_t ttv::RetryTimer::GetNextAttempt() {
  // Only apply jitter if not the first try
  uint64_t jitter = mNextAttemptNumber == 0 ? 0 : mRetryJitter;
  uint64_t value = JitterTime(mBackOffTable[mNextAttemptNumber], jitter);

  if (mNextAttemptNumber < mBackOffTable.size() - 1) {
    mNextAttemptNumber++;
  }

  return value;
}

ttv::LambdaRetryTimer::LambdaRetryTimer() : mBackOffTable(RetryBackoffTable()), mTaskId(0), mTimerSet(false) {}

ttv::LambdaRetryTimer::LambdaRetryTimer(const std::vector<uint64_t>& backOffTableMs, uint64_t retryJitterWidthMs)
    : mBackOffTable(RetryBackoffTable(backOffTableMs, retryJitterWidthMs)), mTaskId(0), mTimerSet(false) {}

ttv::LambdaRetryTimer::LambdaRetryTimer(uint64_t maxInterval, uint64_t retryJitterWidthMs)
    : mBackOffTable(RetryBackoffTable(maxInterval, retryJitterWidthMs)), mTaskId(0), mTimerSet(false) {}

ttv::LambdaRetryTimer::~LambdaRetryTimer() {
  // Cancel the callback on the event scheduler (if it exists)
  Stop();
}

void ttv::LambdaRetryTimer::SetCallback(CallbackFunc&& func) {
  mCallback = std::move(func);
}

void ttv::LambdaRetryTimer::SetEventScheduler(const std::shared_ptr<IEventScheduler>& eventScheduler) {
  mEventScheduler = eventScheduler;
}

TTV_ErrorCode ttv::LambdaRetryTimer::Start(uint64_t milliseconds) {
  // First stop the current timer
  Stop();

  if (mEventScheduler != nullptr && mCallback != nullptr) {
    mTimerSet = true;

    Result<TaskId> result = mEventScheduler->ScheduleTask({[this, callbackFunc = mCallback]() {
                                                             mTimerSet = false;
                                                             callbackFunc();
                                                           },
      milliseconds, "LambdaRetryTimer"});

    if (!result.IsError()) {
      mTaskId = result.GetResult();
    }

    return result.GetErrorCode();
  }

  return TTV_EC_NOT_INITIALIZED;
}

TTV_ErrorCode ttv::LambdaRetryTimer::StartBackoff() {
  TTV_ErrorCode ec = Start(mBackOffTable.GetInterval());
  if (TTV_SUCCEEDED(ec)) {
    mBackOffTable.Advance();
  }

  return ec;
}

TTV_ErrorCode ttv::LambdaRetryTimer::Stop() {
  TTV_ErrorCode ec = TTV_EC_SUCCESS;

  if (mTimerSet.exchange(false)) {
    if (mEventScheduler != nullptr) {
      ec = mEventScheduler->CancelTask(mTaskId);
    }
    mTaskId = 0;
  } else {
    return TTV_EC_OPERATION_FAILED;
  }

  return ec;
}

void ttv::LambdaRetryTimer::ResetBackoff() {
  mBackOffTable.Reset();
}

bool ttv::LambdaRetryTimer::IsSet() const {
  return mTimerSet;
}
