#include "twitchsdk/broadcast/internal/pch.h"

#include "twitchsdk/broadcast/internal/ingestsampledata.h"

/*
 *  This is a custom data format that is encoded as follows:
 *
 *  Preamble        32 bytes                Identifies this as the encoding type: "TwitchPlatformsSDKTestVideoV001:"
 *  SPSDataLength   4 bytes                 Length of SPS data represented. 32-bit unsigned integer
 *  PPSDataLength   4 bytes                 Length of PPS data represented. 32-bit unsigned integer
 *  FrameCount      4 bytes                 Number of frames represented. 32-bit unsigned integer
 *  FrameMetaData   8 bytes * FrameCount    Metadata about each frame, described below:
 *                                              KeyFrame        4 Bytes     Whether the frame is a key frame. 32-bit
 * unsigned integer, non-zero value indicates it's a key frame. FrameLength     4 Bytes     Length of the data for this
 * frame. 32-bit unsigned integer. SPSData         SPSDataLength bytes     The raw SPS data. PPSData PPSDataLength bytes
 * The raw PPS data.
 *
 *  FrameData       FrameLength bytes       The raw frame data.
 *  (repeating for each frame)
 *
 *  The frame data for each frame appears in the same order as the FrameMetadata provided above.
 *
 *  All data is little-endian ordered.
 */

namespace {
constexpr char kPreamble[] = "TwitchPlatformsSDKTestVideoV001:";
}

TTV_ErrorCode ttv::broadcast::IngestSampleData::Parse(const uint8_t* buffer, size_t length) {
  TTV_ErrorCode ec = TryParse(buffer, length);
  if (TTV_FAILED(ec)) {
    Clear();
  }
  return ec;
}

TTV_ErrorCode ttv::broadcast::IngestSampleData::TryParse(const uint8_t* buffer, size_t length) {
  const uint8_t* current = buffer;
  const uint8_t* end = buffer + length;

  // Preamble
  if (current + 32 > end)
    return TTV_EC_INVALID_FORMAT;
  if (memcmp(current, kPreamble, 32) != 0) {
    return TTV_EC_INVALID_FORMAT;
  }
  current += 32;

  // SPSDataLength
  if (current + 4 > end)
    return TTV_EC_INVALID_FORMAT;
  uint32_t spsDataLength = *reinterpret_cast<const uint32_t*>(current);
  current += 4;

  // PPSDataLength
  if (current + 4 > end)
    return TTV_EC_INVALID_FORMAT;
  uint32_t ppsDataLength = *reinterpret_cast<const uint32_t*>(current);
  current += 4;

  // FrameCount
  if (current + 4 > end)
    return TTV_EC_INVALID_FORMAT;
  uint32_t frameCount = *reinterpret_cast<const uint32_t*>(current);
  current += 4;

  // FrameMetaData
  struct FrameMetaData {
    uint32_t size;
    bool keyFrame;
  };

  std::vector<FrameMetaData> allFrameMetadata;
  for (uint32_t i = 0; i < frameCount; i++) {
    FrameMetaData metaData;

    // KeyFrame
    if (current + 4 > end)
      return TTV_EC_INVALID_FORMAT;
    metaData.keyFrame = (*reinterpret_cast<const uint32_t*>(current) != 0);
    current += 4;

    // FrameLength
    if (current + 4 > end)
      return TTV_EC_INVALID_FORMAT;
    metaData.size = *reinterpret_cast<const uint32_t*>(current);
    current += 4;

    allFrameMetadata.push_back(metaData);
  }

  // SPSData
  if (current + spsDataLength > end)
    return TTV_EC_INVALID_FORMAT;
  spsData.reserve(spsDataLength);
  spsData.assign(current, current + spsDataLength);
  current += spsDataLength;

  // PPSData
  if (current + ppsDataLength > end)
    return TTV_EC_INVALID_FORMAT;
  ppsData.reserve(ppsDataLength);
  ppsData.assign(current, current + ppsDataLength);
  current += ppsDataLength;

  // FrameData
  for (const auto& metaData : allFrameMetadata) {
    frames.emplace_back();
    auto& frame = frames.back();
    frame.keyFrame = metaData.keyFrame;

    if (current + metaData.size > end)
      return TTV_EC_INVALID_FORMAT;
    frame.data.reserve(metaData.size);
    frame.data.assign(current, current + metaData.size);
    current += metaData.size;
  }

  if (current != end)
    return TTV_EC_INVALID_FORMAT;

  return TTV_EC_SUCCESS;
}

void ttv::broadcast::IngestSampleData::Clear() {
  spsData.clear();
  ppsData.clear();
  frames.clear();
}
