/****************************************************************************
 * Twitch SDK
 *
 * This software is supplied under the terms of a license agreement with
 * Twitch Interactive, Inc. and may not be copied or used except in accordance
 * with the terms of that agreement
 *
 * Copyright (c) 2012-2016 Twitch Interactive, Inc.
 ***************************************************************************/

#pragma once

#include "twitchsdk/core/mutex.h"
#include "twitchsdk/core/socket.h"

#include <queue>

namespace ttv {
namespace test {
class TestSocket;
}
}  // namespace ttv

/**
 * A mock object that simulates sending and receiving data over a raw socket
 */
class ttv::test::TestSocket : public ttv::ISocket {
 public:
  TestSocket();

  virtual TTV_ErrorCode Connect() override;
  virtual TTV_ErrorCode Disconnect() override;

  virtual TTV_ErrorCode Send(const uint8_t* buffer, size_t length, size_t& sent) override;
  virtual TTV_ErrorCode Recv(uint8_t* buffer, size_t length, size_t& received) override;

  virtual uint64_t TotalSent() override;
  virtual uint64_t TotalReceived() override;

  virtual bool Connected() override;

  void SetMaxBitrate(uint32_t kbps) { mMaxBitrateKbps = kbps; }

  /**
   * Push received data back to the client. If non-null, the callback will be invoked after the client has received all
   * data.
   */
  template <typename PayloadType, typename CallbackType>
  void PushReceivedPayload(PayloadType&& payload, CallbackType&& callback) {
    AutoMutex lock(mMutex.get());

    mReceivedPayloadQueue.emplace(std::forward<PayloadType>(payload), std::forward<CallbackType>(callback));
  }

  /**
   * This callback is called whenever the client sends data over the socket.
   */
  template <typename CallbackType>
  void SetSendCallback(CallbackType&& callback) {
    mSendCallback = std::forward<CallbackType>(callback);
  }

  /**
   * This callback is called whenever the client attempts to receive data over the socket.
   */
  template <typename CallbackType>
  void SetRecvCallback(CallbackType&& callback) {
    mRecvCallback = std::forward<CallbackType>(callback);
  }

  /**
   * The tester can set this value to simulate a socket failure. Any subsequent attempts to read or write will return an
   * error
   */
  bool GetShouldFail() const { return mShouldFail; }
  void SetShouldFail(bool shouldFail) { mShouldFail = shouldFail; }

 private:
  std::function<TTV_ErrorCode(const uint8_t* buffer, size_t length, size_t& sent)> mSendCallback;
  std::function<TTV_ErrorCode(uint8_t* buffer, size_t length, size_t& received)> mRecvCallback;
  using QueueEntry = std::pair<std::string, std::function<void()>>;
  std::queue<QueueEntry> mReceivedPayloadQueue;
  std::unique_ptr<IMutex> mMutex;
  uint64_t mTotalSent;
  uint64_t mTotalReceived;
  uint32_t mMaxBitrateKbps;
  bool mConnected;
  bool mShouldFail;
};
