/****************************************************************************
 * Twitch SDK
 *
 * This software is supplied under the terms of a license agreement with
 * Twitch Interactive, Inc. and may not be copied or used except in accordance
 * with the terms of that agreement
 *
 * Copyright (c) 2012-2016 Twitch Interactive, Inc.
 ***************************************************************************/

#include "twitchsdk/broadcast/internal/audioconvert/audioconvertcontext.h"
#include "twitchsdk/broadcast/internal/audioconvert/audioconvertpipeline.h"

#include <array>

#include "gtest/gtest.h"

namespace {
struct AudioConvertOptions {
  using Ditherer = ttv::RoundingDitherer;
  static constexpr size_t FilterTapCount = 32;
};

template <typename ElementType, typename... Args>
std::array<ElementType, sizeof...(Args)> MakeArray(Args... args) {
  return std::array<ElementType, sizeof...(Args)>{{static_cast<ElementType>(args)...}};
}

template <typename LeftIterator, typename RightIterator>
void CheckNearEqual(LeftIterator leftBegin, LeftIterator leftEnd, RightIterator rightBegin, RightIterator rightEnd,
  double errorMargin = 0.0001) {
  EXPECT_EQ(leftEnd - leftBegin, rightEnd - rightBegin);

  std::equal(leftBegin, leftEnd, rightBegin,
    [errorMargin](typename std::iterator_traits<LeftIterator>::value_type left,
      typename std::iterator_traits<RightIterator>::value_type right) {
      EXPECT_NEAR(left, right, errorMargin);
      return std::abs(static_cast<double>(left) - static_cast<double>(right)) < errorMargin;
    });
}
}  // namespace

TEST(AudioConvert, SampleCache) {
  ttv::SampleCache<int16_t, 8> cache;

  std::vector<size_t> populatedIndexes;
  cache.Populate(0, [&populatedIndexes](size_t index) {
    populatedIndexes.push_back(index);
    return index * index;
  });

  ASSERT_EQ(populatedIndexes, std::vector<size_t>({0, 1, 2, 3, 4, 5, 6, 7}));

  ttv::SampleRange range = cache.GetRange();
  ASSERT_EQ(range.startIndex, 0);
  ASSERT_EQ(range.sampleCount, 8);

  ASSERT_EQ(cache[0], 0);
  ASSERT_EQ(cache[1], 1);
  ASSERT_EQ(cache[2], 4);
  ASSERT_EQ(cache[3], 9);
  ASSERT_EQ(cache[4], 16);
  ASSERT_EQ(cache[5], 25);
  ASSERT_EQ(cache[6], 36);
  ASSERT_EQ(cache[7], 49);

  populatedIndexes.clear();

  // Move our indexes forward.
  cache.Populate(2, [&populatedIndexes](size_t index) {
    populatedIndexes.push_back(index);
    return index * index;
  });

  ASSERT_EQ(populatedIndexes, std::vector<size_t>({8, 9}));

  range = cache.GetRange();
  ASSERT_EQ(range.startIndex, 2);
  ASSERT_EQ(range.sampleCount, 8);

  ASSERT_EQ(cache[2], 4);
  ASSERT_EQ(cache[3], 9);
  ASSERT_EQ(cache[4], 16);
  ASSERT_EQ(cache[5], 25);
  ASSERT_EQ(cache[6], 36);
  ASSERT_EQ(cache[7], 49);
  ASSERT_EQ(cache[8], 64);
  ASSERT_EQ(cache[9], 81);

  populatedIndexes.clear();

  // Move the indexes backward.
  cache.Populate(0, [&populatedIndexes](size_t index) {
    populatedIndexes.push_back(index);
    return index * index;
  });

  ASSERT_EQ(populatedIndexes, std::vector<size_t>({0, 1}));

  range = cache.GetRange();
  ASSERT_EQ(range.startIndex, 0);
  ASSERT_EQ(range.sampleCount, 8);

  ASSERT_EQ(cache[0], 0);
  ASSERT_EQ(cache[1], 1);
  ASSERT_EQ(cache[2], 4);
  ASSERT_EQ(cache[3], 9);
  ASSERT_EQ(cache[4], 16);
  ASSERT_EQ(cache[5], 25);
  ASSERT_EQ(cache[6], 36);
  ASSERT_EQ(cache[7], 49);

  populatedIndexes.clear();

  // Move the indexes to a completely disjunct range.
  cache.Populate(15, [&populatedIndexes](size_t index) {
    populatedIndexes.push_back(index);
    return index * index;
  });

  ASSERT_EQ(populatedIndexes, std::vector<size_t>({15, 16, 17, 18, 19, 20, 21, 22}));

  range = cache.GetRange();
  ASSERT_EQ(range.startIndex, 15);
  ASSERT_EQ(range.sampleCount, 8);

  ASSERT_EQ(cache[15], 225);
  ASSERT_EQ(cache[16], 256);
  ASSERT_EQ(cache[17], 289);
  ASSERT_EQ(cache[18], 324);
  ASSERT_EQ(cache[19], 361);
  ASSERT_EQ(cache[20], 400);
  ASSERT_EQ(cache[21], 441);
  ASSERT_EQ(cache[22], 484);
}

TEST(AudioConvert, CyclingGenerator) {
  struct IndexGenerator {
    using ReturnType = size_t;

    static ReturnType Generate(size_t index, size_t /*length*/) { return index; }
  };

  {
    ttv::LookupTable<ttv::CyclingGenerator<IndexGenerator, 3>, 12> table;

    ASSERT_EQ(table[0], 0);
    ASSERT_EQ(table[1], 3);
    ASSERT_EQ(table[2], 6);
    ASSERT_EQ(table[3], 9);
    ASSERT_EQ(table[4], 1);
    ASSERT_EQ(table[5], 4);
    ASSERT_EQ(table[6], 7);
    ASSERT_EQ(table[7], 10);
    ASSERT_EQ(table[8], 2);
    ASSERT_EQ(table[9], 5);
    ASSERT_EQ(table[10], 8);
    ASSERT_EQ(table[11], 11);
  }

  {
    ttv::LookupTable<ttv::CyclingGenerator<IndexGenerator, 5>, 15> table;

    ASSERT_EQ(table[0], 0);
    ASSERT_EQ(table[1], 5);
    ASSERT_EQ(table[2], 10);
    ASSERT_EQ(table[3], 1);
    ASSERT_EQ(table[4], 6);
    ASSERT_EQ(table[5], 11);
    ASSERT_EQ(table[6], 2);
    ASSERT_EQ(table[7], 7);
    ASSERT_EQ(table[8], 12);
    ASSERT_EQ(table[9], 3);
    ASSERT_EQ(table[10], 8);
    ASSERT_EQ(table[11], 13);
    ASSERT_EQ(table[12], 4);
    ASSERT_EQ(table[13], 9);
    ASSERT_EQ(table[14], 14);
  }
}

TEST(AudioConvert, TestBufferWithOffsets) {
  using InputBufferFormat = ttv::BufferFormat<int8_t, 8000>;
  using OutputBufferFormat = InputBufferFormat;
  ttv::AudioConvertContext<AudioConvertOptions> context;
  auto pipeline = ttv::MakeAudioConvertPipeline<InputBufferFormat, OutputBufferFormat>(context);

  auto inputArray = MakeArray<int8_t>(1, -2, 3, -4, 5, -6, 7, -8);

  std::vector<int8_t> outputVector;
  outputVector.resize(100);

  pipeline.BindInputBuffer(inputArray.data(), {73, inputArray.size()});
  ttv::SampleRange writtenRange = pipeline.TransferToOutputBuffer(outputVector.data(), {73, outputVector.size()});
  pipeline.UnbindInputBuffer();

  ASSERT_EQ(writtenRange.startIndex, 73);
  ASSERT_EQ(writtenRange.sampleCount, 8);

  outputVector.resize(writtenRange.sampleCount);

  CheckNearEqual(inputArray.begin(), inputArray.end(), outputVector.begin(), outputVector.end());
}

TEST(AudioConvert, TestBufferWithTwoChannels) {
  using InputBufferFormat = ttv::BufferFormat<int8_t, 8000, 2>;
  using OutputBufferFormat = InputBufferFormat;

  ttv::AudioConvertContext<AudioConvertOptions> context;
  auto pipeline = ttv::MakeAudioConvertPipeline<InputBufferFormat, OutputBufferFormat>(context);

  auto inputArray = MakeArray<int8_t>(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16);
  size_t sampleCount = inputArray.size() / 2;

  std::vector<int8_t> outputVector;
  outputVector.resize(100);

  pipeline.BindInputBuffer(inputArray.data(), {0, sampleCount});
  ttv::SampleRange writtenRange = pipeline.TransferToOutputBuffer(outputVector.data(), {0, outputVector.size() / 2});
  pipeline.UnbindInputBuffer();

  ASSERT_EQ(writtenRange.startIndex, 0);
  ASSERT_EQ(writtenRange.sampleCount, sampleCount);

  outputVector.resize(writtenRange.sampleCount * 2);

  CheckNearEqual(inputArray.begin(), inputArray.end(), outputVector.begin(), outputVector.end());
}

TEST(AudioConvert, TestSignedToUnsignedOperator) {
  using InputBufferFormat = ttv::BufferFormat<int8_t, 8000>;
  using OutputBufferFormat = ttv::BufferFormat<uint8_t, 8000>;

  ttv::AudioConvertContext<AudioConvertOptions> context;
  auto pipeline = ttv::MakeAudioConvertPipeline<InputBufferFormat, OutputBufferFormat>(context);

  auto inputArray = MakeArray<int8_t>(1, -2, 3, -4, 5, -6, 7, -8);

  std::vector<uint8_t> outputVector;
  outputVector.resize(100);

  pipeline.BindInputBuffer(inputArray.data(), {0, inputArray.size()});
  ttv::SampleRange writtenRange = pipeline.TransferToOutputBuffer(outputVector.data(), {0, outputVector.size()});
  pipeline.UnbindInputBuffer();

  ASSERT_EQ(writtenRange.startIndex, 0);
  ASSERT_EQ(writtenRange.sampleCount, inputArray.size());

  outputVector.resize(writtenRange.sampleCount);

  auto expectedOutput = MakeArray<uint8_t>(129, 126, 131, 124, 133, 122, 135, 120);

  CheckNearEqual(outputVector.begin(), outputVector.end(), expectedOutput.begin(), expectedOutput.end());
}

TEST(AudioConvert, TestUnsignedToSignedOperator) {
  using InputBufferFormat = ttv::BufferFormat<uint8_t, 8000>;
  using OutputBufferFormat = ttv::BufferFormat<int8_t, 8000>;

  ttv::AudioConvertContext<AudioConvertOptions> context;
  auto pipeline = ttv::MakeAudioConvertPipeline<InputBufferFormat, OutputBufferFormat>(context);

  auto inputArray = MakeArray<uint8_t>(129, 126, 131, 124, 133, 122, 135, 120);

  std::vector<int8_t> outputVector;
  outputVector.resize(100);

  pipeline.BindInputBuffer(inputArray.data(), {0, inputArray.size()});
  ttv::SampleRange writtenRange = pipeline.TransferToOutputBuffer(outputVector.data(), {0, outputVector.size()});
  pipeline.UnbindInputBuffer();

  ASSERT_EQ(writtenRange.startIndex, 0);
  ASSERT_EQ(writtenRange.sampleCount, inputArray.size());

  outputVector.resize(writtenRange.sampleCount);

  auto expectedOutput = MakeArray<int8_t>(1, -2, 3, -4, 5, -6, 7, -8);

  CheckNearEqual(outputVector.begin(), outputVector.end(), expectedOutput.begin(), expectedOutput.end());
}

TEST(AudioConvert, TestIntegerToFloatOperator) {
  using InputBufferFormat = ttv::BufferFormat<int8_t, 8000>;
  using OutputBufferFormat = ttv::BufferFormat<double, 8000>;

  ttv::AudioConvertContext<AudioConvertOptions> context;
  auto pipeline = ttv::MakeAudioConvertPipeline<InputBufferFormat, OutputBufferFormat>(context);

  auto inputArray = MakeArray<int8_t>(64, 32, -96, -16, 48, -112, 80, 0);

  std::vector<double> outputVector;
  outputVector.resize(100);

  pipeline.BindInputBuffer(inputArray.data(), {0, inputArray.size()});
  ttv::SampleRange writtenRange = pipeline.TransferToOutputBuffer(outputVector.data(), {0, outputVector.size()});
  pipeline.UnbindInputBuffer();

  ASSERT_EQ(writtenRange.startIndex, 0);
  ASSERT_EQ(writtenRange.sampleCount, inputArray.size());

  outputVector.resize(writtenRange.sampleCount);

  auto expectedOutput = MakeArray<double>(0.5, 0.25, -0.75, -0.125, 0.375, -0.875, 0.625, 0.0);

  CheckNearEqual(outputVector.begin(), outputVector.end(), expectedOutput.begin(), expectedOutput.end());
}

TEST(AudioConvert, TestFloatToIntegerOperator) {
  using InputBufferFormat = ttv::BufferFormat<double, 8000>;
  using OutputBufferFormat = ttv::BufferFormat<int8_t, 8000>;

  ttv::AudioConvertContext<AudioConvertOptions> context;
  auto pipeline = ttv::MakeAudioConvertPipeline<InputBufferFormat, OutputBufferFormat>(context);

  auto inputArray = MakeArray<double>(0.5, 0.25, -0.75, -0.125, 0.375, -0.875, 0.625, 0.0);

  std::vector<int8_t> outputVector;
  outputVector.resize(100);

  pipeline.BindInputBuffer(inputArray.data(), {0, inputArray.size()});
  ttv::SampleRange writtenRange = pipeline.TransferToOutputBuffer(outputVector.data(), {0, outputVector.size()});
  pipeline.UnbindInputBuffer();

  ASSERT_EQ(writtenRange.startIndex, 0);
  ASSERT_EQ(writtenRange.sampleCount, inputArray.size());

  outputVector.resize(writtenRange.sampleCount);

  auto expectedOutput = MakeArray<int8_t>(64, 32, -96, -16, 48, -112, 80, 0);

  CheckNearEqual(outputVector.begin(), outputVector.end(), expectedOutput.begin(), expectedOutput.end());
}

TEST(AudioConvert, TestIncreaseBitDepthOperator) {
  using InputBufferFormat = ttv::BufferFormat<int8_t, 8000>;
  using OutputBufferFormat = ttv::BufferFormat<int16_t, 8000>;

  ttv::AudioConvertContext<AudioConvertOptions> context;
  auto pipeline = ttv::MakeAudioConvertPipeline<InputBufferFormat, OutputBufferFormat>(context);

  auto inputArray = MakeArray<int8_t>(1, -2, 3, -4, 5, -6, 7, -8);

  std::vector<int16_t> outputVector;
  outputVector.resize(100);

  pipeline.BindInputBuffer(inputArray.data(), {0, inputArray.size()});
  ttv::SampleRange writtenRange = pipeline.TransferToOutputBuffer(outputVector.data(), {0, outputVector.size()});
  pipeline.UnbindInputBuffer();

  ASSERT_EQ(writtenRange.startIndex, 0);
  ASSERT_EQ(writtenRange.sampleCount, inputArray.size());

  outputVector.resize(writtenRange.sampleCount);

  auto expectedOutput = MakeArray<int16_t>(256, -512, 768, -1024, 1280, -1536, 1792, -2048);

  CheckNearEqual(outputVector.begin(), outputVector.end(), expectedOutput.begin(), expectedOutput.end());
}

TEST(AudioConvert, TestDecreaseBitDepthOperator) {
  using InputBufferFormat = ttv::BufferFormat<int16_t, 8000>;
  using OutputBufferFormat = ttv::BufferFormat<int8_t, 8000>;

  ttv::AudioConvertContext<AudioConvertOptions> context;
  auto pipeline = ttv::MakeAudioConvertPipeline<InputBufferFormat, OutputBufferFormat>(context);

  // Make a few of these off by 1 in each direction so that we make sure our rounding works correctly.
  auto inputArray = MakeArray<int16_t>(256, -512, 767, -1025, 1281, -1537, 1792, -2048);

  std::vector<int8_t> outputVector;
  outputVector.resize(100);

  pipeline.BindInputBuffer(inputArray.data(), {0, inputArray.size()});
  ttv::SampleRange writtenRange = pipeline.TransferToOutputBuffer(outputVector.data(), {0, outputVector.size()});
  pipeline.UnbindInputBuffer();

  ASSERT_EQ(writtenRange.startIndex, 0);
  ASSERT_EQ(writtenRange.sampleCount, inputArray.size());

  outputVector.resize(writtenRange.sampleCount);

  auto expectedOutput = MakeArray<int8_t>(1, -2, 3, -4, 5, -6, 7, -8);

  CheckNearEqual(outputVector.begin(), outputVector.end(), expectedOutput.begin(), expectedOutput.end());
}

TEST(AudioConvert, TestDownsampleOperator) {
  using InputBufferFormat = ttv::BufferFormat<double, 8000>;
  using OutputBufferFormat = ttv::BufferFormat<double, 4000>;

  ttv::AudioConvertContext<AudioConvertOptions> context;
  auto pipeline = ttv::MakeAudioConvertPipeline<InputBufferFormat, OutputBufferFormat>(context);

  std::array<double, 64> inputArray;

  // Resampling involves low-pass filtering, so we need to use some more organic values here.
  for (size_t index = 0; index < 64; index++) {
    double time = static_cast<double>(index) / 8000.0;

    // Make a sine wave at 440 hz with amplitude 0.7
    inputArray[index] = 0.7 * sin(440.0 * 2.0 * ttv::kPi * time);
  }

  std::vector<double> outputVector;
  outputVector.resize(100);

  pipeline.BindInputBuffer(inputArray.data(), {0, inputArray.size()});
  ttv::SampleRange writtenRange = pipeline.TransferToOutputBuffer(outputVector.data(), {0, outputVector.size()});
  pipeline.UnbindInputBuffer();

  ASSERT_EQ(writtenRange.startIndex, 0);
  ASSERT_EQ(writtenRange.sampleCount, 16);

  auto expectedOutput = MakeArray<double>(-0.559785, -0.163421, 0.307949, 0.637978, 0.675192, 0.402511, -0.0549122,
    -0.487132, -0.695771, -0.58507, -0.205837, 0.26787, 0.618631, 0.685458, 0.437677, -0.0109859);

  outputVector.resize(writtenRange.sampleCount);

  CheckNearEqual(outputVector.begin(), outputVector.end(), expectedOutput.begin(), expectedOutput.end());

  // Build the next set of values of the same sine wave.
  std::array<double, 64> adjacentArray;
  for (size_t index = 0; index < 64; index++) {
    double time = static_cast<double>(index + 64) / 8000.0;

    // Make a sine wave at 440 hz with amplitude 0.7
    adjacentArray[index] = 0.7 * sin(440.0 * 2.0 * ttv::kPi * time);
  }

  outputVector.resize(100);

  pipeline.BindInputBuffer(adjacentArray.data(), {64, adjacentArray.size()});
  writtenRange = pipeline.TransferToOutputBuffer(outputVector.data(), {0, outputVector.size()});
  pipeline.UnbindInputBuffer();

  ASSERT_EQ(writtenRange.startIndex, 16);
  ASSERT_EQ(writtenRange.sampleCount, 32);

  auto adjacentExpectedOutput = MakeArray<double>(-0.454607, -0.689575, -0.608046, -0.247441, 0.226733, 0.596843,
    0.693018, 0.471115, 0.0329837, -0.420287, -0.680657, -0.628623, -0.288068, 0.184702, 0.572699, 0.697843, 0.502695,
    0.0768232, -0.384308, -0.669052, -0.646719, -0.327559, 0.141942, 0.546295, 0.699914, 0.53229, 0.120359, -0.346813,
    -0.654807, -0.662263, -0.365757, 0.0986217);

  outputVector.resize(writtenRange.startIndex + writtenRange.sampleCount);

  CheckNearEqual(outputVector.begin() + writtenRange.startIndex, outputVector.end(), adjacentExpectedOutput.begin(),
    adjacentExpectedOutput.end());

  // Set a non-adjacent buffer source, and make sure it doesn't use residual samples.
  outputVector.resize(2000);

  pipeline.BindInputBuffer(inputArray.data(), {1024, inputArray.size()});
  writtenRange = pipeline.TransferToOutputBuffer(outputVector.data(), {0, outputVector.size()});
  pipeline.UnbindInputBuffer();

  ASSERT_EQ(writtenRange.startIndex, 512);
  ASSERT_EQ(writtenRange.sampleCount, 16);

  auto nonAdjacentExpectedOutput = MakeArray<double>(-0.559785, -0.163421, 0.307949, 0.637978, 0.675192, 0.402511,
    -0.0549122, -0.487132, -0.695771, -0.58507, -0.205837, 0.26787, 0.618631, 0.685458, 0.437677, -0.0109859);

  outputVector.resize(writtenRange.startIndex + writtenRange.sampleCount);

  CheckNearEqual(outputVector.begin() + writtenRange.startIndex, outputVector.end(), nonAdjacentExpectedOutput.begin(),
    nonAdjacentExpectedOutput.end());
}
