/****************************************************************************
 * Twitch SDK
 *
 * This software is supplied under the terms of a license agreement with
 * Twitch Interactive, Inc. and may not be copied or used except in accordance
 * with the terms of that agreement
 *
 * Copyright (c) 2012-2016 Twitch Interactive, Inc.
 ***************************************************************************/

#include "twitchsdk/core/internal/pch.h"

#include "twitchsdk/core/win8/winappsocket.h"

#include "twitchsdk/core/assertion.h"
#include "twitchsdk/core/thread.h"

#include <robuffer.h>

#include <codecvt>

using namespace ABI::Windows::Foundation;
using namespace ABI::Windows::Networking;
using namespace ABI::Windows::Networking::WinAppSockets;
using namespace ABI::Windows::Storage::Streams;
using namespace Microsoft::WRL;
using namespace Microsoft::WRL::Wrappers;

#define RETURN_ON_ERROR(hr, ret) \
  if (hr != S_OK) {              \
    return ret;                  \
  }

namespace ttv {
static HRESULT WaitForAsyncOpComplete(ComPtr<IAsyncInfo> asyncInfo) {
  AsyncStatus sts;
  HRESULT hr = asyncInfo->get_Status(&sts);
  RETURN_ON_ERROR(hr, hr);
  while (sts == AsyncStatus::Started) {
    Sleep(10);
    hr = asyncInfo->get_Status(&sts);
    RETURN_ON_ERROR(hr, hr);
  }
  while (sts == AsyncStatus::Started)
    ;

  return hr;
}
}  // namespace ttv

ttv::WinAppSocket::WinAppSocket()
    : mSocket(0), mLastSocketError(0), mTotalSent(0), mTotalRecieved(0), mAsyncReadInProgress(false) {}

ttv::WinAppSocket::~WinAppSocket() {
  (void)Disconnect();
}

TTV_ErrorCode ttv::WinAppSocket::Initialize(const std::string& host, const std::string& port) {
  mHostName = host;
  mPort = port;
}

TTV_ErrorCode ttv::WinAppSocket::Connect() {
  TTV_ASSERT(!Connected());
  if (Connected()) {
    return TTV_EC_SOCKET_EALREADY;
  }

  HRESULT hr = Windows::Foundation::ActivateInstance<ComPtr<IStreamSocket>>(
    HStringReference(RuntimeClass_Windows_Networking_Sockets_StreamSocket).Get(), &mSocket);
  TTV_ASSERT(SUCCEEDED(hr));

  mLastSocketError = 0;
  mLastFlushTime = 0;
  mCachePos = 0;
  mTotalSent = 0;
  mTotalRecieved = 0;

  TTV_ErrorCode ret = TTV_EC_SUCCESS;

  std::wstring_convert<std::codecvt_utf8_utf16<wchar_t>> converter;
  ComPtr<IHostNameFactory> hostNameFactory;
  hr = Windows::Foundation::GetActivationFactory(
    HStringReference(RuntimeClass_Windows_Networking_HostName).Get(), &hostNameFactory);
  RETURN_ON_ERROR(hr, TTV_EC_SOCKET_CONNECT_FAILED);
  ComPtr<IHostName> hostName;
  HString hostStr;
  hostStr.Set(converter.from_bytes(mHostName).c_str());
  hostNameFactory->CreateHostName(hostStr.Get(), &hostName);
  HString portStr;
  portStr.Set(converter.from_bytes(mPort).c_str());

  ComPtr<IAsyncAction> connectOp;
  hr = mSocket->ConnectAsync(hostName.Get(), portStr.Get(), &connectOp);
  RETURN_ON_ERROR(hr, TTV_EC_SOCKET_CONNECT_FAILED);

  if (TTV_SUCCEEDED(hr)) {
    ComPtr<IAsyncInfo> info;
    connectOp.As(&info);
    hr = WaitForAsyncOpComplete(info);
  }

  if (SUCCEEDED(hr)) {
    hr = connectOp->GetResults();

    if (FAILED(hr)) {
      ret = TTV_EC_SOCKET_CONNECT_FAILED;
    }
  }

  if (!SUCCEEDED(hr)) {
    (void)Disconnect();
    ret = TTV_EC_SOCKET_CONNECT_FAILED;
  }

  return ret;
}

TTV_ErrorCode ttv::WinAppSocket::Disconnect() {
  ComPtr<IClosable> closableSocket;

  if (mSocket) {
    if (SUCCEEDED(mSocket.As(&closableSocket))) {
      closableSocket->Close();
    }
    mSocket.Reset();
  }

  return TTV_EC_SUCCESS;
}

TTV_ErrorCode ttv::WinAppSocket::Send(const uint8_t* buffer, size_t length) {
  TTV_ASSERT(buffer);
  TTV_ASSERT(length > 0);

  HRESULT hr;

  ComPtr<IBufferFactory> bufferFactory;
  hr = Windows::Foundation::GetActivationFactory(
    HStringReference(RuntimeClass_Windows_Storage_Streams_Buffer).Get(), &bufferFactory);

  ComPtr<IBuffer> outputBuffer;
  if (SUCCEEDED(hr)) {
    hr = bufferFactory->Create(static_cast<UINT32>(length), &outputBuffer);
  }

  ComPtr<Windows::Storage::Streams::IBufferByteAccess> bufferByteAccess;
  if (SUCCEEDED(hr)) {
    hr = outputBuffer.As(&bufferByteAccess);
  }

  BYTE* bufferBytes = nullptr;
  if (SUCCEEDED(hr)) {
    hr = bufferByteAccess->Buffer(&bufferBytes);
  }

  if (SUCCEEDED(hr)) {
    memcpy(bufferBytes, buffer, length);
    hr = outputBuffer->put_Length(static_cast<UINT32>(length));
  }

  ComPtr<IOutputStream> outputStream;
  if (SUCCEEDED(hr)) {
    hr = mSocket->get_OutputStream(&outputStream);
  }

  ComPtr<IAsyncOperationWithProgress<UINT32, UINT32>> writeOp;
  if (SUCCEEDED(hr)) {
    hr = outputStream->WriteAsync(outputBuffer.Get(), &writeOp);
  }

  ComPtr<IAsyncInfo> info;
  if (SUCCEEDED(hr)) {
    writeOp.As(&info);
    hr = WaitForAsyncOpComplete(info);
  }

  if (SUCCEEDED(hr)) {
    mTotalSent += length;
    return TTV_EC_SUCCESS;
  } else {
    (void)Disconnect();
    return TTV_EC_SOCKET_CONNECT_FAILED;
  }
}

TTV_ErrorCode ttv::WinAppSocket::Recv(uint8_t* buffer, size_t length, size_t& received) {
  TTV_ASSERT(buffer);
  received = 0;

  if (!Connected()) {
    return TTV_EC_SOCKET_ENOTCONN;
  }

  HRESULT hr = S_OK;
  TTV_ErrorCode ret = TTV_EC_SUCCESS;
  bool receivedData = false;

  // Fill the buffer with the data that has already been received
  if (mAsyncReadInProgress) {
    ComPtr<IAsyncInfo> info;
    mAsyncReadOp.As(&info);
    TTV_ASSERT(info);
    AsyncStatus status;
    hr = info->get_Status(&status);

    if (SUCCEEDED(hr)) {
      switch (status) {
        case AsyncStatus::Completed: {
          // Read operation is done; copy the data
          //
          UINT32 readLength = 0;
          hr = ReadAndResetInputBuffer(buffer, length, readLength);
          if (FAILED(hr)) {
            ret = TTV_EC_SOCKET_RECV_ERROR;
            break;
          }

          received = readLength;
          mTotalRecieved += readLength;

          mAsyncReadInProgress = false;
          receivedData = true;
          break;
        }
        case AsyncStatus::Started: {
          // Read operation is still in progress; data not available
          ret = TTV_EC_SOCKET_EWOULDBLOCK;
          break;
        }
        default: {
          ret = TTV_EC_SOCKET_ERR;
          mAsyncReadInProgress = false;
          break;
        }
      }
    } else {
      ret = TTV_EC_SOCKET_RECV_ERROR;
    }
  }

  // Preemptively request more data so it will be ready for the next call
  if (!mAsyncReadInProgress && TTV_SUCCEEDED(ret)) {
    if (!mBufferFactory) {
      hr = Windows::Foundation::GetActivationFactory(
        HStringReference(RuntimeClass_Windows_Storage_Streams_Buffer).Get(), &mBufferFactory);
    }

    if (SUCCEEDED(hr)) {
      TTV_ASSERT(mBufferFactory);
      TTV_ASSERT(!mInputBuffer);
      hr = mBufferFactory->Create(static_cast<UINT32>(length), &mInputBuffer);
      TTV_ASSERT(mInputBuffer);
    }

    ComPtr<IInputStream> inputStream;
    if (SUCCEEDED(hr)) {
      hr = mSocket->get_InputStream(&inputStream);
    }

    if (SUCCEEDED(hr)) {
      hr = inputStream->ReadAsync(
        mInputBuffer.Get(), static_cast<UINT32>(length), InputStreamOptions::InputStreamOptions_Partial, &mAsyncReadOp);
    }

    if (SUCCEEDED(hr)) {
      mAsyncReadInProgress = true;
    }
  }

  if (SOCKET_FAILED(ret)) {
    ret = TTV_EC_SOCKET_RECV_ERROR;
    (void)Disconnect();
  }

  return ret;
}

HRESULT ttv::WinAppSocket::ReadAndResetInputBuffer(uint8_t* buffer, size_t length, UINT32& readLength) {
  (void)length;
  TTV_ASSERT(mInputBuffer);
  HRESULT hr = mInputBuffer->get_Length(&readLength);
  TTV_ASSERT(readLength <= length);

  if (SUCCEEDED(hr)) {
    ComPtr<Windows::Storage::Streams::IBufferByteAccess> bufferByteAccess;
    hr = mInputBuffer.As(&bufferByteAccess);

    if (SUCCEEDED(hr)) {
      BYTE* bufferBytes;
      hr = bufferByteAccess->Buffer(&bufferBytes);

      if (SUCCEEDED(hr)) {
        memcpy(buffer, bufferBytes, readLength);
        mInputBuffer.Reset();
      }
    }
  }
  return hr;
}

bool ttv::WinAppSocket::Connected() {
  return mSocket != 0;
}

uint64_t ttv::WinAppSocket::TotalSent() const {
  return mTotalSent;
}

uint64_t ttv::WinAppSocket::TotalReceived() const {
  return mTotalRecieved;
}

ttv::WinAppSocketFactory::~WinAppSocketFactory() {}

bool ttv::WinAppSocketFactory::IsProtocolSupported(const std::string& protocol) {
  return protocol == "" || protocol == "tcp";
}

TTV_ErrorCode ttv::WinAppSocketFactory::CreateSocket(const std::string& uri, std::shared_ptr<ttv::ISocket>& result) {
  result.reset();

  Uri url(uri);
  if (url.GetProtocol() == "" || url.GetProtocol() == "tcp") {
    std::shared_ptr<WinAppSocket> socket = std::make_shared<WinAppSocket>();
    if (TTV_SUCCEEDED(socket->Initialize(url.GetHostName(), url.GetPort()))) {
      result = socket;
      return TTV_EC_SUCCESS;
    }
  }

  return TTV_EC_UNIMPLEMENTED;
}
