/****************************************************************************
 * Twitch SDK
 *
 * This software is supplied under the terms of a license agreement with
 * Twitch Interactive, Inc. and may not be copied or used except in accordance
 * with the terms of that agreement
 *
 * Copyright (c) 2012-2016 Twitch Interactive, Inc.
 ***************************************************************************/

#pragma once

#include "twitchsdk/core/java_utility.h"
#include "twitchsdk/core/socket.h"

#include <jni.h>

namespace ttv {
namespace binding {
namespace java {
class JavaSocketBase;
class JavaSocketFactoryBase;

class JavaSocket;
class JavaSocketFactory;
class JavaWebSocket;
class JavaWebSocketFactory;
}  // namespace java
}  // namespace binding
}  // namespace ttv

class ttv::binding::java::JavaSocketBase {
 public:
  virtual ~JavaSocketBase() = default;

  GlobalJavaObjectReference& GetJavaInstance() { return mSocketInstance; }

 protected:
  JavaSocketBase(JNIEnv* jEnv, jobject jInstance);

  void AllocateByteArray(size_t size);

  GlobalJavaObjectReference mSocketInstance;  //!< The java implementation of the socket.
  GlobalJavaObjectReference mSentReceivedResultContainer;
  GlobalJavaObjectReference mByteArrayInstance;  //!< The scratch byte array.
  size_t mByteArraySize;                         //!< The size of mByteArrayInstance.
};

class ttv::binding::java::JavaSocketFactoryBase {
 public:
  GlobalJavaObjectReference& GetJavaInstance() { return mJavaInstance; }

 protected:
  JavaSocketFactoryBase(JNIEnv* jEnv, jobject jInstance, jmethodID isProtocolSupportedId, jmethodID createMethodId);
  virtual ~JavaSocketFactoryBase() = default;

  bool IsProtocolSupported(const std::string& protocol);
  TTV_ErrorCode CreateSocket(const std::string& uri, jobject& result);

  GlobalJavaObjectReference mJavaInstance;  //!< The java implementation of the factory.

 private:
  jmethodID mIsProtocolSupportedId;
  jmethodID mCreateMethodId;
};

class ttv::binding::java::JavaSocket : public ttv::ISocket, public JavaSocketBase {
 public:
  JavaSocket(JNIEnv* jEnv, jobject jInstance);

  // ISocket implementation
  virtual TTV_ErrorCode Connect() override;
  virtual TTV_ErrorCode Disconnect() override;
  virtual TTV_ErrorCode Send(const uint8_t* buffer, size_t length, size_t& sent) override;
  virtual TTV_ErrorCode Recv(uint8_t* buffer, size_t length, size_t& received) override;
  virtual uint64_t TotalSent() override;
  virtual uint64_t TotalReceived() override;
  virtual bool Connected() override;
};

class ttv::binding::java::JavaSocketFactory : public ttv::ISocketFactory, public JavaSocketFactoryBase {
 public:
  JavaSocketFactory(JNIEnv* jEnv, jobject jInstance);

  // ISocketFactory implementation
  virtual bool IsProtocolSupported(const std::string& protocol) override;
  virtual TTV_ErrorCode CreateSocket(const std::string& uri, std::shared_ptr<ISocket>& result) override;
};

class ttv::binding::java::JavaWebSocket : public ttv::IWebSocket, public JavaSocketBase {
 public:
  JavaWebSocket(JNIEnv* jEnv, jobject jInstance);

  // IWebSocket implementation
  virtual TTV_ErrorCode Connect() override;
  virtual TTV_ErrorCode Disconnect() override;
  virtual TTV_ErrorCode Send(MessageType type, const uint8_t* buffer, size_t length) override;
  virtual TTV_ErrorCode Recv(MessageType& type, uint8_t* buffer, size_t length, size_t& received) override;
  virtual TTV_ErrorCode Peek(MessageType& type, size_t& length) override;
  virtual bool Connected() override;

 private:
  GlobalJavaObjectReference mMessageTypeResultContainer;
};

class ttv::binding::java::JavaWebSocketFactory : public ttv::IWebSocketFactory, public JavaSocketFactoryBase {
 public:
  JavaWebSocketFactory(JNIEnv* jEnv, jobject jInstance);

  // IWebSocketFactory implementation
  virtual bool IsProtocolSupported(const std::string& protocol) override;
  virtual TTV_ErrorCode CreateWebSocket(const std::string& uri, std::shared_ptr<IWebSocket>& result) override;
};
