/****************************************************************************
 * Twitch SDK
 *
 * This software is supplied under the terms of a license agreement with
 * Twitch Interactive, Inc. and may not be copied or used except in accordance
 * with the terms of that agreement
 *
 * Copyright (c) 2012-2016 Twitch Interactive, Inc.
 ***************************************************************************/

#include "twitchsdk/core/internal/pch.h"

#include "twitchsdk/core/orbissocket.h"

#include "twitchsdk/core/httprequestutils.h"

#include <kernel.h>
#include <net.h>
#include <rtc.h>
#include <sys/errno.h>
#include <sys/socket.h>
#include <sys/types.h>

#define SOCKET_ERROR -1
#define INVALID_SOCKET -1

namespace {
TTV_ErrorCode doResolverNtoa(const char* hostname, SceNetInAddr* addr) {
  TTV_ErrorCode ec = TTV_EC_SUCCESS;
  SceNetId rid = -1;
  int memid = -1;

  int ret = sceNetPoolCreate(__FUNCTION__, 4 * 1024, 0);
  if (ret < 0) {
    ttv::trace::Message(
      "Socket", ttv::MessageLevel::Error, "sceNetPoolCreate() failed (0x%x errno=%d)\n", ret, sce_net_errno);
    return TTV_EC_SOCKET_GETADDRINFO_FAILED;
  }

  memid = ret;
  ret = sceNetResolverCreate("resolver", memid, 0);
  if (ret < 0) {
    ttv::trace::Message(
      "Socket", ttv::MessageLevel::Error, "sceNetResolverCreate() failed (0x%x errno=%d)\n", ret, sce_net_errno);
    ec = TTV_EC_SOCKET_GETADDRINFO_FAILED;
  }

  if (ret >= 0) {
    rid = ret;
    ret = sceNetResolverStartNtoa(rid, hostname, addr, 0, 0, 0);
    if (ret < 0) {
      ttv::trace::Message(
        "Socket", ttv::MessageLevel::Error, "sceNetResolverStartNtoa() failed (0x%x errno=%d)\n", ret, sce_net_errno);
      ec = TTV_EC_SOCKET_GETADDRINFO_FAILED;
    }
  }

  // Cleanup
  if (rid != -1) {
    ret = sceNetResolverDestroy(rid);
    if (ret < 0) {
      ttv::trace::Message(
        "Socket", ttv::MessageLevel::Error, "sceNetResolverDestroy() failed (0x%x errno=%d)\n", ret, sce_net_errno);
      ec = TTV_EC_SOCKET_GETADDRINFO_FAILED;
    }
  }

  if (memid != -1) {
    ret = sceNetPoolDestroy(memid);
    if (ret < 0) {
      ttv::trace::Message(
        "Socket", ttv::MessageLevel::Error, "sceNetPoolDestroy() failed (0x%x errno=%d)\n", ret, sce_net_errno);
      ec = TTV_EC_SOCKET_GETADDRINFO_FAILED;
    }
  }

  return ec;
}
}  // namespace

ttv::OrbisSocket::OrbisSocket(const std::string& host, uint32_t port)
    : mHostName(host), mTotalRecieved(0), mTotalSent(0), mSocket(0), mCachePos(0), mLastSocketError(0) {
  char buffer[64];
  (void)snprintf(buffer, sizeof(buffer), "%u", port);
  mPort = buffer;
}

ttv::OrbisSocket::~OrbisSocket() {
  Disconnect();
}

TTV_ErrorCode ttv::OrbisSocket::Connect() {
  if (Connected()) {
    return TTV_EC_SOCKET_EALREADY;
  }

  mLastSocketError = 0;
  mCachePos = 0;
  mTotalSent = 0;
  mTotalRecieved = 0;

  uint32_t portNum = 0;
  int numScanned = sscanf(mPort.c_str(), "%u", &portNum);
  if (numScanned != 1) {
    return TTV_EC_INVALID_ARG;
  }

  TTV_ErrorCode ec = TTV_EC_SUCCESS;

  // DNS lookup
  SceNetSockaddrIn socketAddress;
  memset(&socketAddress, 0, sizeof(socketAddress));
  socketAddress.sin_len = sizeof(socketAddress);
  socketAddress.sin_family = SCE_NET_AF_INET;
  socketAddress.sin_port = sceNetHtons(static_cast<SceNetInPort_t>(portNum));

  if (sceNetInetPton(SCE_NET_AF_INET, mHostName.c_str(), &socketAddress.sin_addr) == 0) {
    ec = doResolverNtoa(mHostName.c_str(), &socketAddress.sin_addr);
    if (TTV_FAILED(ec) || socketAddress.sin_addr.s_addr == 0) {
      ttv::trace::Message("Socket", MessageLevel::Error, "Failed to resolve hostname at %s\n", mHostName.c_str());
    }
  }

  // Allow multiple sockets to bind to the same port
  if (TTV_SUCCEEDED(ec)) {
    mSocket = sceNetSocket("TwitchSocket", SCE_NET_AF_INET, SCE_NET_SOCK_STREAM, SCE_NET_IPPROTO_TCP);

    int optval = 1;
    int ret = sceNetSetsockopt(mSocket, SCE_NET_SOL_SOCKET, SCE_NET_SO_REUSEADDR, &optval, sizeof(optval));
    if (ret < 0) {
      ec = TTV_EC_SOCKET_CONNECT_FAILED;
      ttv::trace::Message(
        "Socket", MessageLevel::Error, "sceNetSetsockopt(SO_REUSEADR) failed( 0x%08x, errno%d)\n", ret, sce_net_errno);
    } else {
      int ret = sceNetConnect(mSocket, (SceNetSockaddr*)&socketAddress, sizeof(socketAddress));
      if (ret < 0) {
        ec = TTV_EC_SOCKET_CONNECT_FAILED;
        ttv::trace::Message("Socket", MessageLevel::Error, "sceNetConnect() failed (errno=%d)\n", sce_net_errno);
      }
    }
  }

  if (TTV_SUCCEEDED(ec)) {
    SetBlockingMode(true);
  } else {
    Disconnect();
  }

  return ec;
}

TTV_ErrorCode ttv::OrbisSocket::Disconnect() {
  if (mSocket > 0) {
    int ret = sceNetShutdown(mSocket, SCE_NET_SHUT_RDWR);
    if (ret < 0) {
      ttv::trace::Message("Socket", MessageLevel::Error, "sceNetShutdown() failed (errno=%d)\n", sce_net_errno);
    }

    ret = sceNetSocketClose(mSocket);
    if (ret < 0) {
      ttv::trace::Message("Socket", MessageLevel::Error, "sceNetSocketClose() failed (errno=%d)\n", sce_net_errno);
    }

    mSocket = 0;
  }

  return TTV_EC_SUCCESS;
}

TTV_ErrorCode ttv::OrbisSocket::Send(const uint8_t* buffer, size_t length, size_t& sent) {
  assert(buffer);
  assert(length > 0);

  sent = 0;

  TTV_ErrorCode ret = TTV_EC_SUCCESS;

  if (!Connected()) {
    return TTV_EC_SOCKET_ENOTCONN;
  }

  auto socketSent = sceNetSend(mSocket, reinterpret_cast<const char*>(buffer), length, 0);
  if (socketSent < 0) {
    mLastSocketError = sce_net_errno;
    ttv::trace::Message("Socket", MessageLevel::Error, "Error Sending from a socket. Error = %d", mLastSocketError);
    ret = TTV_EC_SOCKET_SEND_ERROR;
  } else {
    sent += static_cast<size_t>(socketSent);
    mTotalSent += socketSent;
  }

  if (TTV_FAILED(ret)) {
    Disconnect();
  }

  return ret;
}

TTV_ErrorCode ttv::OrbisSocket::Recv(uint8_t* buffer, size_t length, size_t& received) {
  assert(buffer);

  received = 0;

  if (!Connected()) {
    return TTV_EC_SOCKET_ENOTCONN;
  }

  TTV_ErrorCode ret = TTV_EC_SUCCESS;

  // NOTE: We use MSG_DONTWAIT to make the recv non-blocking even though the socket is configured to be blocking
  auto result = sceNetRecv(mSocket, reinterpret_cast<char*>(buffer), static_cast<int>(length), SCE_NET_MSG_DONTWAIT);
  if (result > 0) {
    received = static_cast<size_t>(result);
    mTotalRecieved += static_cast<uint64_t>(result);
  }
  // If the other end drops the connection we get 0 back as a result
  else if (result == 0) {
    ret = TTV_EC_SOCKET_ECONNABORTED;
  } else {
    mLastSocketError = sce_net_errno;
    if (EWOULDBLOCK == mLastSocketError) {
      ret = TTV_EC_SOCKET_EWOULDBLOCK;
    } else {
      ttv::trace::Message("Socket", MessageLevel::Error, "Error Sending from a socket. Error = %d", mLastSocketError);
      ret = TTV_EC_SOCKET_RECV_ERROR;
    }
  }

  if (SOCKET_FAILED(ret)) {
    Disconnect();
  }

  return ret;
}

TTV_ErrorCode ttv::OrbisSocket::SetBlockingMode(bool block) {
  TTV_ErrorCode ec = TTV_EC_SUCCESS;

  if (mSocket != 0 && mSocket != INVALID_SOCKET) {
    int value = block ? 0 : 1;
    int ret = sceNetSetsockopt(mSocket, SCE_NET_SOL_SOCKET, SCE_NET_SO_NBIO, &value, sizeof(value));
    if (ret < 0) {
      mLastSocketError = errno;
      ec = TTV_EC_SOCKET_IOCTL_ERROR;
      ttv::trace::Message("Socket", MessageLevel::Error, "sceNetSetsockopt(SO_REUSEADR) failed( 0x%08x, errno%d)\n",
        ret, mLastSocketError);
    }
  }

  return ec;
}

bool ttv::OrbisSocket::Connected() {
  return mSocket != 0;
}

uint64_t ttv::OrbisSocket::TotalSent() {
  return mTotalSent;
}

uint64_t ttv::OrbisSocket::TotalReceived() {
  return mTotalRecieved;
}

bool ttv::OrbisSocketFactory::IsProtocolSupported(const std::string& protocol) {
  return protocol == "" || protocol == "tcp";
}

TTV_ErrorCode ttv::OrbisSocketFactory::CreateSocket(const std::string& uri, std::shared_ptr<ttv::ISocket>& result) {
  result.reset();

  ttv::Uri url(uri);
  if (url.GetProtocol() == "" || url.GetProtocol() == "tcp") {
    uint32_t port = 0;

    if (url.GetPort() != "") {
      int numRead = sscanf(url.GetPort().c_str(), "%u", &port);
      if (numRead != 1) {
        return TTV_EC_INVALID_ARG;
      }
    }

    result = std::make_shared<OrbisSocket>(url.GetHostName(), port);

    return TTV_EC_SUCCESS;
  }

  return TTV_EC_UNIMPLEMENTED;
}
