/****************************************************************************
 * Twitch SDK
 *
 * This software is supplied under the terms of a license agreement with
 * Twitch Interactive, Inc. and may not be copied or used except in accordance
 * with the terms of that agreement
 *
 * Copyright (c) 2012-2016 Twitch Interactive, Inc.
 ***************************************************************************/

#include "twitchsdk/core/internal/pch.h"

#include "twitchsdk/core/winsocket.h"

#include "twitchsdk/core/assertion.h"
#include "twitchsdk/core/httprequestutils.h"
#include "twitchsdk/core/systemclock.h"
#include "twitchsdk/core/types/errortypes.h"

#define _WINSOCKAPI_
#include <windows.h>

#include <mstcpip.h>
#include <winsock2.h>
#include <ws2tcpip.h>

#pragma comment(lib, "ws2_32.lib")
#pragma comment(lib, "fwpuclnt.lib")

// This is here for symmetry since WinSock doesn't have a function for this like it does for all others
#define WSACloseSocket(SOCKET) closesocket(SOCKET)

namespace {
/**
 * Converts a WSA error code to a core errorid error code.
 */
TTV_ErrorCode ConvertToErrorCode(int ec) {
  switch (ec) {
    case WSAEWOULDBLOCK:
      return TTV_EC_SOCKET_EWOULDBLOCK;
    case WSAENOTCONN:
      return TTV_EC_SOCKET_ENOTCONN;
    case WSATRY_AGAIN:
      return TTV_EC_SOCKET_TRY_AGAIN;
    case WSAECONNABORTED:
      return TTV_EC_SOCKET_ECONNABORTED;
    case WSAEALREADY:
      return TTV_EC_SOCKET_EALREADY;
    case WSAETIMEDOUT:
      return TTV_EC_SOCKET_ETIMEDOUT;
    case WSAECONNRESET:
      return TTV_EC_SOCKET_ECONNRESET;
    default:
      return TTV_EC_SOCKET_ERR;
  }
}
}  // namespace

namespace ttv {
WSADATA nWsaData;

const int kSocketBufferSize = 0x10000;
const uint64_t kMaxCacheAge = 1000;

static auto deleter = [](addrinfo* p) {
  if (p) {
    freeaddrinfo(p);
  }
};

std::unique_ptr<addrinfo, decltype(deleter)> GetAddresInfo(const std::string& host, const std::string& port) {
  addrinfo hints;
  memset(&hints, 0, sizeof(hints));
  hints.ai_family = AF_UNSPEC;
  hints.ai_socktype = SOCK_STREAM;
  hints.ai_protocol = IPPROTO_TCP;

  std::unique_ptr<addrinfo, decltype(deleter)> ret(nullptr, deleter);

  addrinfo* hostAddressInfo = 0;
  int addrInfoRet = getaddrinfo(host.c_str(), port.c_str(), &hints, &hostAddressInfo);

  if (addrInfoRet == 0) {
    ret.reset(hostAddressInfo);
    hostAddressInfo = 0;
  }

  return ret;
}
}  // namespace ttv

TTV_ErrorCode ttv::WinSocket::InitializeWinSock() {
  int res = WSAStartup(MAKEWORD(2, 2), &nWsaData);
  if (res != NO_ERROR) {
    return ConvertToErrorCode(WSAGetLastError());
  }
  return TTV_EC_SUCCESS;
}

TTV_ErrorCode ttv::WinSocket::ShutdownWinSock() {
  int res = WSACleanup();

  if (res != NO_ERROR) {
    return ConvertToErrorCode(WSAGetLastError());
  }
  return TTV_EC_SUCCESS;
}

ttv::WinSocket::WinSocket() : mSocket(INVALID_SOCKET), mLastSocketError(0), mTotalSent(0), mTotalReceived(0) {
  (void)InitializeWinSock();
}

TTV_ErrorCode ttv::WinSocket::Initialize(const std::string& host, const std::string& port) {
  ttv::trace::Message(
    "Socket", MessageLevel::Debug, "WinSocket: Creating WinSocket for %s:%s", host.c_str(), port.c_str());

  mHost = host;
  mPort = port;

  return TTV_EC_SUCCESS;
}

ttv::WinSocket::~WinSocket() {
  (void)Disconnect();
  (void)ShutdownWinSock();
}

TTV_ErrorCode ttv::WinSocket::Connect() {
  TTV_ASSERT(!Connected());
  if (Connected()) {
    return TTV_EC_SOCKET_EALREADY;
  }

  TTV_ErrorCode ret = TTV_EC_SUCCESS;

  mLastSocketError = 0;
  mTotalSent = 0;
  mTotalReceived = 0;

  auto remoteHost = GetAddresInfo(mHost, mPort);

  if (remoteHost.get() == nullptr) {
    mLastSocketError = WSAGetLastError();
    ttv::trace::Message("Socket", MessageLevel::Error, "Failed in call to getaddrinfo. Error = %d", mLastSocketError);
    ret = ConvertToErrorCode(mLastSocketError);
  }

  if (TTV_SUCCEEDED(ret)) {
    mSocket = WSASocketW(remoteHost->ai_family, remoteHost->ai_socktype, remoteHost->ai_protocol, NULL, 0, 0);
    if (INVALID_SOCKET == mSocket) {
      mLastSocketError = WSAGetLastError();
      ttv::trace::Message("Socket", MessageLevel::Error, "Failed in call to WSASocket. Error = %d", mLastSocketError);
      ret = ConvertToErrorCode(mLastSocketError);
    }
  }

  if (TTV_SUCCEEDED(ret)) {
    auto res =
      WSAConnect(mSocket, remoteHost->ai_addr, (int)remoteHost->ai_addrlen, nullptr, nullptr, nullptr, nullptr);
    if (res == SOCKET_ERROR) {
      mLastSocketError = WSAGetLastError();
      ttv::trace::Message("Socket", MessageLevel::Error, "Failed in call to WSAConnect. Error = %d", mLastSocketError);
      ret = ConvertToErrorCode(mLastSocketError);
    }
  }

  if (TTV_SUCCEEDED(ret)) {
    ret = SetBlockingMode(true);
  } else {
    (void)Disconnect();
  }

  return ret;
}

TTV_ErrorCode ttv::WinSocket::Disconnect() {
  if (mSocket) {
    if (INVALID_SOCKET != mSocket) {
      int ret = WSACloseSocket(mSocket);
      if (ret != 0) {
        mLastSocketError = WSAGetLastError();
      }
    }

    mSocket = INVALID_SOCKET;
  }

  return TTV_EC_SUCCESS;
}

TTV_ErrorCode ttv::WinSocket::TCPListen(const std::string& host, const std::string& port) {
  auto hostInfo = GetAddresInfo(host, port);

  if (hostInfo.get() == nullptr) {
    mLastSocketError = WSAGetLastError();
    ttv::trace::Message(
      "Socket", MessageLevel::Error, "Failed in call to GetAddressInfo. Error = %d", mLastSocketError);
    return ConvertToErrorCode(mLastSocketError);
  }

  mSocket = socket(hostInfo->ai_family, hostInfo->ai_socktype, hostInfo->ai_protocol);
  if (INVALID_SOCKET == mSocket) {
    mLastSocketError = WSAGetLastError();
    ttv::trace::Message("Socket", MessageLevel::Error, "Failed in call to socket. Error = %d", mLastSocketError);
    return ConvertToErrorCode(mLastSocketError);
  }

  auto res = bind(mSocket, hostInfo->ai_addr, static_cast<int>(hostInfo->ai_addrlen));
  if (res == SOCKET_ERROR) {
    mLastSocketError = WSAGetLastError();
    ttv::trace::Message("Socket", MessageLevel::Error, "Failed in call to bind. Error = %d", mLastSocketError);
    return ConvertToErrorCode(mLastSocketError);
  }

  res = listen(mSocket, SOMAXCONN);
  if (res == SOCKET_ERROR) {
    mLastSocketError = WSAGetLastError();
    ttv::trace::Message("Socket", MessageLevel::Error, "Failed in call to bind. Error = %d", mLastSocketError);
    return ConvertToErrorCode(mLastSocketError);
  }

  return TTV_EC_SUCCESS;
}

TTV_ErrorCode ttv::WinSocket::AcceptConnection(std::shared_ptr<ISocket>& newSocket) {
  newSocket.reset();

  if (!mSocket) {
    return TTV_EC_NOT_INITIALIZED;
  }

  auto acceptResult = accept(mSocket, nullptr, 0);

  if (acceptResult == INVALID_SOCKET) {
    mLastSocketError = WSAGetLastError();
    ttv::trace::Message("Socket", MessageLevel::Error, "Failed in call to accept. Error = %d", mLastSocketError);
    return ConvertToErrorCode(mLastSocketError);
  }

  std::shared_ptr<WinSocket> socket = std::make_shared<WinSocket>();
  socket->mSocket = acceptResult;
  newSocket = socket;

  return TTV_EC_SUCCESS;
}

TTV_ErrorCode ttv::WinSocket::Send(const uint8_t* buffer, size_t length, size_t& sent) {
  TTV_ASSERT(buffer);
  TTV_ASSERT(length > 0);

  sent = 0;

  TTV_ErrorCode ec = TTV_EC_SUCCESS;

  WSABUF wsaBuffer;
  wsaBuffer.buf = const_cast<CHAR*>(reinterpret_cast<const CHAR*>(buffer));
  wsaBuffer.len = static_cast<ULONG>(length);
  DWORD dwSent = 0;

  auto res = WSASend(mSocket, &wsaBuffer, 1, &dwSent, 0, nullptr, nullptr);

  if (res == 0) {
    sent = static_cast<size_t>(dwSent);
  } else {
    mLastSocketError = WSAGetLastError();
    ttv::trace::Message("Socket", MessageLevel::Error, "Error Sending from a socket. Error = %d", mLastSocketError);
    ec = ConvertToErrorCode(mLastSocketError);
  }

  mTotalSent += static_cast<uint64_t>(sent);

  if (SOCKET_FAILED(ec)) {
    (void)Disconnect();
  }

  return ec;
}

TTV_ErrorCode ttv::WinSocket::Recv(uint8_t* buffer, size_t length, size_t& received) {
  TTV_ASSERT(buffer);

  received = 0;

  TTV_ErrorCode ec = TTV_EC_SUCCESS;

  if (!Connected()) {
    ec = TTV_EC_SOCKET_ENOTCONN;
  }

  if (TTV_SUCCEEDED(ec)) {
    WSABUF wsaBuffer;
    wsaBuffer.buf = const_cast<CHAR*>(reinterpret_cast<const CHAR*>(buffer));
    wsaBuffer.len = static_cast<ULONG>(length);

    // Check to see if any data is available before trying to read
    WSAPOLLFD status;
    memset(&status, 0, sizeof(status));
    status.fd = mSocket;
    status.events = POLLRDNORM;

    int statusResult = WSAPoll(&status, 1, 0);

    // One structure was updated
    if (statusResult == 1) {
      // Check the flag
      if (status.revents == POLLRDNORM) {
        DWORD got = 0;
        DWORD flags = 0;
        int res = WSARecv(mSocket, &wsaBuffer, 1, &got, &flags, nullptr, nullptr);
        if (res == 0) {
          // If the other end drops the connection we get 0 back as a result
          if (got == 0) {
            ec = TTV_EC_SOCKET_ECONNABORTED;
          }
          // Got some data
          else {
            received += static_cast<size_t>(got);
            mTotalReceived += static_cast<size_t>(got);
          }
        } else {
          mLastSocketError = WSAGetLastError();

          if (WSAEWOULDBLOCK == mLastSocketError) {
            ec = ConvertToErrorCode(mLastSocketError);
          } else {
            TTV_ASSERT(mLastSocketError);
            ttv::trace::Message(
              "Socket", MessageLevel::Error, "Error Sending from a socket. Error = %d", mLastSocketError);
            ec = ConvertToErrorCode(mLastSocketError);
          }
        }
      } else {
        ec = TTV_EC_SOCKET_EWOULDBLOCK;
      }
    }
    // No data available
    else if (statusResult == 0) {
      ec = TTV_EC_SOCKET_EWOULDBLOCK;
    }
    // Socket error, abort
    else {
      ec = TTV_EC_SOCKET_ECONNABORTED;
    }
  }

  if (SOCKET_FAILED(ec)) {
    (void)Disconnect();
  }

  return ec;
}

TTV_ErrorCode ttv::WinSocket::SetBlockingMode(bool block) {
  TTV_ASSERT(mSocket != INVALID_SOCKET);

  TTV_ErrorCode ec = TTV_EC_SUCCESS;

  u_long mode = block ? 0 : 1;
  int res = ioctlsocket(mSocket, FIONBIO, &mode);
  if (res != 0) {
    mLastSocketError = WSAGetLastError();
    ttv::trace::Message("Socket", MessageLevel::Error, "Error Changing blocking mode. Error = %d", mLastSocketError);
    ec = ConvertToErrorCode(mLastSocketError);
  }

  return ec;
}

bool ttv::WinSocket::Connected() {
  return mSocket != INVALID_SOCKET;
}

uint64_t ttv::WinSocket::TotalSent() {
  return mTotalSent;
}

uint64_t ttv::WinSocket::TotalReceived() {
  return mTotalReceived;
}

ttv::WinSocketFactory::~WinSocketFactory() {}

bool ttv::WinSocketFactory::IsProtocolSupported(const std::string& protocol) {
  return protocol == "" || protocol == "tcp";
}

TTV_ErrorCode ttv::WinSocketFactory::CreateSocket(const std::string& uri, std::shared_ptr<ttv::ISocket>& result) {
  result.reset();

  Uri url(uri);
  if (url.GetProtocol() == "" || url.GetProtocol() == "tcp") {
    std::shared_ptr<WinSocket> socket = std::make_shared<WinSocket>();
    if (TTV_SUCCEEDED(socket->Initialize(url.GetHostName(), url.GetPort()))) {
      result = socket;
      return TTV_EC_SUCCESS;
    }
  }

  return TTV_EC_UNIMPLEMENTED;
}
