/****************************************************************************
 * Twitch SDK
 *
 * This software is supplied under the terms of a license agreement with
 * Twitch Interactive, Inc. and may not be copied or used except in accordance
 * with the terms of that agreement
 *
 * Copyright (c) 2012-2016 Twitch Interactive, Inc.
 ***************************************************************************/

#include "twitchsdk/core/internal/pch.h"

#include "twitchsdk/core/standardsocket.h"

#include "twitchsdk/core/assertion.h"
#include "twitchsdk/core/httprequestutils.h"
#include "twitchsdk/core/stringutilities.h"

#include <arpa/inet.h>
#include <fcntl.h>
#include <netdb.h>
#include <netinet/in.h>
#include <sys/errno.h>
#include <sys/ioctl.h>
#include <sys/socket.h>

#define SOCKET_ERROR -1
#define INVALID_SOCKET -1

ttv::StandardSocket::StandardSocket(const std::string& host, uint32_t port)
    : mHostName(host), mSocket(0), mLastSocketError(0), mTotalSent(0), mTotalRecieved(0) {
  char buffer[64];
  (void)snprintf(buffer, sizeof(buffer), "%u", port);
  mPort = buffer;
}

ttv::StandardSocket::~StandardSocket() {
  (void)Disconnect();
}

TTV_ErrorCode ttv::StandardSocket::Connect() {
  TTV_ASSERT(!Connected());
  if (Connected()) {
    return TTV_EC_SOCKET_EALREADY;
  }

  mLastSocketError = 0;
  mTotalSent = 0;
  mTotalRecieved = 0;

  addrinfo hints;
  memset(&hints, 0, sizeof(hints));
  hints.ai_family = AF_INET;
  hints.ai_socktype = SOCK_STREAM;
  hints.ai_protocol = IPPROTO_TCP;

  auto deleter = [](addrinfo* p) {
    if (p) {
      freeaddrinfo(p);
    }
  };
  std::unique_ptr<addrinfo, decltype(deleter)> remoteHost(nullptr, deleter);

  addrinfo* tempRemoteHost = 0;
  int addrInfoRet = getaddrinfo(mHostName.c_str(), mPort.c_str(), &hints, &tempRemoteHost);
  remoteHost.reset(tempRemoteHost);
  tempRemoteHost = 0;

  TTV_ErrorCode ret = TTV_EC_SUCCESS;

  if (remoteHost.get() == nullptr || addrInfoRet != 0) {
    mLastSocketError = errno;
    ttv::trace::Message("Socket", MessageLevel::Error, "Failed in call to getaddrinfo. Error = %d", mLastSocketError);
    ret = TTV_EC_SOCKET_GETADDRINFO_FAILED;
  }

  if (TTV_SUCCEEDED(ret)) {
    mSocket = socket(remoteHost->ai_family, remoteHost->ai_socktype, remoteHost->ai_protocol);
    if (INVALID_SOCKET == mSocket) {
      mLastSocketError = errno;
      ttv::trace::Message("Socket", MessageLevel::Error, "Failed in call to socket. Error = %d", mLastSocketError);
      ret = TTV_EC_SOCKET_CREATE_FAILED;
    }
  }

  if (TTV_SUCCEEDED(ret)) {
    auto res = connect(mSocket, remoteHost->ai_addr, static_cast<socklen_t>(remoteHost->ai_addrlen));
    if (SOCKET_ERROR == res) {
      mLastSocketError = errno;
      ttv::trace::Message("Socket", MessageLevel::Error, "Failed in call to socket. Error = %d", mLastSocketError);
      ret = TTV_EC_SOCKET_CONNECT_FAILED;
    }
  }

  if (TTV_SUCCEEDED(ret)) {
    SetBlockingMode(true);
  } else {
    (void)Disconnect();
  }

  return TTV_EC_SUCCESS;
}

TTV_ErrorCode ttv::StandardSocket::Disconnect() {
  if (mSocket) {
    shutdown(mSocket, SHUT_RDWR);
    mSocket = 0;
  }

  return TTV_EC_SUCCESS;
}

TTV_ErrorCode ttv::StandardSocket::Send(const uint8_t* buffer, size_t length, size_t& sent) {
  TTV_ASSERT(buffer);
  TTV_ASSERT(length > 0);

  sent = 0;

  TTV_ErrorCode ec = TTV_EC_SUCCESS;

  auto socketSent = send(mSocket, reinterpret_cast<const char*>(buffer), length, 0);
  if (socketSent < 0) {
    mLastSocketError = errno;
    TTV_ASSERT(mLastSocketError);
    ttv::trace::Message("Socket", MessageLevel::Error, "Error Sending from a socket. Error = %d", mLastSocketError);
    ec = TTV_EC_SOCKET_SEND_ERROR;
  } else {
    sent = static_cast<size_t>(socketSent);
    mTotalSent += static_cast<uint64_t>(socketSent);
  }

  if (TTV_FAILED(ec)) {
    (void)Disconnect();
  }

  return ec;
}

TTV_ErrorCode ttv::StandardSocket::Recv(uint8_t* buffer, size_t length, size_t& received) {
  TTV_ASSERT(buffer);

  received = 0;

  if (!Connected()) {
    return TTV_EC_SOCKET_ENOTCONN;
  }

  TTV_ErrorCode ec = TTV_EC_SUCCESS;

  // NOTE: We use MSG_DONTWAIT to make the recv non-blocking even though the socket is configured to be blocking
  auto result = recv(mSocket, reinterpret_cast<char*>(buffer), static_cast<size_t>(length), MSG_DONTWAIT);
  if (result > 0) {
    received = static_cast<size_t>(result);
    mTotalRecieved += static_cast<uint64_t>(result);

  }
  // if the other end drops the connection we get 0 back as a result
  else if (result == 0) {
    ec = TTV_EC_SOCKET_ECONNABORTED;
  } else {
    mLastSocketError = errno;
    if (EWOULDBLOCK == mLastSocketError) {
      ec = TTV_EC_SOCKET_EWOULDBLOCK;
    } else {
      TTV_ASSERT(mLastSocketError);

      ttv::trace::Message("Socket", MessageLevel::Error, "Error Sending from a socket. Error = %d", mLastSocketError);
      ec = TTV_EC_SOCKET_RECV_ERROR;
    }
  }

  if (SOCKET_FAILED(ec)) {
    (void)Disconnect();
  }

  return ec;
}

TTV_ErrorCode ttv::StandardSocket::SetBlockingMode(bool block) {
  u_long mode = block ? 0 : 1;
  int res = ioctl(mSocket, FIONBIO, &mode);
  if (res != 0) {
    mLastSocketError = errno;
    ttv::trace::Message("Socket", MessageLevel::Error, "Error Changing blocking mode. Error = %d", mLastSocketError);
    return TTV_EC_SOCKET_IOCTL_ERROR;
  }

  return TTV_EC_SUCCESS;
}

bool ttv::StandardSocket::Connected() {
  return mSocket != 0;
}

uint64_t ttv::StandardSocket::TotalSent() {
  return mTotalSent;
}

uint64_t ttv::StandardSocket::TotalReceived() {
  return mTotalRecieved;
}

ttv::StandardSocketFactory::~StandardSocketFactory() {}

bool ttv::StandardSocketFactory::IsProtocolSupported(const std::string& protocol) {
  return protocol == "" || protocol == "tcp";
}

TTV_ErrorCode ttv::StandardSocketFactory::CreateSocket(const std::string& uri, std::shared_ptr<ttv::ISocket>& result) {
  result.reset();

  Uri url(uri);
  if (url.GetProtocol() == "" || url.GetProtocol() == "tcp") {
    uint32_t port = 0;

    if (url.GetPort() != "") {
      if (!ParseNum(url.GetPort(), port)) {
        return TTV_EC_INVALID_ARG;
      }
    }

    result = std::make_shared<StandardSocket>(url.GetHostName(), port);

    return TTV_EC_SUCCESS;
  }

  return TTV_EC_UNIMPLEMENTED;
}
