#include "context.h"

#include "staticinit.h"
#include "twitchsdk/core/assertion.h"
#include "twitchsdk/core/stringutilities.h"
#include "twitchsdk/core/thread.h"
#include "twitchsdk/core/tracer.h"

#include <regex>
#include <sstream>

using namespace ttv;

ParamDefinition::ParamDefinition(const std::string& name, ParamType::Enum type) : mName(name), mType(type) {}

ParamValue::ParamValue(std::shared_ptr<ParamDefinition> def, const std::string& value) : mDefinition(def) {
  TTV_ASSERT(def->GetType() == ParamType::String || def->GetType() == ParamType::RemainderAsString);

  mString = value;
}

ParamValue::ParamValue(std::shared_ptr<ParamDefinition> def, uint32_t value) : mDefinition(def) {
  TTV_ASSERT(def->GetType() == ParamType::UInt32);

  mUInt32 = value;
}

ParamValue::ParamValue(std::shared_ptr<ParamDefinition> def, int32_t value) : mDefinition(def) {
  TTV_ASSERT(def->GetType() == ParamType::Int32);

  mInt32 = value;
}

ParamValue::ParamValue(std::shared_ptr<ParamDefinition> def, uint64_t value) : mDefinition(def) {
  TTV_ASSERT(def->GetType() == ParamType::UInt64);

  mUInt64 = value;
}

ParamValue::ParamValue(std::shared_ptr<ParamDefinition> def, int64_t value) : mDefinition(def) {
  TTV_ASSERT(def->GetType() == ParamType::Int64);

  mInt64 = value;
}

ParamValue::ParamValue(std::shared_ptr<ParamDefinition> def, float value) : mDefinition(def) {
  TTV_ASSERT(def->GetType() == ParamType::Float);

  mFloat = value;
}

ParamValue::ParamValue(std::shared_ptr<ParamDefinition> def, bool value) : mDefinition(def) {
  TTV_ASSERT(def->GetType() == ParamType::Bool);

  mBool = value;
}

std::string ParamValue::GetString() const {
  TTV_ASSERT(mDefinition->GetType() == ParamType::String || mDefinition->GetType() == ParamType::RemainderAsString);

  return mString;
}

uint32_t ParamValue::GetUInt32() const {
  TTV_ASSERT(mDefinition->GetType() == ParamType::UInt32);

  return mUInt32;
}

int32_t ParamValue::GetInt32() const {
  TTV_ASSERT(mDefinition->GetType() == ParamType::Int32);

  return mInt32;
}

uint64_t ParamValue::GetUInt64() const {
  TTV_ASSERT(mDefinition->GetType() == ParamType::UInt64);

  return mUInt64;
}

int64_t ParamValue::GetInt64() const {
  TTV_ASSERT(mDefinition->GetType() == ParamType::Int64);

  return mInt64;
}

float ParamValue::GetFloat() const {
  TTV_ASSERT(mDefinition->GetType() == ParamType::Float);

  return mFloat;
}

bool ParamValue::GetBool() const {
  TTV_ASSERT(mDefinition->GetType() == ParamType::Bool);

  return mBool;
}

ParamValue::operator std::string() const {
  return GetString();
}

ParamValue::operator uint32_t() const {
  return GetUInt32();
}

ParamValue::operator int32_t() const {
  return GetInt32();
}

ParamValue::operator uint64_t() const {
  return GetUInt64();
}

ParamValue::operator int64_t() const {
  return GetInt64();
}

ParamValue::operator float() const {
  return GetFloat();
}

ParamValue::operator bool() const {
  return GetBool();
}

bool ParamValue::operator==(const char* other) const {
  return GetString() == other;
}

bool ParamValue::operator==(const std::string& other) const {
  return GetString() == other;
}

bool ParamValue::operator==(bool other) const {
  return GetBool() == other;
}

CommandFunction::CommandFunction(std::shared_ptr<CommandDefinition> def)
    : mDefinition(def), mFunction(nullptr), mFunctionWithInstance(nullptr) {}

CommandFunction& CommandFunction::Function(CommandFunc&& func) {
  mFunction = std::move(func);

  return *this;
}

CommandFunction& CommandFunction::Function(CommandWithInstanceFunc&& func) {
  mFunctionWithInstance = std::move(func);

  return *this;
}

CommandFunction& CommandFunction::Description(const std::string& desc) {
  mDescription = desc;

  return *this;
}

CommandFunction& CommandFunction::AddParam(const std::string& name, ParamType::Enum type) {
  auto def = std::make_shared<ParamDefinition>(name, type);
  mParams.push_back(def);

  return *this;
}

CommandDefinition& CommandFunction::Done() {
  return *mDefinition;
}

void CommandFunction::SubstituteVariables(
  std::vector<std::string>& tokens, const std::map<std::string, std::string>& variables) const {
  for (const auto& kvp : variables) {
    std::string variable("$");
    variable += kvp.first;
    ToLowerCase(variable);

    for (size_t i = 0; i < tokens.size(); ++i) {
      std::string lower = tokens[i];
      ToLowerCase(lower);

      if (lower == variable) {
        tokens[i] = kvp.second;
      }
    }
  }
}

std::string CommandFunction::SubstituteVariables(
  const std::string& str, const std::map<std::string, std::string>& variables) const {
  std::string result = str;

  for (int i = 0; i < 10; ++i)  // Just so we don't loop forever
  {
    for (const auto& kvp : variables) {
      std::stringstream sstream;
      sstream << "(?:^|\\s+)(\\$" << kvp.first << ")(?:$|\\s+)";
      std::regex regex(sstream.str());

      auto matchStart = std::sregex_iterator(result.begin(), result.end(), regex);
      auto matchEnd = std::sregex_iterator();

      // Build a reverse list of matches
      std::vector<std::pair<uint32_t, uint32_t>> matches;
      for (std::sregex_iterator iter = matchStart; iter != matchEnd; ++iter) {
        std::smatch match = *iter;

        std::pair<uint32_t, uint32_t> pair(
          static_cast<uint32_t>(match.position(1)), static_cast<uint32_t>(match.position(1) + match.length(1)));
        matches.insert(matches.begin(), pair);
      }

      // Replace in the string in reverse order to preserve indices
      for (size_t m = 0; m < matches.size(); ++m) {
        const auto& pair = matches[m];
        result = result.substr(0, pair.first) + kvp.second + result.substr(pair.second);
      }
    }
  }

  return result;
}

bool CommandFunction::TryParse(const std::string& paramLine, const std::map<std::string, std::string>& variables,
  std::vector<ParamValue>& values) const {
  values.clear();

  std::string line = paramLine;
  ttv::Trim(line);
  line = SubstituteVariables(line, variables);

  std::vector<std::string> tokens;
  TokenizeParameters(line, tokens);

  // We need at least as any tokens as expected input params
  if (mParams.size() > tokens.size()) {
    return false;
  }

  bool remainderConsumed = false;

  for (size_t i = 0; i < mParams.size(); ++i) {
    std::shared_ptr<ParamDefinition> def = mParams[i];
    const std::string& token = tokens[i];

    switch (def->GetType()) {
      case ParamType::UInt32: {
        uint32_t value = 0;
        if (!ParseNum(token, value)) {
          return false;
        }
        values.push_back(ParamValue(def, value));
        break;
      }
      case ParamType::Int32: {
        int32_t value = 0;
        if (!ParseNum(token, value)) {
          return false;
        }
        values.push_back(ParamValue(def, value));
        break;
      }
      case ParamType::UInt64: {
        uint64_t value = 0;
        if (!ParseNum(token, value)) {
          return false;
        }
        values.push_back(ParamValue(def, value));
        break;
      }
      case ParamType::Int64: {
        int64_t value = 0;
        if (!ParseNum(token, value)) {
          return false;
        }
        values.push_back(ParamValue(def, value));
        break;
      }
      case ParamType::Float: {
        float value = 0;
        if (!ParseNum(token, value)) {
          return false;
        }
        values.push_back(ParamValue(def, value));
        break;
      }
      case ParamType::String: {
        values.push_back(ParamValue(def, token));
        break;
      }
      case ParamType::Bool: {
        std::string lower = token;
        ToLowerCase(lower);

        bool value = lower == "1" || lower == "true";
        values.push_back(ParamValue(def, value));
        break;
      }
      case ParamType::RemainderAsString: {
        // This is only valid as the last parameter
        if (i != mParams.size() - 1) {
          return false;
        }

        const char* remainder = line.c_str();
        for (size_t t = 0; t < i && remainder != nullptr; ++t) {
          remainder = AdvanceToNextWord(remainder);
        }

        if (remainder == nullptr) {
          return false;
        }

        values.push_back(ParamValue(def, std::string(remainder)));
        remainderConsumed = true;

        break;
      }
      default: {
        TTV_ASSERT(false);
        return false;
      }
    }
  }

  if (!remainderConsumed && mParams.size() != tokens.size()) {
    return false;
  }

  return true;
}

bool CommandFunction::Invoke(const CommandInstance& instance, const std::vector<ParamValue>& values) const {
  if (mFunction != nullptr) {
    return mFunction(values);
  } else if (mFunctionWithInstance != nullptr) {
    return mFunctionWithInstance(instance, values);
  } else {
    return false;
  }
}

CommandDefinition::CommandDefinition(const std::string& name) : mName(name), mExecuteOnMainThread(false) {
  AddAlias(name);
}

CommandFunction& CommandDefinition::AddFunction() {
  auto func = std::make_shared<CommandFunction>(shared_from_this());
  mFlavors.push_back(func);

  return *func;
}

CommandDefinition& CommandDefinition::AddAlias(const std::string& alias) {
  auto lower = alias;
  ToLowerCase(lower);
  mAliases.push_back(lower);

  return *this;
}

CommandDefinition& CommandDefinition::RunOnMainThread() {
  mExecuteOnMainThread = true;

  return *this;
}

bool CommandDefinition::IsCommand(const std::string& name) const {
  auto lower = name;
  ToLowerCase(lower);

  for (const auto& alias : mAliases) {
    if (alias == lower) {
      return true;
    }
  }

  return false;
}

void CommandDefinition::PrintUsage() const {
  for (const auto& flavor : mFlavors) {
    const auto& params = flavor->GetParams();

    std::cout << "  " << mName;
    for (const auto& p : params) {
      std::cout << " <" << p->GetName() << ">";
    }
    std::cout << std::endl;

    std::cout << "    " << flavor->GetDescription() << std::endl;

    for (const auto& alias : mAliases) {
      std::cout << "    Alias: " << alias << std::endl;
    }
  }
}

CommandInstance::CommandInstance(
  const std::shared_ptr<CommandDefinition>& definition, const std::string& params, const std::string& line)
    : mDefinition(definition), mParams(params), mLine(line), mEcho(true) {
  Split(mParams, mParamTokens, ' ');
}

bool CommandInstance::Invoke(VariableMap& variables) {
  std::vector<ParamValue> values;

  for (const auto& flavor : mDefinition->GetFlavors()) {
    if (flavor->TryParse(mParams, variables, values)) {
      if (flavor->Invoke(*this, values)) {
        return true;
      }
    }

    values.clear();
  }

  return false;
}

void CommandInstance::PrintUsage() const {
  mDefinition->PrintUsage();
}

CommandCategory::CommandCategory(const std::string& name, const std::string& description)
    : mName(name), mDescription(description) {}

bool CommandCategory::FindCommand(const std::string& name, std::shared_ptr<CommandDefinition>& result) {
  result.reset();

  auto iter = std::find_if(mDefinitions.begin(), mDefinitions.end(),
    [name](const std::shared_ptr<CommandDefinition>& cmd) -> bool { return cmd->IsCommand(name); });

  if (iter == mDefinitions.end()) {
    return false;
  } else {
    result = *iter;
    return true;
  }
}

CommandDefinition& CommandCategory::AddCommand(const std::string& name) {
  auto def = std::make_shared<CommandDefinition>(name);
  mDefinitions.push_back(def);

  return *def;
}

void ReportCommandResult(TTV_ErrorCode ec) {
  std::cout << "  " << ttv::ErrorToString(ec) << std::endl;
}

Context::Context() : mUpdateInterval(250), mExit(false) {}

void Context::RegisterCommandCategory(std::shared_ptr<CommandCategory> category) {
  mCommands[category->GetName()] = category;
}

void Context::RegisterModule(std::shared_ptr<ttv::IModule> module) {
  mModules.push_back(module);
}

std::shared_ptr<ttv::IModule> Context::GetModule(const std::string& name) {
  auto iter = std::find_if(mModules.begin(), mModules.end(),
    [name](std::shared_ptr<ttv::IModule> module) { return module->GetModuleName() == name; });

  if (iter != mModules.end()) {
    return *iter;
  } else {
    return nullptr;
  }
}

void Context::RegisterLoggers(const char* const* loggers, size_t numLoggers) {
  for (size_t i = 0; i < numLoggers; ++i) {
    mLoggers.push_back(loggers[i]);
  }
}

void Context::InitializeModules() {
  for (auto module : mModules) {
    if (module->GetState() == ttv::IModule::State::Uninitialized) {
      module->Initialize(nullptr);
    }
  }
}

void Context::UpdateModules() {
  // Update modules in the order they were given
  for (auto module : mModules) {
    module->Update();
  }
}

void Context::ShutdownModules() {
  // Shutdown in the reverse order of init and update
  std::reverse(mModules.begin(), mModules.end());

  ttv::ShutdownModulesSync(mModules);
  mModules.clear();
}

void Context::Exit() {
  mExit = true;
}

void Context::SetVariable(const std::string& name, const std::string& value) {
  mVariables[name] = value;
}

bool Context::GenerateCommand(const std::string& line, std::shared_ptr<CommandInstance>& result) {
  result.reset();

  std::string command;
  std::string params;

  bool valid = ParseCommand(line, command, params);
  if (!valid) {
    return false;
  }

  ToLowerCase(command);

  std::shared_ptr<CommandDefinition> definition;

  for (auto kvp : mCommands) {
    auto category = kvp.second;
    if (category->FindCommand(command, definition)) {
      break;
    }
  }

  if (definition != nullptr) {
    result = std::make_shared<CommandInstance>(definition, params, line);
    return true;
  } else {
    return false;
  }
}

bool Context::EnqueueCommand(const std::string& line, const std::string& workingDirectory, bool echo) {
  std::shared_ptr<CommandInstance> cmd;
  bool found = GenerateCommand(line, cmd);

  if (found) {
    cmd->SetWorkingDirectory(workingDirectory);
    cmd->SetEcho(echo);

    if (cmd->GetExecuteOnMainThread()) {
      if (!cmd->Invoke(mVariables)) {
        std::cout << "Usage:" << std::endl;
        cmd->PrintUsage();
      }
    } else {
      mCommandQueue.push(cmd);
    }

    return true;
  } else {
    std::cout << "Invalid command: " << line << std::endl;
    return false;
  }
}

bool Context::SetCommandAlias(const std::string& cmd, const std::string& alias) {
  std::shared_ptr<CommandDefinition> def;
  for (auto kvp : mCommands) {
    auto category = kvp.second;
    if (category->FindCommand(cmd, def)) {
      break;
    }
  }

  if (def != nullptr) {
    def->AddAlias(alias);
  }

  return def != nullptr;
}

void Context::PrintHelpCategory(const std::string& categoryName) {
  std::shared_ptr<CommandCategory> category;
  auto lowerInput = categoryName;
  ToLowerCase(lowerInput);

  for (auto kvp : mCommands) {
    auto lower = kvp.first;
    ToLowerCase(lower);

    if (lowerInput == lower) {
      category = kvp.second;
      break;
    }
  }

  if (category == nullptr) {
    std::cout << "Unknown category: " << categoryName;
    return;
  }

  std::cout << "  " << category->GetName() << std::endl;
  std::cout << "  =============================================" << std::endl;
  std::cout << std::endl;
  const auto& definitions = category->GetCommands();
  for (auto def : definitions) {
    def->PrintUsage();
    std::cout << std::endl;
  }
}

void Context::PrintHelp() {
  std::cout << std::endl;
  std::cout << "Help" << std::endl;
  std::cout << std::endl;

  PrintHelpCategory("Shell");

  for (auto kvp : mCommands) {
    if (kvp.first != "Shell") {
      PrintHelpCategory(kvp.first);
    }
  }
}

bool Context::ProcessCommand() {
  std::shared_ptr<CommandInstance> cmd;
  mCommandQueue.try_pop(cmd);

  if (cmd == nullptr) {
    return false;
  }

  if (cmd->GetEcho()) {
    std::cout << cmd->GetLine() << std::endl;
  }

  if (!cmd->Invoke(mVariables)) {
    std::cout << "Usage:" << std::endl;
    cmd->PrintUsage();
  }

  return true;
}
