/****************************************************************************
 * Twitch SDK
 *
 * This software is supplied under the terms of a license agreement with
 * Twitch Interactive, Inc. and may not be copied or used except in accordance
 * with the terms of that agreement
 *
 * Copyright (c) 2012-2016 Twitch Interactive, Inc.
 ***************************************************************************/

#include "twitchsdk/broadcast/internal/pch.h"

#include "twitchsdk/broadcast/internal/wavefilewriter.h"

ttv::broadcast::WaveFileWriter::WaveFileWriter()
    : mFile(nullptr), mSampleRate(0), mNumChannels(0), mDuration(0), mIsFloat(false) {}

ttv::broadcast::WaveFileWriter::~WaveFileWriter() {
  Close();
}

bool ttv::broadcast::WaveFileWriter::Open(const std::string& path, uint sampleRate, uint numChannels, bool isFloat) {
  Close();

  mSampleRate = sampleRate;
  mNumChannels = numChannels;
  mIsFloat = isFloat;
  mFile = fopen(path.c_str(), "wb");

  return mFile != nullptr;
}

bool ttv::broadcast::WaveFileWriter::WriteSamples(const float* samples, uint numSamples) {
  if (mFile == nullptr) {
    return false;
  }

  mDuration += static_cast<float>(numSamples) / static_cast<float>(mSampleRate) / static_cast<float>(mNumChannels);
  fwrite(samples, sizeof(float), numSamples, mFile);

  return true;
}

bool ttv::broadcast::WaveFileWriter::WriteSamples(const int16_t* samples, uint numSamples) {
  if (mFile == nullptr) {
    return false;
  }

  mDuration += static_cast<float>(numSamples) / static_cast<float>(mSampleRate) / static_cast<float>(mNumChannels);
  fwrite(samples, sizeof(int16_t), numSamples, mFile);

  return true;
}

bool ttv::broadcast::WaveFileWriter::Close() {
  if (mFile != nullptr) {
    // update the wave header with the file length
    uint fileSize = static_cast<uint>(ftell(mFile));
    fseek(mFile, SEEK_SET, 0);
    WriteWaveHeader(fileSize);

    fclose(mFile);
    mFile = nullptr;
  }

  mSampleRate = 0;
  mNumChannels = 0;
  mDuration = 0;
  mIsFloat = false;

  return true;
}

void ttv::broadcast::WaveFileWriter::WriteWaveHeader(uint fileLength) {
  // http://mathmatrix.narod.ru/Wavefmt.html
  // TODO: This code might need an endian swap if we ever compile on big endian

  uint32_t headerSize = 44;

  const char* riff = "RIFF";
  fwrite(riff, 4, 1, mFile);

  uint32_t fileSize = fileLength - 8;
  fwrite(&fileSize, 4, 1, mFile);

  const char* wave = "WAVE";
  fwrite(wave, 4, 1, mFile);

  const char* fmt = "fmt ";
  fwrite(fmt, 4, 1, mFile);

  uint32_t sizeOfWaveSectionChunk = 16;
  fwrite(&sizeOfWaveSectionChunk, 4, 1, mFile);

  // WAVE type format
  uint16_t waveTypeFormat = mIsFloat ? 3 : 1;
  fwrite(&waveTypeFormat, 2, 1, mFile);

  // num channels
  uint16_t numChannels16 = static_cast<uint16_t>(mNumChannels);
  fwrite(&numChannels16, 2, 1, mFile);

  // Samples per second
  uint32_t sampleRate32 = static_cast<uint32_t>(mSampleRate);
  fwrite(&sampleRate32, 4, 1, mFile);

  // Bytes per second
  // Speed of data stream = Number_of_channels*Samples_per_second*Bits_per_Sample/8
  uint32_t bytesPerSample = mIsFloat ? sizeof(float) : sizeof(int16_t);
  uint32_t bytesPerSecond = mNumChannels * mSampleRate * bytesPerSample;
  fwrite(&bytesPerSecond, 4, 1, mFile);

  // Block alignment
  // Number_of_channels*Bits_per_Sample/8
  uint16_t blockAlignment = static_cast<uint16_t>(mNumChannels * bytesPerSample);
  fwrite(&blockAlignment, 2, 1, mFile);

  // Bits per sample
  int16_t bitsPerSample = 8 * (mIsFloat ? sizeof(float) : sizeof(int16_t));
  fwrite(&bitsPerSample, 2, 1, mFile);

  // Data description header
  const char* data = "data";  // "atad";
  fwrite(data, 4, 1, mFile);

  // Size of data
  uint32_t dataSize = fileLength - headerSize;
  fwrite(&dataSize, 4, 1, mFile);
}
