/****************************************************************************
 * Twitch SDK
 *
 * This software is supplied under the terms of a license agreement with
 * Twitch Interactive, Inc. and may not be copied or used except in accordance
 * with the terms of that agreement
 *
 * Copyright (c) 2012-2016 Twitch Interactive, Inc.
 ***************************************************************************/

#pragma once

#include "twitchsdk/broadcast/internal/muxers/rtmp.h"
#include "twitchsdk/broadcast/internal/muxers/rtmpcontext.h"
#include "twitchsdk/core/systemclock.h"

#define __STDC_FORMAT_MACROS
#include <inttypes.h>

namespace ttv {
namespace broadcast {
class RtmpState;
class RtmpIdleState;
class RtmpInitializeState;
class RtmpHandshakeState;
class RtmpConnectState;
class RtmpCreateStreamState;
class RtmpPublishState;
class RtmpSendVideoState;
class RtmpShutdownState;
class RtmpErrorState;

using rtmpinputbuffer_t = std::array<uint8_t, 0x10000>;
}  // namespace broadcast
}  // namespace ttv

//////////////////////////////////////////////////////////////////////////
// Base class for the RTMP State machine
//////////////////////////////////////////////////////////////////////////
class ttv::broadcast::RtmpState {
 public:
  RtmpState(RtmpContext* context) : mContext(context) {}
  virtual ~RtmpState() {}

  void OnEnter() {
    mStateStartTime = GetSystemClockTime();
    ttv::trace::Message(
      "rtmp", MessageLevel::Debug, "RtmpState::OnEnter - state start time: %" PRId64, GetSystemTimeMilliseconds());

    OnEnterInternal();
  }
  void OnExit() { OnExitInternal(); }
  void Update() {
    auto timeout = SystemTimeToMs(GetSystemClockTime() - mStateStartTime);

    if (timeout > GetTimeoutDuration()) {
      ttv::trace::Message("rtmp", MessageLevel::Error, "RtmpState::Update - Rtmp state timed out at time: %" PRId64,
        GetSystemTimeMilliseconds());

      GetContext()->mLastError = TTV_EC_BROADCAST_RTMP_TIMEOUT;
      GetContext()->SetNextState(RtmpContext::State::Error);
    } else {
      // Todo PollForInput can return a valid
      // error that will be lost right here.
      TTV_ErrorCode ret = PollForInput();
      if (TTV_FAILED(ret)) {
        GetContext()->mLastError = ret;
        GetContext()->SetNextState(RtmpContext::State::Error);
      }
    }
  }

  // TODO: This should be private/protected
  void EndChunk() { GetContext()->mChunkSpace = 0; }

  // TODO: This should be private/protected
  // TODO: This doesn't belong in rtmpstate at all
  TTV_ErrorCode AppendChunkData(const uint8_t* buffer, size_t length, RtmpMessageDetails& chunkDetails);
  template <class _T>
  TTV_ErrorCode AppendChunkData(const std::vector<_T>& buffer, RtmpMessageDetails& chunkDetails) {
    return AppendChunkData(reinterpret_cast<const uint8_t*>(&buffer[0]), buffer.size() * sizeof(_T), chunkDetails);
  }

 protected:
  TTV_ErrorCode PollForInput();
  TTV_ErrorCode SendChunkData(const uint8_t* buffer, size_t length, RtmpMessageDetails& chunkDetails) {
    TTV_ErrorCode ret = AppendChunkData(buffer, length, chunkDetails);
    EndChunk();
    return ret;
  }
  template <class _T>
  TTV_ErrorCode SendChunkData(const std::vector<_T>& buffer, RtmpMessageDetails& chunkDetails) {
    return SendChunkData(reinterpret_cast<const uint8_t*>(&buffer[0]), buffer.size() * sizeof(_T), chunkDetails);
  }

  RtmpContext* GetContext() {
    assert(mContext);
    return mContext;
  }
  const RtmpContext* GetContext() const {
    assert(mContext);
    return mContext;
  }

  virtual void HandleIncomingChunkSize(ChunkHeader header, const uint8_t* data);
  virtual void HandleIncomingAbortMsg(ChunkHeader header, const uint8_t* data);
  virtual void HandleIncomingBytesRead(ChunkHeader header, const uint8_t* data);
  virtual void HandleIncomingControl(ChunkHeader header, const uint8_t* data);
  virtual void HandleIncomingWinacksize(ChunkHeader header, const uint8_t* data);
  virtual void HandleIncomingPeerBW(ChunkHeader header, const uint8_t* data);
  virtual void HandleIncomingEdgeOrigin(ChunkHeader header, const uint8_t* data);
  virtual void HandleIncomingAudio(ChunkHeader header, const uint8_t* data);
  virtual void HandleIncomingVideo(ChunkHeader header, const uint8_t* data);
  virtual void HandleIncomingAmf3Data(ChunkHeader header, const uint8_t* data);
  virtual void HandleIncomingAmf3SharedObject(ChunkHeader header, uint8_t* data);
  virtual void HandleIncomingAmf3(ChunkHeader header, const uint8_t* data);
  virtual void HandleIncomingAmf0Data(ChunkHeader header, const uint8_t* data);
  virtual void HandleIncomingAmf0SharedObject(ChunkHeader header, const uint8_t* data);
  virtual void HandleIncomingAmf0(ChunkHeader header, const uint8_t* data);
  virtual void HandleIncomingAggregate(ChunkHeader header, const uint8_t* data);

 private:
  virtual void OnEnterInternal() {}
  virtual void OnExitInternal() {}

  virtual uint64_t GetTimeoutDuration() const { return std::numeric_limits<uint64_t>::max(); }

  TTV_ErrorCode _PollForInput();
  static rtmpinputbuffer_t mInputBuffer;
  static rtmpinputbuffer_t::size_type mInputBufferPos;

  RtmpContext* mContext;
  uint64_t mStateStartTime;
};

//////////////////////////////////////////////////////////////////////////
// Idle state, doesn't do anything
//////////////////////////////////////////////////////////////////////////
class ttv::broadcast::RtmpIdleState : public RtmpState {
 public:
  RtmpIdleState(RtmpContext* context) : RtmpState(context) {}
  virtual ~RtmpIdleState();
};

class ttv::broadcast::RtmpErrorState : public RtmpState {
 public:
  RtmpErrorState(RtmpContext* context) : RtmpState(context) {}
  virtual ~RtmpErrorState();

  virtual void OnEnterInternal() {
    ttv::trace::Message(
      "rtmp", MessageLevel::Error, "RTMP Entered error state with error %s", ErrorToString(GetContext()->mLastError));

    GetContext()->mSocket.Disconnect();
  }
};
