/****************************************************************************
 * Twitch SDK
 *
 * This software is supplied under the terms of a license agreement with
 * Twitch Interactive, Inc. and may not be copied or used except in accordance
 * with the terms of that agreement
 *
 * Copyright (c) 2012-2016 Twitch Interactive, Inc.
 ***************************************************************************/

#pragma once

#include "twitchsdk/broadcast/internal/audioconvert/audioconvertoptions.h"
#include "twitchsdk/broadcast/internal/audioconvert/changefloatingpointdepthoperator.h"
#include "twitchsdk/broadcast/internal/audioconvert/converttofloatingpointoperator.h"
#include "twitchsdk/broadcast/internal/audioconvert/converttointegraltypeoperator.h"
#include "twitchsdk/broadcast/internal/audioconvert/converttosignedoperator.h"
#include "twitchsdk/broadcast/internal/audioconvert/converttounsignedoperator.h"
#include "twitchsdk/broadcast/internal/audioconvert/decreasebitdepthoperator.h"
#include "twitchsdk/broadcast/internal/audioconvert/dsputilities.h"
#include "twitchsdk/broadcast/internal/audioconvert/increasebitdepthoperator.h"
#include "twitchsdk/broadcast/internal/audioconvert/operatorchainbuilder.h"
#include "twitchsdk/broadcast/internal/audioconvert/pcmbufferaudiosource.h"
#include "twitchsdk/broadcast/internal/audioconvert/resampleoperator.h"

namespace ttv {
/**
 * This class builds a chain of operators to convert audio from a given input buffer format to an output buffer format.
 *
 * @tparam ContextType The type of AudioConvertContext for the operators to use. See audioconvertcontext.h
 * @tparam InputBufferFormat The input buffer format. See BufferFormat in dsputilities.h
 * @tparam OutputBufferFormat The output buffer format. See BufferFormat in dsputilities.h
 * @tparam SelectedChannelIndex Which input audio channel this chain will process.
 */
template <typename ContextType, typename InputBufferFormat, typename OutputBufferFormat, size_t SelectedChannelIndex>
struct OperatorChainBuilder {
  /**
   * std::make_signed chokes if we pass it a non-integral type. We sometimes want to pass floating-point types
   * in the algorithm below, so this is a more lenient version that is just a no-op for floating-point types.
   */
  template <typename ScalarType, typename Enable = void>
  struct MakeSignedIfIntegralSelector {
    using Type = ScalarType;
  };

  template <typename ScalarType>
  struct MakeSignedIfIntegralSelector<ScalarType, std::enable_if_t<std::is_integral<ScalarType>::value>> {
    using Type = std::make_signed_t<ScalarType>;
  };

  template <typename ScalarType>
  using MakeSignedIfIntegral = typename MakeSignedIfIntegralSelector<ScalarType>::Type;

  /**
   * All of these operator wrappers are here to keep us from actually instantiating the operator classes
   * in the "false" branches of the conditionals below. Instead, the algorithm just chains a bunch of these
   * wrappers together and then unwraps them at the end.
   */
  struct InputSourceWrapper {
    template <typename T = void>
    using Unwrap = PCMBufferAudioSource<InputBufferFormat, ContextType, SelectedChannelIndex>;

    static constexpr size_t SampleRate = InputBufferFormat::SampleRate;
    using SampleType = typename InputBufferFormat::SampleType;
  };

  template <typename InputSource>
  struct ConvertToSignedWrapper {
    template <typename T = void>
    using Unwrap = ConvertToSignedOperator<typename InputSource::template Unwrap<>, ContextType>;

    static constexpr size_t SampleRate = InputSource::SampleRate;
    using SampleType = MakeSignedIfIntegral<typename InputSource::SampleType>;
  };

  template <typename InputSource>
  struct ConvertToUnsignedWrapper {
    template <typename T = void>
    using Unwrap = ConvertToUnsignedOperator<typename InputSource::template Unwrap<>, ContextType>;

    static constexpr size_t SampleRate = InputSource::SampleRate;
    using SampleType = std::make_unsigned_t<typename InputSource::SampleType>;
  };

  template <typename InputSource, typename OutputSampleType>
  struct ChangeFloatingPointDepthWrapper {
    template <typename T = void>
    using Unwrap =
      ChangeFloatingPointDepthOperator<typename InputSource::template Unwrap<>, OutputSampleType, ContextType>;

    static constexpr size_t SampleRate = InputSource::SampleRate;
    using SampleType = OutputSampleType;
  };

  template <typename InputSource, typename OutputSampleType>
  struct ConvertToIntegralTypeWrapper {
    template <typename T = void>
    using Unwrap = std::enable_if_t<std::is_void<T>::value,
      ConvertToIntegralTypeOperator<typename InputSource::template Unwrap<>, OutputSampleType, ContextType>>;

    static constexpr size_t SampleRate = InputSource::SampleRate;
    using SampleType = OutputSampleType;
  };

  template <typename InputSource, typename OutputSampleType>
  struct ConvertToFloatingTypeWrapper {
    template <typename T = void>
    using Unwrap =
      ConvertToFloatingPointOperator<typename InputSource::template Unwrap<>, OutputSampleType, ContextType>;

    static constexpr size_t SampleRate = InputSource::SampleRate;
    using SampleType = OutputSampleType;
  };

  template <typename InputSource, typename OutputSampleType>
  struct IncreaseBitDepthWrapper {
    template <typename T = void>
    using Unwrap = IncreaseBitDepthOperator<typename InputSource::template Unwrap<>, OutputSampleType, ContextType>;

    static constexpr size_t SampleRate = InputSource::SampleRate;
    using SampleType = OutputSampleType;
  };

  template <typename InputSource, typename OutputSampleType>
  struct DecreaseBitDepthWrapper {
    template <typename T = void>
    using Unwrap = DecreaseBitDepthOperator<typename InputSource::template Unwrap<>, OutputSampleType, ContextType>;

    static constexpr size_t SampleRate = InputSource::SampleRate;
    using SampleType = OutputSampleType;
  };

  template <typename InputSource, size_t SampleRateArg>
  struct ResampleWrapper {
    template <typename T = void>
    using Unwrap = ResampleOperator<typename InputSource::template Unwrap<>, SampleRateArg, ContextType>;

    static constexpr size_t SampleRate = SampleRateArg;
    using SampleType = typename InputSource::SampleType;
  };

  template <typename Input>
  using MakeSigned = std::conditional_t<
    /* if   */ std::is_unsigned<typename Input::SampleType>::value,
    /* then */ ConvertToSignedWrapper<Input>,
    /* else */ Input>;

  template <typename Input>
  using FlipSignedness = std::conditional_t<
    /* if   */ std::is_signed<typename Input::SampleType>::value,
    /* then */ ConvertToUnsignedWrapper<Input>,
    /* else */ ConvertToSignedWrapper<Input>>;

  template <typename Input, typename OutputSampleType>
  using MatchSignedness = std::conditional_t<
    /* if   */ std::is_signed<typename Input::SampleType>::value == std::is_signed<OutputSampleType>::value,
    /* then */ Input,
    /* else */ FlipSignedness<Input>>;

  template <typename Input, typename OutputSampleType>
  using ConvertFloatingToSampleType = std::conditional_t<
    /* if   */ std::is_floating_point<OutputSampleType>::value,
    /* then */ ChangeFloatingPointDepthWrapper<Input, OutputSampleType>,
    /* else */
    MatchSignedness<ConvertToIntegralTypeWrapper<Input, MakeSignedIfIntegral<OutputSampleType>>, OutputSampleType>>;

  template <typename Input, typename OutputSampleType>
  using ConvertIntegralDepth = std::conditional_t<
    /* if   */ sizeof(typename Input::SampleType) < sizeof(OutputSampleType),
    /* then */ IncreaseBitDepthWrapper<Input, OutputSampleType>,
    /* else */ DecreaseBitDepthWrapper<Input, OutputSampleType>>;

  template <typename Input, typename OutputSampleType>
  using MatchIntegralDepth = std::conditional_t<
    /* if   */ sizeof(typename Input::SampleType) == sizeof(OutputSampleType),
    /* then */ Input,
    /* else */ ConvertIntegralDepth<Input, OutputSampleType>>;

  template <typename Input, typename OutputSampleType>
  using ConvertIntegralToSampleType = std::conditional_t<
    /* if   */ std::is_floating_point<OutputSampleType>::value,
    /* then */ ConvertToFloatingTypeWrapper<MatchSignedness<Input, OutputSampleType>, OutputSampleType>,
    /* else */ MatchIntegralDepth<MatchSignedness<Input, OutputSampleType>, OutputSampleType>>;

  template <typename Input, typename OutputSampleType>
  using ConvertSampleType = std::conditional_t<
    /* if   */ std::is_floating_point<typename Input::SampleType>::value,
    /* else */ ConvertFloatingToSampleType<Input, OutputSampleType>,
    /* then */ ConvertIntegralToSampleType<Input, OutputSampleType>>;

  template <typename Input, typename OutputSampleType>
  using MatchSampleType = std::conditional_t<
    /* if   */ std::is_same<typename Input::SampleType, OutputSampleType>::value,
    /* then */ Input,
    /* else */ ConvertSampleType<Input, OutputSampleType>>;

  // If we are increasing our bit depth, we should upscale before resampling. If we are decreasing our bit depth,
  // we should resample first, then change our bit depth. This minimizes quantization error.
  template <typename Input, typename Output>
  using ConvertSampleRateAndType = std::conditional_t<
    /* if   */ (sizeof(typename Input::SampleType) < sizeof(typename Output::SampleType)),
    /* then */
    MatchSignedness<
      ResampleWrapper<MatchSampleType<Input, MakeSignedIfIntegral<typename Output::SampleType>>, Output::SampleRate>,
      typename Output::SampleType>,
    /* else */ MatchSampleType<ResampleWrapper<MakeSigned<Input>, Output::SampleRate>, typename Output::SampleType>>;

  template <typename Input, typename Output>
  using ConvertFormat = std::conditional_t<
    /* if   */ Input::SampleRate == Output::SampleRate,
    /* then */ MatchSampleType<Input, typename Output::SampleType>,
    /* else */ ConvertSampleRateAndType<Input, Output>>;

  using WrappedType = ConvertFormat<InputSourceWrapper, OutputBufferFormat>;

  using Type = typename WrappedType::template Unwrap<>;

  static_assert(std::is_same<typename OutputBufferFormat::SampleType, typename Type::SampleType>::value,
    "Sample type mismatch after building operator chain.");
  static_assert(
    OutputBufferFormat::SampleRate == Type::SampleRate, "Sample rate mismatch after building operator chain.");
};

template <typename ContextType, typename InputBufferFormat, typename OutputBufferFormat, size_t SelectedChannelIndex>
using BuildOperatorChain =
  typename OperatorChainBuilder<ContextType, InputBufferFormat, OutputBufferFormat, SelectedChannelIndex>::Type;

template <typename ContextType, typename InputBufferFormat, typename OutputBufferFormat, size_t RemainingChannelCount,
  typename... CurrentChannels>
struct OperatorChainTupleBuilder {
  using AppendedOperatorChain =
    BuildOperatorChain<ContextType, InputBufferFormat, OutputBufferFormat, RemainingChannelCount - 1>;
  using Type = typename OperatorChainTupleBuilder<ContextType, InputBufferFormat, OutputBufferFormat,
    RemainingChannelCount - 1, AppendedOperatorChain, CurrentChannels...>::Type;
};

template <typename ContextType, typename InputBufferFormat, typename OutputBufferFormat, typename... CurrentChannels>
struct OperatorChainTupleBuilder<ContextType, InputBufferFormat, OutputBufferFormat, 0, CurrentChannels...> {
  using Type = std::tuple<CurrentChannels...>;
};

template <typename ContextType, typename InputBufferFormat, typename OutputBufferFormat>
using BuildOperatorChainTuple = typename OperatorChainTupleBuilder<ContextType, InputBufferFormat, OutputBufferFormat,
  InputBufferFormat::ChannelCount>::Type;

template <typename TupleType, typename ContextType, typename... Args>
std::enable_if_t<sizeof...(Args) < std::tuple_size<TupleType>::value, TupleType> PopulateTupleArgs(
  ContextType& context, Args&... args) {
  return PopulateTupleArgs<TupleType>(context, context, args...);
}

template <typename TupleType, typename ContextType, typename... Args>
std::enable_if_t<sizeof...(Args) == std::tuple_size<TupleType>::value, TupleType> PopulateTupleArgs(
  ContextType& /*context*/, Args&... args) {
  return TupleType{args...};
}

template <typename TupleType, typename ContextType>
TupleType MakeOperatorTuple(ContextType& context) {
  return PopulateTupleArgs<TupleType>(context);
}
}  // namespace ttv
