/****************************************************************************
 * Twitch SDK
 *
 * This software is supplied under the terms of a license agreement with
 * Twitch Interactive, Inc. and may not be copied or used except in accordance
 * with the terms of that agreement
 *
 * Copyright (c) 2012-2016 Twitch Interactive, Inc.
 ***************************************************************************/

#include "twitchsdk/broadcast/internal/pch.h"

#include "twitchsdk/broadcast/internal/framewriter.h"

#include "twitchsdk/core/mutex.h"
#include "twitchsdk/core/systemclock.h"
#include "twitchsdk/core/thread.h"

namespace {
// If the queue is longer than kMaxQueueDelayBeforeFlush (in ms) then at shutdown waiting packets
// will not get flushed, rather they will just get dumped.
const uint kMaxQueueDelayBeforeFlushMilliseconds = 2000;

const uint64_t kAverageBitRateMeasurementWindowMilliseconds = 10000;
const uint64_t kAverageCongestionMeasurementWindowMilliseconds = 4000;

// TODO: These constants should be centralized somewhere
const uint32_t kVideoStreamIndex = 0;
const uint32_t kAudioStreamIndex = 1;

/**
 * The congestion level at which our algorithm reduces the recommended bit rate. The scale of this value is 0.0-1.0.
 */
const double kCongestionThreshold = 0.1;

/**
 * How often, in milliseconds, we update the recommended bit rate.
 */
const uint64_t kBitRateUpdateIntervalMilliseconds = 1000;

/**
 * How often, in milliseconds, we can increase the recommended bit rate. This is based off recommendations of the
 * backend teams. The edge server distribution algorithm updates its model every 30 seconds, so we shouldn't be
 * increasing more rapidly than that or we can overload the server with a sudden spike in traffic.
 */
const uint64_t kBitRateIncreaseIntervalMilliseconds = 30000;

/**
 * The ratio between our recommended bit rate and the actual measured bit rate that our algorithm applies when the
 * stream is in an uncongested state. Should be greater than 1.0. This is based off of the video team's recommendations,
 * so that none of the edge servers that serve video get overloaded.
 */
const double kBitRateIncreaseRatio = 1.0875;

/**
 * The ratio between our recommended bit rate and the actual measured bit rate that our algorithm applies when the
 * stream is in a congested state. Should be less than 1.0.
 */
const double kBitRateDecreaseRatio = 0.95;

/**
 * The maximum number of video frames (in milliseconds) in the queue before aborting the stream.
 * Only applicable while FlvMuxer is NOT ready.
 */
const int kMaxFrameQueueOnMuxerUnreadyMilliseconds = 7000;  // 7 seconds
}  // namespace

ttv::broadcast::FrameWriter::FrameWriter(bool audioEnabled)
    : mWriteFrameThreadProceed(true),
      mWriteFrameThread(nullptr),
      mLastSentPacketTimestamp(0),
      mLastReceivedPacketTimestamp(0),
      mWarningDelayThreshold(0),
      mErrorDelayThreshold(0),
      mNumQueuedBytes(0),
      mEncodedBitCounter(0),
      mLastStatTime(0),
      mRecommendedBitRate(0),
      mLastError(TTV_EC_SUCCESS),
      mDelayState(DelayState::Okay),
      mAudioEnabled(audioEnabled) {
  ttv::trace::Message("FrameWriter", MessageLevel::Info, "FrameWriter created");
}

ttv::broadcast::FrameWriter::~FrameWriter() {
  Shutdown();

  ttv::trace::Message("FrameWriter", MessageLevel::Info, "FrameWriter destroyed");
}

void ttv::broadcast::FrameWriter::Shutdown() {
  AutoTracer tracer("FrameWriter", MessageLevel::Info, "FrameWriter::Shutdown()");

  if (mWriteFrameThread != nullptr) {
    // Wake up the processing thread so it can shut down
    {
      std::unique_lock<std::mutex> lock(mMutex);
      mWriteFrameThreadProceed = false;
    }
    mCondition.notify_all();

    mWriteFrameThread->Join();
    mWriteFrameThread.reset();
  }

  if (mFlvMuxer != nullptr) {
    mFlvMuxer->Stop();
    mFlvMuxer.reset();
  }

  if (mCustomMuxer != nullptr) {
    mCustomMuxer->Stop();
    mCustomMuxer.reset();
  }
}

TTV_ErrorCode ttv::broadcast::FrameWriter::SetFlvMuxer(const std::shared_ptr<FlvMuxer>& flvMuxer) {
  mFlvMuxer = flvMuxer;

  return TTV_EC_SUCCESS;
}

TTV_ErrorCode ttv::broadcast::FrameWriter::SetCustomMuxer(const std::shared_ptr<IMuxer>& muxer) {
  mCustomMuxer = muxer;

  return TTV_EC_SUCCESS;
}

TTV_ErrorCode ttv::broadcast::FrameWriter::Start(const VideoParams& videoParams) {
  AutoTracer tracer("FrameWriter", MessageLevel::Info, "FrameWriter::Start()");

  auto threadProc = [this]() {
    ttv::trace::Message("FrameWriter", MessageLevel::Info, "FrameWriter thread started");
    mBitrateUpdateTimer.Set(kBitRateUpdateIntervalMilliseconds);
    mBitrateIncreaseTimer.Set(kBitRateIncreaseIntervalMilliseconds);

    mStreamStartSystemTime = GetSystemClockTime();

    TTV_ErrorCode ec = TTV_EC_SUCCESS;
    for (;;) {
      {
        std::unique_lock<std::mutex> lock(mMutex);
        if (!mWriteFrameThreadProceed) {
          break;
        }

        if (mPacketQueue.empty()) {
          mCondition.wait(lock);
        } else {
          do {
            std::unique_ptr<Packet> packet = std::move(mPacketQueue.front());
            mPacketQueue.pop();

            TransferPacketToOutputQueues(std::move(packet));
          } while (mWriteFrameThreadProceed && !mPacketQueue.empty());

          if (!mWriteFrameThreadProceed) {
            break;
          }
        }
      }

      ec = SendDataToMuxers();

      if (mVideoParams.automaticBitRateAdjustmentEnabled && mBitrateUpdateTimer.Check(true)) {
        UpdateRecommendedBitRate();
        mBitrateUpdateTimer.Set(kBitRateUpdateIntervalMilliseconds);
      }

      UpdateDelayState();

      if (TTV_FAILED(ec) && (ec != TTV_EC_BROADCAST_NOMOREDATA)) {
        mLastError = ec;

        {
          std::unique_lock<std::mutex> lock(mMutex);
          mWriteFrameThreadProceed = false;
        }

        ttv::trace::Message("FrameWriter", MessageLevel::Debug,
          "FrameWriter thread received error from muxer, aborting: %s", ErrorToString(ec));
      }
    }

    // Flush remaining data to the muxer before stopping
    if (GetQueueDelayInMilliseconds() < kMaxQueueDelayBeforeFlushMilliseconds) {
      while (TTV_SUCCEEDED(ec)) {
        ec = SendDataToMuxers();
      }
    }

    // Ensure the client can get the last error
    if (TTV_SUCCEEDED(mLastError)) {
      mLastError = ec;
    }

    // Notify if the stream failed
    if (TTV_FAILED(mLastError)) {
      if (mStreamAbortCallback != nullptr) {
        mStreamAbortCallback(this, mLastError);
      }
    }

    {
      std::unique_lock<std::mutex> lock(mMutex);
      mPacketQueue = {};
    }

    mVideoPacketQueue = {};
    mAudioPacketQueue = {};

    ttv::trace::Message("FrameWriter", MessageLevel::Info, "FrameWriter thread exiting");
  };

  mVideoParams = videoParams;
  mRecommendedBitRate = videoParams.initialKbps * 1000;

  // Start the frame reading thread
  TTV_ErrorCode ec = ttv::CreateThread(threadProc, "ttv::broadcast::FrameWriter", mWriteFrameThread);
  TTV_ASSERT(TTV_SUCCEEDED(ec) && mWriteFrameThread != nullptr);

  if (TTV_SUCCEEDED(ec)) {
    mWriteFrameThread->Run();
  }

  return ec;
}

TTV_ErrorCode ttv::broadcast::FrameWriter::WritePacket(std::unique_ptr<Packet>&& packet) {
  TTV_ASSERT(packet != nullptr);
  if (packet == nullptr) {
    return TTV_EC_INVALID_ARG;
  }

  mLastReceivedPacketTimestamp = packet->timestamp;
  if (packet->streamIndex == kVideoStreamIndex) {
    mEncodedBitCounter += packet->data.size() * 8;
  }

  {
    std::unique_lock<std::mutex> lock(mMutex);
    mPacketQueue.push(std::move(packet));
  }
  mCondition.notify_all();

  return TTV_EC_SUCCESS;
}

TTV_ErrorCode ttv::broadcast::FrameWriter::GetLastError() {
  TTV_ErrorCode ec = mLastError;
  mLastError = TTV_EC_SUCCESS;
  return ec;
}

void ttv::broadcast::FrameWriter::TransferPacketToOutputQueues(std::unique_ptr<Packet>&& packet) {
  mNumQueuedBytes += static_cast<uint32_t>(packet->data.size());

  if (GetQueueDelayInMilliseconds() >= kMaxFrameQueueOnMuxerUnreadyMilliseconds && !mFlvMuxer->IsReady()) {
    if (mStreamAbortCallback != nullptr) {
      mStreamAbortCallback(this, TTV_EC_BROADCAST_FRAME_QUEUE_FULL);
    }
    return;
  }

  switch (packet->streamIndex) {
    case kVideoStreamIndex:
      mVideoPacketQueue.push(std::move(packet));
      break;

    case kAudioStreamIndex:
      assert(mAudioEnabled);
      mAudioPacketQueue.push(std::move(packet));
      break;

    default:
      TTV_ASSERT(false && "Invalid stream index");
      break;
  }
}

TTV_ErrorCode ttv::broadcast::FrameWriter::SendDataToMuxers() {
  // Check stream health before sending more data
  TTV_ErrorCode ec = TTV_EC_SUCCESS;

  if (mFlvMuxer != nullptr) {
    // Note: not returning error, so we can keep the frames queued up while waiting
    // for RTMP to connect
    if (!mFlvMuxer->IsReady()) {
      return TTV_EC_SUCCESS;
    }

    ec = mFlvMuxer->GetError();
  }

  if (TTV_SUCCEEDED(ec)) {
    ec = TTV_EC_BROADCAST_NOMOREDATA;

    size_t packetSize = 0;
    uint64_t nextPacketTimestamp = mLastSentPacketTimestamp;

    if (mAudioEnabled) {
      // Wait until there is both audio and video available
      if (!mAudioPacketQueue.empty() && !mVideoPacketQueue.empty()) {
        const auto& videoPacket = mVideoPacketQueue.front();
        const auto& audioPacket = mAudioPacketQueue.front();

        nextPacketTimestamp = std::min(videoPacket->timestamp, audioPacket->timestamp);

        // Specifically prefer video first if we have packets with the same timestamp
        if (videoPacket->timestamp <= audioPacket->timestamp) {
          packetSize = videoPacket->data.size();

          if (mFlvMuxer != nullptr) {
            ec = mFlvMuxer->WriteVideoPacket(*videoPacket);
          }

          if (mCustomMuxer != nullptr) {
            // NOTE: Intentionally ignoring errors
            mCustomMuxer->WriteVideoPacket(*videoPacket);
          }

          mVideoPacketQueue.pop();
        } else {
          packetSize = audioPacket->data.size();

          if (mFlvMuxer != nullptr) {
            ec = mFlvMuxer->WriteAudioPacket(*audioPacket);
          }

          if (mCustomMuxer != nullptr) {
            // NOTE: Intentionally ignoring errors
            mCustomMuxer->WriteAudioPacket(*audioPacket);
          }

          mAudioPacketQueue.pop();
        }
      }
    } else {
      TTV_ASSERT(mAudioPacketQueue.empty());

      if (!mVideoPacketQueue.empty()) {
        const auto& videoPacket = mVideoPacketQueue.front();
        packetSize = videoPacket->data.size();

        nextPacketTimestamp = videoPacket->timestamp;

        if (mFlvMuxer != nullptr) {
          ec = mFlvMuxer->WriteVideoPacket(*videoPacket);
        }

        if (mCustomMuxer != nullptr) {
          // NOTE: Intentionally ignoring errors
          mCustomMuxer->WriteVideoPacket(*videoPacket);
        }

        mVideoPacketQueue.pop();
      }
    }

    mLastSentPacketTimestamp = nextPacketTimestamp;

    TTV_ASSERT(packetSize <= mNumQueuedBytes);
    mNumQueuedBytes -= static_cast<uint32_t>(packetSize);
  }

  if (mFlvMuxer != nullptr) {
    mFlvMuxer->Update();
  }

  return ec;
}

uint64_t ttv::broadcast::FrameWriter::GetQueueDelayInMilliseconds() const {
  uint64_t numPendingVideoFrames = static_cast<uint64_t>(mVideoPacketQueue.size());
  return 1000 * numPendingVideoFrames / mVideoParams.targetFramesPerSecond;
}

uint32_t ttv::broadcast::FrameWriter::GetRecommendedBitRate() const {
  return mRecommendedBitRate;
}

void ttv::broadcast::FrameWriter::UpdateRecommendedBitRate() {
  if (mFlvMuxer == nullptr) {
    return;
  }

  if (!mFlvMuxer->IsReady()) {
    return;
  }

  // If we can't get an estimate for bps or congestion, bail out.
  uint64_t actualBitsPerSecond = 0;
  if (TTV_FAILED(mFlvMuxer->GetAverageSendBitRate(kAverageBitRateMeasurementWindowMilliseconds, actualBitsPerSecond))) {
    return;
  }

  double congestionLevel = 0.0;
  if (TTV_FAILED(mFlvMuxer->GetCongestionLevel(kAverageCongestionMeasurementWindowMilliseconds, congestionLevel))) {
    return;
  }

  uint32_t previousRecommendedBitRate = mRecommendedBitRate;
  if (congestionLevel > kCongestionThreshold) {
    mRecommendedBitRate = std::min(
      mRecommendedBitRate, static_cast<uint32_t>(static_cast<double>(actualBitsPerSecond) * kBitRateDecreaseRatio));
    mBitrateIncreaseTimer.Set(kBitRateIncreaseIntervalMilliseconds);
  } else if (mBitrateIncreaseTimer.Check(true)) {
    mRecommendedBitRate = static_cast<uint32_t>(static_cast<double>(mRecommendedBitRate) * kBitRateIncreaseRatio);
    mBitrateIncreaseTimer.Set(kBitRateIncreaseIntervalMilliseconds);
  }

  // Clamp to the bounds specified
  mRecommendedBitRate = std::min(mRecommendedBitRate, 1000 * mVideoParams.maximumKbps);
  mRecommendedBitRate = std::max(mRecommendedBitRate, 1000 * mVideoParams.minimumKbps);

  if (mBandwidthStatCallback != nullptr) {
    uint64_t encodedBitCount = 0;
    encodedBitCount = mEncodedBitCounter.exchange(encodedBitCount);

    uint64_t lastStatTime = mLastStatTime;
    mLastStatTime = GetSystemClockTime();
    if (lastStatTime != 0) {
      uint64_t encodedBps = GetSystemClockFrequency() * encodedBitCount / (mLastStatTime - lastStatTime);
      uint64_t ticksSinceLastStatTime = GetSystemClockTime() - mStreamStartSystemTime;
      double streamTime = static_cast<double>(ticksSinceLastStatTime) / static_cast<double>(GetSystemClockFrequency());
      double queueDelay = static_cast<double>(GetQueueDelayInMilliseconds()) / 1000.0;

      BandwidthStat stat;
      stat.recordedTime = streamTime;
      stat.recommendedBitsPerSecond = mRecommendedBitRate;
      stat.measuredBitsPerSecond = actualBitsPerSecond;
      stat.encoderOutputBitsPerSecond = encodedBps;
      stat.backBufferSeconds = queueDelay;
      stat.congestionLevel = congestionLevel;

      mBandwidthStatCallback(this, stat);

      std::unique_lock<std::mutex> lock(mRunningTotalsMutex);
      uint64_t recommendedBits =
        (static_cast<uint64_t>(previousRecommendedBitRate) * GetSystemClockFrequency() * 1000ULL) /
        (ticksSinceLastStatTime);
      mRunningTotals.recommendedBits += recommendedBits;
      mRunningTotals.encoderOutputBits += encodedBitCount;
      mRunningTotals.elapsedSystemTime += ticksSinceLastStatTime;
    }
  }

  ttv::trace::Message("FrameWriter", MessageLevel::Debug,
    "Setting bit rate to %u, based on measured BPS %llu and congestion level %.4f", mRecommendedBitRate,
    actualBitsPerSecond, congestionLevel);
}

TTV_ErrorCode ttv::broadcast::FrameWriter::GatherTrackingStats(
  uint64_t& averageRecommendedKbps, uint64_t& averageEncodedKbps) {
  std::unique_lock<std::mutex> lock(mRunningTotalsMutex);
  if (mRunningTotals.elapsedSystemTime == 0) {
    return TTV_EC_NOT_AVAILABLE;
  }

  averageRecommendedKbps =
    (mRunningTotals.recommendedBits * GetSystemClockFrequency()) / mRunningTotals.elapsedSystemTime;
  averageEncodedKbps =
    (mRunningTotals.encoderOutputBits * GetSystemClockFrequency()) / mRunningTotals.elapsedSystemTime;

  mRunningTotals = {};

  return TTV_EC_SUCCESS;
}

void ttv::broadcast::FrameWriter::UpdateDelayState() {
  if (!mFlvMuxer->IsReady()) {
    return;
  }

  DelayState newState = DelayState::Okay;

  uint64_t queueDelay = GetQueueDelayInMilliseconds();
  if (queueDelay > mErrorDelayThreshold) {
    newState = DelayState::Error;
  } else if (queueDelay > mWarningDelayThreshold) {
    newState = DelayState::Warning;
  }

  if (mDelayState != newState) {
    mDelayState = newState;
    if (mDelayStateChangedCallback != nullptr) {
      mDelayStateChangedCallback(this, newState);
    }
  }
}
