/****************************************************************************
 * Twitch SDK
 *
 * This software is supplied under the terms of a license agreement with
 * Twitch Interactive, Inc. and may not be copied or used except in accordance
 * with the terms of that agreement
 *
 * Copyright (c) 2012-2016 Twitch Interactive, Inc.
 ***************************************************************************/

#include "twitchsdk/core/internal/pch.h"

#include "twitchsdk/core/socket.h"

#include "twitchsdk/core/resourcefactorychain.h"
#include "twitchsdk/core/systemclock.h"
#include "twitchsdk/core/thread.h"

#include <atomic>
#include <memory>
#include <sstream>

namespace {
using namespace ttv;

const uint64_t kMaxCacheAge = 1000;

std::unique_ptr<ResourceFactoryChain<ISocket, ISocketFactory>> gSocketFactoryChain;
std::unique_ptr<ResourceFactoryChain<IWebSocket, IWebSocketFactory>> gWebSocketFactoryChain;
}  // namespace

// TODO: It would be nice to use a reader-writer lock

TTV_ErrorCode ttv::ISocket::Send(const uint8_t* buffer, size_t length) {
  TTV_ASSERT(buffer);
  TTV_ASSERT(length > 0);

  TTV_ErrorCode ec = TTV_EC_SUCCESS;

  while (length > 0) {
    size_t sent = 0;
    ec = Send(buffer, length, sent);

    if (TTV_SUCCEEDED(ec)) {
      buffer += static_cast<size_t>(sent);
      length -= static_cast<size_t>(sent);
    } else {
      break;
    }
  }

  return ec;
}

TTV_ErrorCode ttv::InitializeSocketLibrary() {
  ttv::trace::Message("Core", MessageLevel::Info, "InitializeSocketLibrary()");

  if (gSocketFactoryChain != nullptr) {
    ttv::trace::Message("Core", MessageLevel::Error, "InitializeSocketLibrary() already initialized");

    return TTV_EC_ALREADY_INITIALIZED;
  }

  gSocketFactoryChain = std::make_unique<ResourceFactoryChain<ISocket, ISocketFactory>>("ISocketFactory");
  gWebSocketFactoryChain = std::make_unique<ResourceFactoryChain<IWebSocket, IWebSocketFactory>>("IWebSocketFactory");

  return TTV_EC_SUCCESS;
}

TTV_ErrorCode ttv::ShutdownSocketLibrary() {
  ttv::trace::Message("Core", MessageLevel::Info, "ShutdownSocketLibrary()");

  if (gSocketFactoryChain == nullptr) {
    ttv::trace::Message("Core", MessageLevel::Error, "ShutdownSocketLibrary() not initialized");

    return TTV_EC_NOT_INITIALIZED;
  }

  gSocketFactoryChain.reset();
  gWebSocketFactoryChain.reset();

#if TTV_USE_OPENSSL
  if (gOpenSslInitialized) {
    gOpenSslInitialized = false;
    ttv::OpenSslSocket::ShutdownOpenSslSockets();
  }
#endif

#if TTV_USE_WINSOCK
  if (gWinsockInitialized) {
    gWinsockInitialized = false;
    ttv::WinSocket::ShutdownWinSock();
  }
#endif

  return TTV_EC_SUCCESS;
}

TTV_ErrorCode ttv::RegisterSocketFactory(const std::shared_ptr<ISocketFactory>& factory) {
  ttv::trace::Message("Core", MessageLevel::Debug, "RegisterSocketFactory()");

  TTV_ASSERT(gSocketFactoryChain != nullptr);
  if (gSocketFactoryChain == nullptr) {
    ttv::trace::Message(
      "Core", MessageLevel::Error, "ttv::RegisterSocketFactory(): gSocketFactoryChain not initialized");

    return TTV_EC_NOT_INITIALIZED;
  }

  return gSocketFactoryChain->Register(factory);
}

TTV_ErrorCode ttv::UnregisterSocketFactory(const std::shared_ptr<ISocketFactory>& factory) {
  ttv::trace::Message("Core", MessageLevel::Debug, "UnregisterSocketFactory()");

  TTV_ASSERT(gSocketFactoryChain != nullptr);
  if (gSocketFactoryChain == nullptr) {
    ttv::trace::Message(
      "Core", MessageLevel::Error, "ttv::UnregisterSocketFactory(): gSocketFactoryChain not initialized");

    return TTV_EC_NOT_INITIALIZED;
  }

  return gSocketFactoryChain->Unregister(factory);
}

TTV_ErrorCode ttv::IsSocketProtocolSupported(const std::string& protocol) {
  TTV_ASSERT(gSocketFactoryChain != nullptr);
  if (gSocketFactoryChain == nullptr) {
    ttv::trace::Message(
      "Core", MessageLevel::Error, "ttv::IsSocketProtocolSupported(): gSocketFactoryChain not initialized");

    return TTV_EC_NOT_INITIALIZED;
  } else if (gSocketFactoryChain->Empty()) {
    return TTV_EC_NO_FACTORIES_REGISTERED;
  }

  bool found = gSocketFactoryChain->BoolQuery(
    [&protocol](const std::shared_ptr<ISocketFactory>& factory) { return factory->IsProtocolSupported(protocol); });

  if (found) {
    return TTV_EC_SUCCESS;
  } else {
    return TTV_EC_UNIMPLEMENTED;
  }
}

TTV_ErrorCode ttv::CreateSocket(const std::string& uri, std::shared_ptr<ISocket>& result) {
  ttv::trace::Message("Core", MessageLevel::Debug, "ttv::CreateSocket(): %s", uri.c_str());

  TTV_ASSERT(gSocketFactoryChain != nullptr);
  if (gSocketFactoryChain == nullptr) {
    ttv::trace::Message("Core", MessageLevel::Error, "ttv::CreateSocket(): gSocketFactoryChain not initialized");

    return TTV_EC_NOT_INITIALIZED;
  }

  return gSocketFactoryChain->Create(
    [&uri](const std::shared_ptr<ISocketFactory>& factory, std::shared_ptr<ISocket>& socket) -> TTV_ErrorCode {
      return factory->CreateSocket(uri, socket);
    },
    result);
}

TTV_ErrorCode ttv::RegisterWebSocketFactory(const std::shared_ptr<IWebSocketFactory>& factory) {
  ttv::trace::Message("Core", MessageLevel::Debug, "ttv::RegisterWebSocketFactory()");

  TTV_ASSERT(gWebSocketFactoryChain != nullptr);
  if (gWebSocketFactoryChain == nullptr) {
    ttv::trace::Message(
      "Core", MessageLevel::Error, "ttv::RegisterWebSocketFactory(): gWebSocketFactoryChain not initialized");

    return TTV_EC_NOT_INITIALIZED;
  }

  return gWebSocketFactoryChain->Register(factory);
}

TTV_ErrorCode ttv::UnregisterWebSocketFactory(const std::shared_ptr<IWebSocketFactory>& factory) {
  ttv::trace::Message("Core", MessageLevel::Debug, "ttv::UnregisterWebSocketFactory()");

  TTV_ASSERT(gWebSocketFactoryChain != nullptr);
  if (gWebSocketFactoryChain == nullptr) {
    ttv::trace::Message(
      "Core", MessageLevel::Error, "ttv::UnregisterWebSocketFactory(): gWebSocketFactoryChain not initialized");

    return TTV_EC_NOT_INITIALIZED;
  }

  return gWebSocketFactoryChain->Unregister(factory);
}

TTV_ErrorCode ttv::IsWebSocketProtocolSupported(const std::string& protocol) {
  TTV_ASSERT(gWebSocketFactoryChain != nullptr);
  if (gWebSocketFactoryChain == nullptr) {
    ttv::trace::Message(
      "Core", MessageLevel::Error, "ttv::IsWebSocketProtocolSupported(): gWebSocketFactoryChain not initialized");

    return TTV_EC_NOT_INITIALIZED;
  } else if (gWebSocketFactoryChain->Empty()) {
    return TTV_EC_NO_FACTORIES_REGISTERED;
  }

  bool found =
    gWebSocketFactoryChain->BoolQuery([&protocol](const std::shared_ptr<IWebSocketFactory>& factory) -> bool {
      return factory->IsProtocolSupported(protocol);
    });

  if (found) {
    return TTV_EC_SUCCESS;
  } else {
    return TTV_EC_UNIMPLEMENTED;
  }
}

TTV_ErrorCode ttv::CreateWebSocket(const std::string& uri, std::shared_ptr<IWebSocket>& result) {
  ttv::trace::Message("Core", MessageLevel::Debug, "ttv::CreateWebSocket(): %s", uri.c_str());

  TTV_ASSERT(gWebSocketFactoryChain != nullptr);
  if (gWebSocketFactoryChain == nullptr) {
    ttv::trace::Message("Core", MessageLevel::Error, "ttv::CreateWebSocket(): gWebSocketFactoryChain not initialized");

    return TTV_EC_NOT_INITIALIZED;
  }

  return gWebSocketFactoryChain->Create(
    [&uri](const std::shared_ptr<IWebSocketFactory>& factory, std::shared_ptr<IWebSocket>& socket) -> TTV_ErrorCode {
      return factory->CreateWebSocket(uri, socket);
    },
    result);
}

ttv::BufferedSocket::BufferedSocket() : mLastFlushTime(0), mCachePos(0), mBlocking(false) {}

ttv::BufferedSocket::~BufferedSocket() {}

void ttv::BufferedSocket::Bind(const std::shared_ptr<ISocket>& socket) {
  if (mSocket != nullptr) {
    FlushCache();
  }

  mSocket = socket;
}

TTV_ErrorCode ttv::BufferedSocket::Connect() {
  TTV_ErrorCode ec = mSocket->Connect();

  if (TTV_SUCCEEDED(ec)) {
    mTracker.Reset();
  }

  return ec;
}

TTV_ErrorCode ttv::BufferedSocket::Disconnect() {
  if (mSocket == nullptr) {
    return TTV_EC_SUCCESS;
  }

  return mSocket->Disconnect();
}

TTV_ErrorCode ttv::BufferedSocket::Send(const uint8_t* buffer, size_t length, bool cache) {
  TTV_ErrorCode ec = TTV_EC_SUCCESS;

  // Don't cache or trying to send more bytes than the cache can hold
  if (length > mCache.size() || !cache) {
    // Flush it and send immediately
    ec = FlushCache();

    if (TTV_SUCCEEDED(ec)) {
      ec = DoSend(buffer, length);
    }
  } else {
    // Fill the cache as much as possible
    auto maxCache = std::min(length, mCache.size() - mCachePos);
    if (maxCache > 0) {
      memcpy(&mCache[mCachePos], buffer, maxCache);
      mCachePos += maxCache;
      buffer += maxCache;
      length -= static_cast<size_t>(maxCache);
    }

    // Still more to send
    if (length > 0) {
      ec = FlushCache();

      if (TTV_SUCCEEDED(ec)) {
        memcpy(&mCache[0], buffer, length);
        mCachePos = length;
      }
    }

    // Flush the cache if bytes have been held locally for long enough
    if (mLastFlushTime + MsToSystemTime(kMaxCacheAge) < GetSystemClockTime()) {
      ec = FlushCache();
    }
  }

  return ec;
}

TTV_ErrorCode ttv::BufferedSocket::FlushCache() {
  TTV_ErrorCode ec = TTV_EC_SUCCESS;

  if (mCachePos > 0) {
    ec = DoSend(&mCache[0], static_cast<size_t>(mCachePos));

    if (TTV_SUCCEEDED(ec)) {
      mLastFlushTime = GetSystemClockTime();
      mCachePos = 0;
    }
  }

  return ec;
}

TTV_ErrorCode ttv::BufferedSocket::DoSend(const uint8_t* buffer, size_t length) {
  uint64_t startTime = GetSystemClockTime();
  TTV_ErrorCode ec = mSocket->Send(buffer, length);

  // Record the send stats
  if (TTV_SUCCEEDED(ec)) {
    uint64_t endTime = GetSystemClockTime();
    mTracker.AddSendInfo(static_cast<uint32_t>(length), startTime, endTime - startTime);
  }

  return ec;
}

TTV_ErrorCode ttv::BufferedSocket::GetAverageSendBitRate(
  uint64_t measurementWindowMilliseconds, uint64_t& bitsPerSecond) const {
  return mTracker.GetAverageOutgoingRate(measurementWindowMilliseconds, bitsPerSecond);
}

TTV_ErrorCode ttv::BufferedSocket::GetCongestionLevel(
  uint64_t measurementWindowMilliseconds, double& congestionLevel) const {
  return mTracker.GetEstimatedCongestionLevel(measurementWindowMilliseconds, congestionLevel);
}

TTV_ErrorCode ttv::BufferedSocket::Recv(
  uint8_t* buffer, size_t length, size_t& received, uint64_t maxWaitForBufferFill) {
  received = 0;

  if (!Connected()) {
    return TTV_EC_SOCKET_ENOTCONN;
  }

  TTV_ErrorCode ec = TTV_EC_SUCCESS;

  uint64_t startTime = GetSystemTimeMilliseconds();
  bool done = false;
  bool block = mBlocking;
  size_t got = 0;

  do {
    done = true;

    ec = mSocket->Recv(buffer, length, got);

    if (TTV_SUCCEEDED(ec)) {
      buffer += got;
      received += got;
      length -= got;
      done = (length == 0) || !block;
    } else if (ec == TTV_EC_SOCKET_EWOULDBLOCK) {
      if (block) {
        uint64_t elapsed = GetSystemTimeMilliseconds() - startTime;

        if (elapsed <= maxWaitForBufferFill) {
          done = false;
          Sleep(0);
        }
      }
    }

  } while (!done);

  return ec;
}

TTV_ErrorCode ttv::BufferedSocket::SetBlockingMode(bool blockingMode) {
  mBlocking = blockingMode;

  return TTV_EC_SUCCESS;
}

uint64_t ttv::BufferedSocket::TotalSent() {
  return mSocket->TotalSent();
}

uint64_t ttv::BufferedSocket::TotalReceived() {
  return mSocket->TotalReceived();
}

bool ttv::BufferedSocket::Connected() {
  return mSocket->Connected();
}
