/****************************************************************************
 * Twitch SDK
 *
 * This software is supplied under the terms of a license agreement with
 * Twitch Interactive, Inc. and may not be copied or used except in accordance
 * with the terms of that agreement
 *
 * Copyright (c) 2012-2016 Twitch Interactive, Inc.
 ***************************************************************************/

#include "twitchsdk/broadcast/internal/pch.h"

#include "twitchsdk/broadcast/internal/muxers/rtmpinitializestate.h"

#include "twitchsdk/core/socket.h"
#include "twitchsdk/core/stringutilities.h"
namespace {
const auto kSupportedProtocols = {"rtmp://", "rtmps://"};
}

void ttv::broadcast::RtmpInitializeState::OnEnterInternal() {
  // TODO: Tokenize this nicely.
  // TODO: Should this live in ttv::Socket?

  std::string& url = GetContext()->mURL;

  size_t protocol = 0;
  size_t protocolEnd = 0;

  for (const auto& supportedProtocol : kSupportedProtocols) {
    protocol = url.find(supportedProtocol);
    protocolEnd = protocol + strlen(supportedProtocol);
    if (protocol == 0) {
      break;
    }
  }

  if (protocol != 0) {
    GetContext()->mLastError = TTV_EC_BROADCAST_RTMP_WRONG_PROTOCOL_IN_URL;
    GetContext()->SetNextState(RtmpContext::State::Error);
    return;
  }

  auto portStart = url.find(":", protocolEnd);
  auto domainEnd = url.find("/", protocolEnd);
  auto appEnd = url.find("/", domainEnd + 1);
  auto streamNameStart = url.find_last_of("/");
  auto isSecure = url.find("rtmps://") == 0;

  assert(appEnd == streamNameStart);

  GetContext()->mApplication = url.substr(domainEnd + 1, appEnd - domainEnd - 1);
  GetContext()->mStreamName = url.substr(streamNameStart + 1, std::string::npos);

  TTV_ErrorCode ec = TTV_EC_SUCCESS;

  if (portStart != std::string::npos) {
    domainEnd = portStart;
    std::string port = url.substr(portStart + 1, domainEnd - portStart - 1);
    if (!ParseNum(port, GetContext()->mPort)) {
      ec = TTV_EC_INVALID_ARG;
    }
  } else {
    if (isSecure) {
      GetContext()->mPort = 443;
    } else {
      GetContext()->mPort = 1935;
    }
  }

  if (TTV_SUCCEEDED(ec)) {
    GetContext()->mHostName = url.substr(protocolEnd, domainEnd - protocolEnd);
    GetContext()->mURL = url;

    std::string uri;
    std::shared_ptr<ISocket> socket;

    if (isSecure) {
      uri = "ssl://" + GetContext()->mHostName + ":" + std::to_string(GetContext()->mPort);
    } else {
      uri = GetContext()->mHostName + ":" + std::to_string(GetContext()->mPort);
    }

    ec = ttv::CreateSocket(uri, socket);
    if (TTV_SUCCEEDED(ec)) {
      GetContext()->mSocket.Bind(socket);
    }
    ASSERT_ON_ERROR(ec);
  }

  if (TTV_SUCCEEDED(ec)) {
    ec = GetContext()->mSocket.Connect();
    ASSERT_ON_ERROR(ec);
  }

  if (TTV_FAILED(ec)) {
    GetContext()->SetNextState(RtmpContext::State::Error);
    GetContext()->mLastError = ec;
  } else {
    GetContext()->SetNextState(RtmpContext::State::Handshake);
  }
}
