/****************************************************************************
 * Twitch SDK
 *
 * This software is supplied under the terms of a license agreement with
 * Twitch Interactive, Inc. and may not be copied or used except in accordance
 * with the terms of that agreement
 *
 * Copyright (c) 2012-2016 Twitch Interactive, Inc.
 ***************************************************************************/

#include "twitchsdk/core/internal/pch.h"

#include "twitchsdk/core/java_socket.h"

#include "twitchsdk/core/assertion.h"
#include "twitchsdk/core/generated/jni_all.h"

#define DONT_CALL_INTO_JAVA 0

ttv::binding::java::JavaSocketBase::JavaSocketBase(JNIEnv* jEnv, jobject jInstance) : mByteArraySize(0) {
  mSocketInstance.Bind(jEnv, jInstance);

  AUTO_DELETE_LOCAL_REF(jEnv, jobject, jResultContainer, GetJavaInstance_ResultContainer(jEnv));
  mSentReceivedResultContainer.Bind(jEnv, jResultContainer);
}

void ttv::binding::java::JavaSocketBase::AllocateByteArray(size_t size) {
  mByteArrayInstance.Release();

  AutoJEnv jEnv;

  jbyteArray jArray = jEnv->NewByteArray(static_cast<jsize>(size));
  mByteArrayInstance.Bind(jEnv, jArray);
  mByteArraySize = size;
}

ttv::binding::java::JavaSocketFactoryBase::JavaSocketFactoryBase(
  JNIEnv* jEnv, jobject jInstance, jmethodID isProtocolSupportedId, jmethodID createMethodId)
    : mIsProtocolSupportedId(isProtocolSupportedId), mCreateMethodId(createMethodId) {
  mJavaInstance.Bind(jEnv, jInstance);
}

bool ttv::binding::java::JavaSocketFactoryBase::IsProtocolSupported(const std::string& protocol) {
  bool ret = false;

#if !DONT_CALL_INTO_JAVA

  AutoJEnv jEnv;

  AUTO_DELETE_LOCAL_REF(jEnv, jstring, jProtocol, GetJavaInstance_String(jEnv, protocol));
  jboolean jBoolean = jEnv->CallBooleanMethod(mJavaInstance.GetInstance(), mIsProtocolSupportedId, jProtocol);
  ret = jBoolean != 0;

#else
  (void)protocol;

#endif

  return ret;
}

TTV_ErrorCode ttv::binding::java::JavaSocketFactoryBase::CreateSocket(const std::string& uri, jobject& result) {
  result = nullptr;

  TTV_ErrorCode ec = TTV_EC_UNIMPLEMENTED;

#if !DONT_CALL_INTO_JAVA

  AutoJEnv jEnv;

  AUTO_DELETE_LOCAL_REF(jEnv, jstring, jUri, GetJavaInstance_String(jEnv, uri));
  AUTO_DELETE_LOCAL_REF(jEnv, jobject, jResultContainer, GetJavaInstance_ResultContainer(jEnv));

  // Create the instance
  AUTO_DELETE_LOCAL_REF(jEnv, jobject, jErrorCode,
    jEnv->CallObjectMethod(mJavaInstance.GetInstance(), mCreateMethodId, jUri, jResultContainer));
  ec = GetNativeFromJava_SimpleEnum<TTV_ErrorCode>(
    jEnv, GetJavaClassInfo_ErrorCode(jEnv), jErrorCode, TTV_EC_UNKNOWN_ERROR);

  // Extract it from the container
  if (TTV_SUCCEEDED(ec)) {
    // NOTE: This reference will be released by the caller
    jobject jSocket = GetJavaInstance_GetResultFromResultContainer(jEnv, jResultContainer);
    if (jSocket != nullptr) {
      result = jSocket;
      ec = TTV_EC_SUCCESS;
    }
  }

#else
  (void)uri;

#endif

  return ec;
}

ttv::binding::java::JavaSocket::JavaSocket(JNIEnv* jEnv, jobject jInstance) : JavaSocketBase(jEnv, jInstance) {}

TTV_ErrorCode ttv::binding::java::JavaSocket::Connect() {
#if !DONT_CALL_INTO_JAVA
  AutoJEnv jEnv;

  JavaClassInfo& info = GetJavaClassInfo_ISocket(jEnv);

  ttv::trace::Message("Core", MessageLevel::Debug, "Calling into java to connect to socket...");
  AUTO_DELETE_LOCAL_REF(
    jEnv, jobject, jErrorCode, jEnv->CallObjectMethod(mSocketInstance.GetInstance(), info.methods["connect"]));

  TTV_ErrorCode ec = GetNativeFromJava_SimpleEnum<TTV_ErrorCode>(
    jEnv, GetJavaClassInfo_ErrorCode(jEnv), jErrorCode, TTV_EC_UNKNOWN_ERROR);
  ttv::trace::Message("Core", MessageLevel::Debug, "Done calling into java to connect to socket %s", ErrorToString(ec));
  return ec;

#else
  return TTV_EC_SOCKET_CONNECT_FAILED;

#endif
}

TTV_ErrorCode ttv::binding::java::JavaSocket::Disconnect() {
#if !DONT_CALL_INTO_JAVA
  AutoJEnv jEnv;

  JavaClassInfo& info = GetJavaClassInfo_ISocket(jEnv);

  AUTO_DELETE_LOCAL_REF(
    jEnv, jobject, jErrorCode, jEnv->CallObjectMethod(mSocketInstance.GetInstance(), info.methods["disconnect"]));
  return GetNativeFromJava_SimpleEnum<TTV_ErrorCode>(
    jEnv, GetJavaClassInfo_ErrorCode(jEnv), jErrorCode, TTV_EC_UNKNOWN_ERROR);

#else
  return TTV_EC_SUCCESS;

#endif
}

TTV_ErrorCode ttv::binding::java::JavaSocket::Send(const uint8_t* buffer, size_t length, size_t& sent) {
  TTV_ErrorCode ec = TTV_EC_SOCKET_ENOTCONN;

  sent = 0;

#if !DONT_CALL_INTO_JAVA

  AutoJEnv jEnv;

  JavaClassInfo& info = GetJavaClassInfo_ISocket(jEnv);
  JavaClassInfo& integerInfo = GetJavaClassInfo_Integer(jEnv);

  // Make sure the cached byte array is large enough
  if (mByteArrayInstance.GetInstance() == nullptr || mByteArraySize < length) {
    AllocateByteArray(length);
  }

  // Fill the buffer
  jEnv->SetByteArrayRegion(
    mByteArrayInstance.GetInstanceAsByteArray(), 0, static_cast<jsize>(length), reinterpret_cast<const jbyte*>(buffer));

  // Pass it to Java
  AUTO_DELETE_LOCAL_REF(jEnv, jobject, jErrorCode,
    jEnv->CallObjectMethod(mSocketInstance.GetInstance(), info.methods["send"],
      mByteArrayInstance.GetInstanceAsByteArray(), static_cast<jint>(length),
      mSentReceivedResultContainer.GetInstance()));
  ec = GetNativeFromJava_SimpleEnum<TTV_ErrorCode>(
    jEnv, GetJavaClassInfo_ErrorCode(jEnv), jErrorCode, TTV_EC_UNKNOWN_ERROR);

  if (TTV_SUCCEEDED(ec)) {
    AUTO_DELETE_LOCAL_REF(jEnv, jobject, jReceived,
      GetJavaInstance_GetResultFromResultContainer(jEnv, mSentReceivedResultContainer.GetInstance()));
    jsize jSize = jEnv->CallIntMethod(jReceived, integerInfo.methods["intValue"]);

    sent = static_cast<size_t>(jSize);
  }

#else
  (void)buffer;
  (void)length;

#endif

  return ec;
}

TTV_ErrorCode ttv::binding::java::JavaSocket::Recv(uint8_t* buffer, size_t length, size_t& received) {
  TTV_ErrorCode ec = TTV_EC_SOCKET_ENOTCONN;

  received = 0;

#if !DONT_CALL_INTO_JAVA

  AutoJEnv jEnv;

  JavaClassInfo& info = GetJavaClassInfo_ISocket(jEnv);
  JavaClassInfo& integerInfo = GetJavaClassInfo_Integer(jEnv);

  // Make sure the cached byte array is large enough
  if (mByteArrayInstance.GetInstance() == nullptr || mByteArraySize < length) {
    AllocateByteArray(length);
  }

  // Get some bytes from Java
  AUTO_DELETE_LOCAL_REF(jEnv, jobject, jErrorCode,
    jEnv->CallObjectMethod(mSocketInstance.GetInstance(), info.methods["recv"],
      mByteArrayInstance.GetInstanceAsByteArray(), static_cast<jint>(length),
      mSentReceivedResultContainer.GetInstance()));
  ec = GetNativeFromJava_SimpleEnum<TTV_ErrorCode>(
    jEnv, GetJavaClassInfo_ErrorCode(jEnv), jErrorCode, TTV_EC_UNKNOWN_ERROR);

  if (TTV_SUCCEEDED(ec)) {
    AUTO_DELETE_LOCAL_REF(jEnv, jobject, jReceived,
      GetJavaInstance_GetResultFromResultContainer(jEnv, mSentReceivedResultContainer.GetInstance()));
    jsize jSize = jEnv->CallIntMethod(jReceived, integerInfo.methods["intValue"]);

    received = static_cast<size_t>(jSize);
    TTV_ASSERT(received <= length);
    jEnv->GetByteArrayRegion(mByteArrayInstance.GetInstanceAsByteArray(), 0, jSize, reinterpret_cast<jbyte*>(buffer));
  }

#else
  (void)buffer;
  (void)length;

#endif

  return ec;
}

uint64_t ttv::binding::java::JavaSocket::TotalSent() {
#if !DONT_CALL_INTO_JAVA
  AutoJEnv jEnv;

  JavaClassInfo& info = GetJavaClassInfo_ISocket(jEnv);

  jint jInt = jEnv->CallIntMethod(mSocketInstance.GetInstance(), info.methods["totalSent"]);
  return static_cast<uint64_t>(jInt);

#else
  return 0;

#endif
}

uint64_t ttv::binding::java::JavaSocket::TotalReceived() {
#if !DONT_CALL_INTO_JAVA
  AutoJEnv jEnv;

  JavaClassInfo& info = GetJavaClassInfo_ISocket(jEnv);

  jint jInt = jEnv->CallIntMethod(mSocketInstance.GetInstance(), info.methods["totalReceived"]);
  return static_cast<uint64_t>(jInt);

#else
  return 0;

#endif
}

bool ttv::binding::java::JavaSocket::Connected() {
#if !DONT_CALL_INTO_JAVA
  AutoJEnv jEnv;

  JavaClassInfo& info = GetJavaClassInfo_ISocket(jEnv);

  jboolean jBoolean = jEnv->CallBooleanMethod(mSocketInstance.GetInstance(), info.methods["connected"]);
  return jBoolean != 0;

#else
  return false;

#endif
}

ttv::binding::java::JavaSocketFactory::JavaSocketFactory(JNIEnv* jEnv, jobject jInstance)
    : JavaSocketFactoryBase(jEnv, jInstance, GetJavaClassInfo_ISocketFactory(jEnv).methods["isProtocolSupported"],
        GetJavaClassInfo_ISocketFactory(jEnv).methods["createSocket"]) {}

bool ttv::binding::java::JavaSocketFactory::IsProtocolSupported(const std::string& protocol) {
  return JavaSocketFactoryBase::IsProtocolSupported(protocol);
}

TTV_ErrorCode ttv::binding::java::JavaSocketFactory::CreateSocket(
  const std::string& uri, std::shared_ptr<ISocket>& result) {
  TTV_ErrorCode ec = TTV_EC_SUCCESS;

  result.reset();

#if !DONT_CALL_INTO_JAVA

  AutoJEnv jEnv;
  {
    jobject jSocket = nullptr;
    ec = JavaSocketFactoryBase::CreateSocket(uri, jSocket);
    AUTO_DELETE_LOCAL_REF_NO_DECLARE(jEnv, jobject, jSocket);

    if (TTV_SUCCEEDED(ec) && jSocket != nullptr) {
      result = std::make_shared<JavaSocket>(jEnv, jSocket);
    } else {
      ec = TTV_EC_UNIMPLEMENTED;
    }
  }

#else
  (void)uri;

#endif

  return ec;
}

ttv::binding::java::JavaWebSocket::JavaWebSocket(JNIEnv* jEnv, jobject jInstance) : JavaSocketBase(jEnv, jInstance) {
  AUTO_DELETE_LOCAL_REF(jEnv, jobject, jResultContainer, GetJavaInstance_ResultContainer(jEnv));
  mMessageTypeResultContainer.Bind(jEnv, jResultContainer);
}

TTV_ErrorCode ttv::binding::java::JavaWebSocket::Connect() {
#if !DONT_CALL_INTO_JAVA
  AutoJEnv jEnv;

  JavaClassInfo& info = GetJavaClassInfo_IWebSocket(jEnv);

  ttv::trace::Message("Core", MessageLevel::Debug, "Calling into java to connect to websocket...");
  AUTO_DELETE_LOCAL_REF(
    jEnv, jobject, jErrorCode, jEnv->CallObjectMethod(mSocketInstance.GetInstance(), info.methods["connect"]));

  TTV_ErrorCode ec = GetNativeFromJava_SimpleEnum<TTV_ErrorCode>(
    jEnv, GetJavaClassInfo_ErrorCode(jEnv), jErrorCode, TTV_EC_UNKNOWN_ERROR);
  ttv::trace::Message(
    "Core", MessageLevel::Debug, "Done calling into java to connect to websocket %s", ErrorToString(ec));
  return ec;

#else
  return TTV_EC_SOCKET_CONNECT_FAILED;

#endif
}

TTV_ErrorCode ttv::binding::java::JavaWebSocket::Disconnect() {
#if !DONT_CALL_INTO_JAVA
  AutoJEnv jEnv;

  JavaClassInfo& info = GetJavaClassInfo_IWebSocket(jEnv);

  AUTO_DELETE_LOCAL_REF(
    jEnv, jobject, jErrorCode, jEnv->CallObjectMethod(mSocketInstance.GetInstance(), info.methods["disconnect"]));

  return GetNativeFromJava_SimpleEnum<TTV_ErrorCode>(
    jEnv, GetJavaClassInfo_ErrorCode(jEnv), jErrorCode, TTV_EC_UNKNOWN_ERROR);

#else
  return TTV_EC_SUCCESS;

#endif
}

TTV_ErrorCode ttv::binding::java::JavaWebSocket::Send(MessageType type, const uint8_t* buffer, size_t length) {
  TTV_ErrorCode ec = TTV_EC_SOCKET_ENOTCONN;

#if !DONT_CALL_INTO_JAVA

  AutoJEnv jEnv;

  JavaClassInfo& info = GetJavaClassInfo_IWebSocket(jEnv);
  JavaClassInfo& enumInfo = GetJavaClassInfo_WebSocketMessageType(jEnv);

  // Make sure the cached byte array is large enough
  if (mByteArrayInstance.GetInstance() == nullptr || mByteArraySize < length) {
    AllocateByteArray(length);
  }

  {
    AUTO_DELETE_LOCAL_REF(jEnv, jobject, jMessageType, GetJavaInstance_SimpleEnum(jEnv, enumInfo, type));

    // Fill the buffer
    jEnv->SetByteArrayRegion(mByteArrayInstance.GetInstanceAsByteArray(), 0, static_cast<jsize>(length),
      reinterpret_cast<const jbyte*>(buffer));

    // Pass it to Java
    AUTO_DELETE_LOCAL_REF(jEnv, jobject, jErrorCode,
      jEnv->CallObjectMethod(mSocketInstance.GetInstance(), info.methods["send"], jMessageType,
        mByteArrayInstance.GetInstanceAsByteArray(), static_cast<jint>(length)));
    ec = GetNativeFromJava_SimpleEnum<TTV_ErrorCode>(
      jEnv, GetJavaClassInfo_ErrorCode(jEnv), jErrorCode, TTV_EC_UNKNOWN_ERROR);
  }

#else
  (void)type;
  (void)buffer;
  (void)length;

#endif

  return ec;
}

TTV_ErrorCode ttv::binding::java::JavaWebSocket::Recv(
  MessageType& type, uint8_t* buffer, size_t length, size_t& received) {
  received = 0;
  type = MessageType::Unknown;

  TTV_ErrorCode ec = TTV_EC_SOCKET_ENOTCONN;

#if !DONT_CALL_INTO_JAVA

  AutoJEnv jEnv;

  JavaClassInfo& info = GetJavaClassInfo_IWebSocket(jEnv);
  JavaClassInfo& integerInfo = GetJavaClassInfo_Integer(jEnv);

  // Make sure the cached byte array is large enough
  if (mByteArrayInstance.GetInstance() == nullptr || mByteArraySize < length) {
    AllocateByteArray(length);
  }

  // Get some bytes from Java
  AUTO_DELETE_LOCAL_REF(jEnv, jobject, jErrorCode,
    jEnv->CallObjectMethod(mSocketInstance.GetInstance(), info.methods["recv"],
      mSentReceivedResultContainer.GetInstance(), mByteArrayInstance.GetInstanceAsByteArray(),
      static_cast<jint>(length), mSentReceivedResultContainer.GetInstance()));
  ec = GetNativeFromJava_SimpleEnum<TTV_ErrorCode>(
    jEnv, GetJavaClassInfo_ErrorCode(jEnv), jErrorCode, TTV_EC_UNKNOWN_ERROR);

  if (TTV_SUCCEEDED(ec)) {
    AUTO_DELETE_LOCAL_REF(jEnv, jobject, jReceived,
      GetJavaInstance_GetResultFromResultContainer(jEnv, mSentReceivedResultContainer.GetInstance()));
    jsize jSize = jEnv->CallIntMethod(jReceived, integerInfo.methods["intValue"]);
    if (jSize >= 0) {
      received = static_cast<size_t>(jSize);
      TTV_ASSERT(received <= length);

      AUTO_DELETE_LOCAL_REF(jEnv, jobject, jMessageType,
        GetJavaInstance_GetResultFromResultContainer(jEnv, mMessageTypeResultContainer.GetInstance()));
      type = GetNativeFromJava_SimpleEnum<MessageType>(
        jEnv, GetJavaClassInfo_WebSocketMessageType(jEnv), jMessageType, MessageType::Unknown);

      jEnv->GetByteArrayRegion(mByteArrayInstance.GetInstanceAsByteArray(), 0, jSize, reinterpret_cast<jbyte*>(buffer));
    }
  }

#else
  (void)buffer;
  (void)length;

#endif

  return ec;
}

TTV_ErrorCode ttv::binding::java::JavaWebSocket::Peek(MessageType& type, size_t& length) {
  length = 0;
  type = MessageType::None;

  TTV_ErrorCode ec = TTV_EC_SOCKET_ENOTCONN;

#if !DONT_CALL_INTO_JAVA

  AutoJEnv jEnv;

  JavaClassInfo& info = GetJavaClassInfo_IWebSocket(jEnv);
  JavaClassInfo& integerInfo = GetJavaClassInfo_Integer(jEnv);

  // Get some bytes from Java
  AUTO_DELETE_LOCAL_REF(jEnv, jobject, jErrorCode,
    jEnv->CallObjectMethod(mSocketInstance.GetInstance(), info.methods["peek"],
      mMessageTypeResultContainer.GetInstanceAsByteArray(), mSentReceivedResultContainer.GetInstance()));
  ec = GetNativeFromJava_SimpleEnum<TTV_ErrorCode>(
    jEnv, GetJavaClassInfo_ErrorCode(jEnv), jErrorCode, TTV_EC_UNKNOWN_ERROR);

  if (TTV_SUCCEEDED(ec)) {
    AUTO_DELETE_LOCAL_REF(jEnv, jobject, jMessageType,
      GetJavaInstance_GetResultFromResultContainer(jEnv, mMessageTypeResultContainer.GetInstance()));
    type = GetNativeFromJava_SimpleEnum<MessageType>(
      jEnv, GetJavaClassInfo_WebSocketMessageType(jEnv), jMessageType, MessageType::Unknown);

    AUTO_DELETE_LOCAL_REF(jEnv, jobject, jReceived,
      GetJavaInstance_GetResultFromResultContainer(jEnv, mSentReceivedResultContainer.GetInstance()));
    jsize jSize = jEnv->CallIntMethod(jReceived, integerInfo.methods["intValue"]);
    length = static_cast<size_t>(jSize);
  }

#endif

  return ec;
}

bool ttv::binding::java::JavaWebSocket::Connected() {
#if !DONT_CALL_INTO_JAVA

  AutoJEnv jEnv;

  JavaClassInfo& info = GetJavaClassInfo_IWebSocket(jEnv);

  jboolean jBoolean = jEnv->CallBooleanMethod(mSocketInstance.GetInstance(), info.methods["connected"]);
  return jBoolean != 0;

#else
  return false;

#endif
}

ttv::binding::java::JavaWebSocketFactory::JavaWebSocketFactory(JNIEnv* jEnv, jobject jInstance)
    : JavaSocketFactoryBase(jEnv, jInstance, GetJavaClassInfo_IWebSocketFactory(jEnv).methods["isProtocolSupported"],
        GetJavaClassInfo_IWebSocketFactory(jEnv).methods["createWebSocket"]) {}

bool ttv::binding::java::JavaWebSocketFactory::IsProtocolSupported(const std::string& protocol) {
  return JavaSocketFactoryBase::IsProtocolSupported(protocol);
}

TTV_ErrorCode ttv::binding::java::JavaWebSocketFactory::CreateWebSocket(
  const std::string& uri, std::shared_ptr<IWebSocket>& result) {
  TTV_ErrorCode ec = TTV_EC_UNIMPLEMENTED;

  result.reset();

#if !DONT_CALL_INTO_JAVA

  AutoJEnv jEnv;

  jobject jSocket = nullptr;
  ec = JavaSocketFactoryBase::CreateSocket(uri, jSocket);
  AUTO_DELETE_LOCAL_REF_NO_DECLARE(jEnv, jobject, jSocket);

  if (TTV_SUCCEEDED(ec) && jSocket != nullptr) {
    result = std::make_shared<JavaWebSocket>(jEnv, jSocket);
    ec = TTV_EC_SUCCESS;
  }

#else
  (void)uri;

#endif

  return ec;
}
