/****************************************************************************
 * Twitch SDK
 *
 * This software is supplied under the terms of a license agreement with
 * Twitch Interactive, Inc. and may not be copied or used except in accordance
 * with the terms of that agreement
 *
 * Copyright (c) 2012-2016 Twitch Interactive, Inc.
 ***************************************************************************/

#pragma once

#include "twitchsdk/broadcast/internal/audioconvert/dsputilities.h"
#include "twitchsdk/broadcast/internal/audioconvert/samplecache.h"

#include <ratio>

namespace ttv {
/**
 * Converts an audio source to one with a lower sample rate.
 *
 * @tparam InputSource Conforms to the AudioSource concept.
 * @tparam OutputSampleRate The sample rate to convert to.
 *
 * Operator Type Properties:
 *     Operator Type: Unary
 *     Data Integrity: Lossy
 *     SampleType[in] == SampleType[out]
 *     SampleRate[in] != SampleRate[out]
 *     StartOffset[in] != StartOffset[out]
 *     Length[in] != Length[out]
 */
template <typename InputSource, size_t OutputSampleRate, typename ContextType>
class ResampleOperator {
 public:
  ResampleOperator(ContextType& context) : mInputSource(context) {
    mGainNormalization = 0.0;
    for (size_t i = 0; i < TapCount; i++) {
      mGainNormalization += mCoefficients[i];
    }
  }

  using SampleType = typename InputSource::SampleType;
  static_assert(std::is_signed<SampleType>::value, "Input sample type must be floating point.");

  static constexpr size_t InputSampleRate = InputSource::SampleRate;
  static constexpr size_t SampleRate = OutputSampleRate;

  using FilterOptions = typename ContextType::FilterOptions;

  static constexpr size_t TapCount = FilterOptions::TapCount;
  using OutputToInputRatio = std::ratio<OutputSampleRate, InputSampleRate>;

  InputSource& GetInputSource() { return mInputSource; }

  void Unbind() {
    SampleRange inputRange = mInputSource.GetSampleRange();

    // Make sure to populate the cache with the last samples before unbinding.
    size_t inputIndex = inputRange.startIndex + inputRange.sampleCount - TapCount;
    mSampleCache.Populate(inputIndex, [this](size_t index) { return mInputSource[index]; });
  }

  SampleRange GetSampleRange() const {
    SampleRange sampleRange = ExtendSampleRange(mInputSource.GetSampleRange(), mSampleCache.GetRange());

    size_t inputIndex = sampleRange.startIndex;
    size_t inputLength = sampleRange.sampleCount;
    inputLength -= TapCount;

    size_t subdivisionIndex = inputIndex * OutputToInputRatio::num;

    size_t outputIndex = subdivisionIndex / OutputToInputRatio::den;
    if (subdivisionIndex % OutputToInputRatio::den != 0) {
      outputIndex++;
    }

    size_t lastInputIndex = inputIndex + inputLength;

    size_t lastSubdivisionIndex = lastInputIndex * OutputToInputRatio::num;
    size_t lastOutputIndex = lastSubdivisionIndex / OutputToInputRatio::den;

    if ((OutputToInputRatio::den - (lastSubdivisionIndex % OutputToInputRatio::den) < OutputToInputRatio::num)) {
      lastOutputIndex++;
    }

    size_t outputLength = lastOutputIndex - outputIndex;

    return {outputIndex, outputLength};
  }

  SampleType operator[](size_t index) const {
    TTV_ASSERT(index >= GetSampleRange().startIndex);
    TTV_ASSERT(index < GetSampleRange().startIndex + GetSampleRange().sampleCount);

    size_t outputSubdivisionIndex = index * OutputToInputRatio::den;

    size_t inputIndex = outputSubdivisionIndex / OutputToInputRatio::num;
    size_t inputPhase = outputSubdivisionIndex % OutputToInputRatio::num;

    if (inputPhase != 0) {
      inputIndex++;
      inputPhase = OutputToInputRatio::num - inputPhase;
    }

    size_t coefficientIndex = TapCount * inputPhase;

    mSampleCache.Populate(inputIndex, [this](size_t index) { return mInputSource[index]; });

    double value = 0.0;

    for (size_t i = 0; i < TapCount; i++) {
      double coefficient = mCoefficients[coefficientIndex + i];
      value += coefficient * mSampleCache[inputIndex + i];
    };

    // Unity gain normalization.
    value /= mGainNormalization;

    return ClampAndCastSample<SampleType>(value);
  }

 private:
  InputSource mInputSource;

  struct SincFunctionParameters {
    static constexpr double Cutoff = FilterOptions::Cutoff * static_cast<double>(OutputToInputRatio::num) /
                                     static_cast<double>(OutputToInputRatio::den);
    static constexpr double Range = static_cast<double>(TapCount);
  };

  using ImpulseResponseFunction = SincFunctionGenerator<SincFunctionParameters>;

  using WindowFunction = typename FilterOptions::WindowFunction;

  using CoefficientFunction = ProductGenerator<ImpulseResponseFunction, WindowFunction>;
  static constexpr size_t CoefficientPhaseSubdivisions = OutputToInputRatio::num * TapCount;

  using CycledCoefficientFunction = CyclingGenerator<CoefficientFunction, OutputToInputRatio::num>;

  LookupTable<CycledCoefficientFunction, CoefficientPhaseSubdivisions> mCoefficients;

  mutable SampleCache<double, TapCount> mSampleCache;
  double mGainNormalization;
};
}  // namespace ttv
