/****************************************************************************
 * Twitch SDK
 *
 * This software is supplied under the terms of a license agreement with
 * Twitch Interactive, Inc. and may not be copied or used except in accordance
 * with the terms of that agreement
 *
 * Copyright (c) 2012-2016 Twitch Interactive, Inc.
 ***************************************************************************/

#pragma once

#include "twitchsdk/core/socket.h"

#include <queue>

namespace ttv {
namespace test {
class TestWebSocket;
}
}  // namespace ttv

/**
 * A mock object that simulates sending and receiving data over a web socket
 */
class ttv::test::TestWebSocket : public ttv::IWebSocket {
 public:
  TestWebSocket();

  virtual TTV_ErrorCode Connect() override;
  virtual TTV_ErrorCode Disconnect() override;

  virtual TTV_ErrorCode Send(MessageType type, const uint8_t* buffer, size_t length) override;
  virtual TTV_ErrorCode Recv(MessageType& type, uint8_t* buffer, size_t length, size_t& received) override;
  virtual TTV_ErrorCode Peek(MessageType& type, size_t& length) override;

  virtual bool Connected() override;

  /**
   * Push received data back to the client. If non-null, the callback will be invoked after the client has received all
   * data.
   */
  template <typename PayloadType, typename CallbackType>
  void PushReceivedPayload(PayloadType&& payload, CallbackType&& callback) {
    mReceivedPayloadQueue.emplace(std::forward<PayloadType>(payload), std::forward<CallbackType>(callback));
  }

  /**
   * This callback is called whenever the client sends data over the web socket.
   */
  template <typename CallbackType>
  void SetSendCallback(CallbackType&& callback) {
    mSentCallback = std::forward<CallbackType>(callback);
  }

  /**
   * The tester can set this value to simulate a web socket failure. Any subsequent attempts to read or write will
   * return an error
   */
  bool GetShouldFail() const { return mShouldFail; }
  void SetShouldFail(bool shouldFail) { mShouldFail = shouldFail; }

 private:
  std::function<void(const std::string&)> mSentCallback;
  std::queue<std::pair<std::string, std::function<void()>>> mReceivedPayloadQueue;
  bool mConnected;
  bool mShouldFail;
};
