/****************************************************************************
 * Twitch SDK
 *
 * This software is supplied under the terms of a license agreement with
 * Twitch Interactive, Inc. and may not be copied or used except in accordance
 * with the terms of that agreement
 *
 * Copyright (c) 2012-2016 Twitch Interactive, Inc.
 ***************************************************************************/

#pragma once

#include "twitchsdk/core/generated/jni_all.h"
#include "twitchsdk/core/httprequest.h"
#include "twitchsdk/core/java_classinfo.h"
#include "twitchsdk/core/mutex.h"
#include "twitchsdk/core/result.h"
#include "twitchsdk/core/types/errortypes.h"

#include <jni.h>
#include <unordered_map>

#include <map>
#include <vector>

#define TTV_JNI_RETURN_ON_NULL(jni, ptr, err)                           \
  {                                                                     \
    if ((ptr) == nullptr) {                                             \
      ASSERT_ON_ERROR(err);                                             \
      return ::ttv::binding::java::GetJavaInstance_ErrorCode(jni, err); \
    }                                                                   \
  }
#define TTV_JNI_RETURN_ON_NOT_NULL(jni, ptr, err)                       \
  {                                                                     \
    if ((ptr) != nullptr) {                                             \
      ASSERT_ON_ERROR(err);                                             \
      return ::ttv::binding::java::GetJavaInstance_ErrorCode(jni, err); \
    }                                                                   \
  }
#define TTV_JNI_RETURN_ON_FALSE(jni, val, err)                          \
  {                                                                     \
    if (!val) {                                                         \
      return ::ttv::binding::java::GetJavaInstance_ErrorCode(jni, err); \
    }                                                                   \
  }

// Used to release local references when they variable goes out of scope.
#define AUTO_DELETE_LOCAL_REF(JENV, TYPE, VAR, VALUE) \
  TYPE VAR = static_cast<TYPE>(VALUE);                \
  ::ttv::binding::java::JavaLocalReferenceDeleter VAR##_Deleter(JENV, VAR, #VAR);

#define AUTO_DELETE_LOCAL_REF_NO_DECLARE(JENV, TYPE, VAR) \
  ::ttv::binding::java::JavaLocalReferenceDeleter VAR##_Deleter(JENV, VAR, #VAR);

namespace ttv {
namespace binding {
namespace java {
extern JavaVM* gGlobalJavaVirtualMachine;  //!< The Java virtual machine
extern JNIEnv*
  gActiveJavaEnvironment;  //!< This is cached on every call into native code so that it's current.  Never use from another thread.

class JavaLocalReferenceDeleter;
class ScopedJavaUTFStringConverter;
class ScopedJavaWcharStringConverter;
class ScopedJavaEnvironmentCacher;
class AutoJEnv;
class GlobalJavaObjectReference;

template <typename PROXY_TYPE>
struct ProxyContext;

template <typename PROXY_TYPE, typename LISTENER_TYPE>
struct ProxyContextWithListener;

template <typename NativeInstanceType, typename ContextType>
class JavaNativeProxyRegistry;

bool CacheJavaVirtualMachine(JNIEnv* jEnv);

void JniThreadInitialize();

JavaClassInfo& GetJavaClassInfo_Boolean(JNIEnv* jEnv);
JavaClassInfo& GetJavaClassInfo_Integer(JNIEnv* jEnv);
JavaClassInfo& GetJavaClassInfo_Long(JNIEnv* jEnv);
JavaClassInfo& GetJavaClassInfo_Float(JNIEnv* jEnv);
JavaClassInfo& GetJavaClassInfo_Double(JNIEnv* jEnv);
JavaClassInfo& GetJavaClassInfo_String(JNIEnv* jEnv);
JavaClassInfo& GetJavaClassInfo_Charset(JNIEnv* jEnv);
JavaClassInfo& GetJavaClassInfo_HashSet(JNIEnv* jEnv);
JavaClassInfo& GetJavaClassInfo_HashMap(JNIEnv* jEnv);

jobject GetJavaInstance_Boolean(JNIEnv* jEnv, bool value);
jobject GetJavaInstance_Integer(JNIEnv* jEnv, int32_t value);
jobject GetJavaInstance_Integer(JNIEnv* jEnv, uint32_t value);
jobject GetJavaInstance_Long(JNIEnv* jEnv, uint64_t value);
jobject GetJavaInstance_Long(JNIEnv* jEnv, int64_t value);
jobject GetJavaInstance_Float(JNIEnv* jEnv, float value);
jobject GetJavaInstance_Double(JNIEnv* jEnv, double value);

jstring GetJavaInstance_StringWithEncoding(JNIEnv* jEnv, const std::string& str);
jstring GetJavaInstance_String(JNIEnv* jEnv, const char* str);
jstring GetJavaInstance_String(JNIEnv* jEnv, const std::string& str);

jobject GetJavaInstance_ErrorCode(JNIEnv* jEnv, TTV_ErrorCode err);
jobject GetJavaInstance_EnumValue(JNIEnv* jEnv, const ttv::EnumValue& value);
jobjectArray GetJavaInstance_EnumValueArray(JNIEnv* jEnv, const std::vector<ttv::EnumValue>& arr);
jobject GetJavaInstance_ErrorCode(JNIEnv* jEnv, TTV_ErrorCode err);
jobjectArray GetJavaInstance_StringArray(JNIEnv* jEnv, const std::vector<std::string>& arr);
jobject GetJavaInstance_StringHashMap(JNIEnv* jEnv, const std::map<std::string, std::string>& map);
jobject GetJavaInstance_ResultContainer(JNIEnv* jEnv);
jobject GetJavaInstance_GetResultFromResultContainer(JNIEnv* jEnv, jobject jResultContainer);

jobject GetJavaInstance_TaskId(JNIEnv* jEnv, TaskId taskId);

void GetNativeInstance_HttpRequestResult(JNIEnv* jEnv, jobject jRequestResult, uint& statusCode,
  std::map<std::string, std::string>& resultHeaders, std::vector<char>& response);
jobject GetJavaInstance_HttpRequestResult(JNIEnv* jEnv);
jobject GetJavaInstance_HttpParameter(JNIEnv* jEnv, const ttv::HttpParam& param);
jobject GetJavaInstance_HttpParameterArray(JNIEnv* jEnv, const std::vector<ttv::HttpParam>& params);

void GetNativeInstance_StringVector(JNIEnv* jEnv, jobjectArray jArray, std::vector<std::string>& result);

void SetResultContainerResult(JNIEnv* jEnv, jobject jResultContainer, jobject jResult);

jobjectArray GetJavaInstance_Array(JNIEnv* jEnv, JavaClassInfo& javaArrayTypeClassInfo, const uint32_t size,
  std::function<jobject(uint32_t index)> entryFunc);

void GetNativeFromJava_ByteArray(JNIEnv* jEnv, jbyteArray jSource, std::vector<uint8_t>& dest);

/**
 * Loads all Java class information required for the base utility classes.
 */
void LoadAllUtilityJavaClassInfo(JNIEnv* jEnv);

template <typename... ArgTypes>
std::function<void(ArgTypes...)> CreateJavaCallbackWrapper(
  JNIEnv* jEnv, jobject jCallback, JavaClassInfo& callbackInfo);

template <typename ContainerType, typename KVTransformerType>
jobject GetJavaInstance_HashMap(JNIEnv* jEnv, const ContainerType& container, const KVTransformerType& transformer);

template <typename T>
jobject GetJavaInstance_SimpleEnum(JNIEnv* jEnv, JavaClassInfo& info, T val);

template <typename T>
T GetNativeFromJava_SimpleEnum(JNIEnv* jEnv, JavaClassInfo& info, jobject jEnumValue, T defaultValue);

/**
 * Creates a java result object from a C++ result object.
 *
 * @param result The C++ result object to be converted.
 * @param converter A callable object that takes a `const ResultType&` and returns the corresponding jobject.
 */
template <typename ResultType, typename ObjectConverter>
jobject GetJavaInstance_Result(JNIEnv* jEnv, const Result<ResultType>& result, ObjectConverter&& converter);

jobject GetJavaInstance_SuccessResult(JNIEnv* jEnv, jobject jResultValue);
jobject GetJavaInstance_ErrorResult(JNIEnv* jEnv, TTV_ErrorCode ec);

/**
 * Creates a C++ result object from a java result object.
 *
 * @param jResult The java result object to be converted.
 * @param converter A callable object that takes a jobject and returns a `ResultType`
 */
template <typename ResultType, typename ObjectConverter>
Result<ResultType> GetNativeInstance_Result(JNIEnv* jEnv, jobject jResult, ObjectConverter&& converter);
}  // namespace java
}  // namespace binding
}  // namespace ttv

class ttv::binding::java::JavaLocalReferenceDeleter {
 public:
  JavaLocalReferenceDeleter(JNIEnv* jEnv, jobject jObject, const char* name);
  ~JavaLocalReferenceDeleter();

 protected:
  JNIEnv* mjEnv;
  jobject mObject;
  const char* mName;
};

class ttv::binding::java::ScopedJavaUTFStringConverter {
 public:
  ScopedJavaUTFStringConverter(JNIEnv* jEnv, jstring jstr);
  ~ScopedJavaUTFStringConverter();

  const char* GetNativeString();
  int GetCharacterLength() const { return mCharacterLength; }
  int GetByteLength() const { return mByteLength; }

  operator const char*() { return GetNativeString(); }
  operator std::string() { return GetNativeString(); }

 protected:
  JNIEnv* mjEnv;
  jstring mJavaString;
  const char* mNativeString;
  int mCharacterLength;
  int mByteLength;
};

class ttv::binding::java::ScopedJavaWcharStringConverter {
 public:
  ScopedJavaWcharStringConverter(JNIEnv* jEnv, jstring jstr);

  const wchar_t* GetNativeString() const;

 protected:
  JNIEnv* mjEnv;
  jstring mJavaString;
  const wchar_t* mNativeString;
  std::wstring mSTDWideString;
};

class ttv::binding::java::ScopedJavaEnvironmentCacher {
 public:
  ScopedJavaEnvironmentCacher(JNIEnv* jenv);
  ~ScopedJavaEnvironmentCacher();

  static int GetMinLocalReferenceTableCapacity();

 protected:
  static int mCacheCount;
};

/**
 * Manages automatically locking and unlocking a JavaVM* instance to obtain a JNIEnv* instance.
 */
class ttv::binding::java::AutoJEnv {
 public:
  AutoJEnv();
  AutoJEnv(JavaVM* jvm);
  ~AutoJEnv();

  JavaVM* GetJvm() { return mJvm; }

  JNIEnv* operator->();
  operator JNIEnv*();

 private:
  bool Lock();
  void Unlock();

  JavaVM* mJvm;
  JNIEnv* mJEnv;
  bool mNeedsDetach;
};

class ttv::binding::java::GlobalJavaObjectReference {
 public:
  GlobalJavaObjectReference();
  virtual ~GlobalJavaObjectReference();

  /**
   * Captures the given object instance and creates a global reference.
   */
  bool Bind(JNIEnv* jEnv, jobject jLocalObject);
  void Release(JNIEnv* jEnv);
  void Release();

  jobject GetInstance() const { return mObject; }
  jbyteArray GetInstanceAsByteArray() const { return static_cast<jbyteArray>(mObject); }

 protected:
  jobject mObject;
};

template <typename PROXY_TYPE>
struct ttv::binding::java::ProxyContext {
  std::shared_ptr<PROXY_TYPE> instance;
};

template <typename PROXY_TYPE, typename LISTENER_TYPE>
struct ttv::binding::java::ProxyContextWithListener {
  std::shared_ptr<PROXY_TYPE> instance;
  std::shared_ptr<LISTENER_TYPE> nativeListener;
};

/**
 * Manages a mapping of a native pointer to a Java object reference.
 */
template <typename NativeInstanceType, typename ContextType>
class ttv::binding::java::JavaNativeProxyRegistry {
 public:
  void Register(const std::shared_ptr<NativeInstanceType>& nativeInstance,
    const std::shared_ptr<ContextType>& nativeContext, jobject javaInstance) {
    if (mMutex == nullptr) {
      TTV_ErrorCode ret = CreateMutex(mMutex, "JavaNativeProxyRegistry");
      ASSERT_ON_ERROR(ret);
    }

    AutoJEnv jEnv;

    auto entry = std::make_shared<Entry>();
    entry->nativeInstance = nativeInstance;
    entry->nativeContext = nativeContext;
    entry->javaInstance.Bind(jEnv, javaInstance);

    ttv::AutoMutex mutex(mMutex.get());
    mRegistry.push_back(entry);
  }

  void Unregister(jlong jNativePointer) { Unregister(reinterpret_cast<NativeInstanceType*>(jNativePointer)); }

  void Unregister(NativeInstanceType* nativeInstance) {
    if (mMutex == nullptr) {
      return;
    }

    ttv::AutoMutex mutex(mMutex.get());

    for (auto iter = mRegistry.begin(); iter != mRegistry.end(); ++iter) {
      auto& entry = *iter;

      if (entry->nativeInstance.get() == nativeInstance) {
        mRegistry.erase(iter);
        return;
      }
    }
  }

  void Unregister(jobject javaInstance) {
    if (mMutex == nullptr) {
      return;
    }

    ttv::AutoMutex mutex(mMutex.get());
    AutoJEnv jEnv;

    for (auto iter = mRegistry.begin(); iter != mRegistry.end(); ++iter) {
      auto& entry = *iter;

      if (jEnv->IsSameObject(javaInstance, entry->javaInstance.GetInstance())) {
        mRegistry.erase(iter);
        return;
      }
    }
  }

  /**
   * From the java instance find the corresponding native instance and returns it.
   */
  std::shared_ptr<NativeInstanceType> LookupNativeInstance(jobject javaInstance) {
    if (mMutex == nullptr) {
      return nullptr;
    }

    ttv::AutoMutex mutex(mMutex.get());
    AutoJEnv jEnv;

    for (const auto& entry : mRegistry) {
      if (jEnv->IsSameObject(javaInstance, entry->javaInstance.GetInstance())) {
        return entry->nativeInstance;
      }
    }

    return nullptr;
  }

  std::shared_ptr<ContextType> LookupNativeContext(jobject javaInstance) {
    if (mMutex == nullptr) {
      return nullptr;
    }

    ttv::AutoMutex mutex(mMutex.get());
    AutoJEnv jEnv;

    for (const auto& entry : mRegistry) {
      if (jEnv->IsSameObject(javaInstance, entry->javaInstance.GetInstance())) {
        return entry->nativeContext;
      }
    }

    return nullptr;
  }

  std::shared_ptr<ContextType> LookupNativeContext(jlong nativeInstance) {
    if (mMutex == nullptr) {
      return nullptr;
    }

    ttv::AutoMutex mutex(mMutex.get());
    AutoJEnv jEnv;

    for (const auto& entry : mRegistry) {
      if (reinterpret_cast<NativeInstanceType*>(nativeInstance) == entry->nativeInstance.get()) {
        return entry->nativeContext;
      }
    }

    return nullptr;
  }

 private:
  struct Entry {
    std::shared_ptr<NativeInstanceType> nativeInstance;
    std::shared_ptr<ContextType> nativeContext;
    GlobalJavaObjectReference javaInstance;
  };

  std::vector<std::shared_ptr<Entry>> mRegistry;
  std::unique_ptr<ttv::IMutex> mMutex;
};

template <typename... ArgTypes>
std::function<void(ArgTypes...)> ttv::binding::java::CreateJavaCallbackWrapper(
  JNIEnv* jEnv, jobject jCallback, JavaClassInfo& callbackInfo) {
  auto callbackReference = std::make_shared<GlobalJavaObjectReference>();
  callbackReference->Bind(jEnv, jCallback);

  // Capturing callbackInfo by reference here is fine because it is global
  // and immutable.
  return [callbackReference, &callbackInfo](ArgTypes... args) {
    auto callback = callbackReference->GetInstance();

    if (callback != nullptr) {
      gActiveJavaEnvironment->CallVoidMethod(callback, callbackInfo.methods["invoke"], args...);
    }
  };
};

template <typename ContainerType, typename KVTransformerType>
jobject ttv::binding::java::GetJavaInstance_HashMap(
  JNIEnv* jEnv, const ContainerType& container, const KVTransformerType& transformer) {
  JavaClassInfo& hashMapInfo = GetJavaClassInfo_HashMap(jEnv);
  jobject jHashMap = jEnv->NewObject(hashMapInfo.klass, hashMapInfo.methods["<init>"]);

  for (const auto& pair : container) {
    auto jPair = transformer(pair);
    AUTO_DELETE_LOCAL_REF(jEnv, jobject, jKey, jPair.first);
    AUTO_DELETE_LOCAL_REF(jEnv, jobject, jValue, jPair.second);

    jEnv->CallObjectMethod(jHashMap, hashMapInfo.methods["put"], jKey, jValue);
  }

  return jHashMap;
}

template <typename T>
jobject ttv::binding::java::GetJavaInstance_SimpleEnum(JNIEnv* jEnv, JavaClassInfo& info, T val) {
  return jEnv->CallStaticObjectMethod(info.klass, info.staticMethods["lookupValue"], static_cast<jint>(val));
}

template <typename T>
T ttv::binding::java::GetNativeFromJava_SimpleEnum(
  JNIEnv* jEnv, JavaClassInfo& info, jobject jEnumValue, T defaultValue) {
  if (jEnumValue == nullptr) {
    return defaultValue;
  }

  return static_cast<T>(jEnv->CallIntMethod(jEnumValue, info.methods["getValue"]));
}

template <typename ResultType, typename ObjectConverter>
jobject ttv::binding::java::GetJavaInstance_Result(
  JNIEnv* jEnv, const Result<ResultType>& result, ObjectConverter&& converter) {
  if (result.IsSuccess()) {
    AUTO_DELETE_LOCAL_REF(jEnv, jobject, jResultObject, converter(result.GetResult()));

    return GetJavaInstance_SuccessResult(jEnv, jResultObject);
  } else {
    return GetJavaInstance_ErrorResult(jEnv, result.GetErrorCode());
  }
}

template <typename ResultType, typename ObjectConverter>
ttv::Result<ResultType> ttv::binding::java::GetNativeInstance_Result(
  JNIEnv* jEnv, jobject jResult, ObjectConverter&& converter) {
  JavaClassInfo& resultInfo = GetJavaClassInfo_Result(jEnv);
  jboolean jIsSuccess = jEnv->CallBooleanMethod(jResult, resultInfo.methods["isSuccess"]);
  if (jIsSuccess == JNI_TRUE) {
    AUTO_DELETE_LOCAL_REF(
      jEnv, jobject, jResultValue, jEnv->CallObjectMethod(jResult, resultInfo.methods["getResult"]));
    return ttv::MakeSuccessResult(converter(jResultValue));
  } else {
    AUTO_DELETE_LOCAL_REF(
      jEnv, jobject, jErrorCode, jEnv->CallObjectMethod(jResult, resultInfo.methods["getErrorCode"]));

    return ttv::MakeErrorResult(GetNativeFromJava_SimpleEnum<TTV_ErrorCode>(
      jEnv, GetJavaClassInfo_ErrorCode(jEnv), jErrorCode, TTV_EC_UNKNOWN_ERROR));
  }
}
