/****************************************************************************
 * Twitch SDK
 *
 * This software is supplied under the terms of a license agreement with
 * Twitch Interactive, Inc. and may not be copied or used except in accordance
 * with the terms of that agreement
 *
 * Copyright (c) 2012-2016 Twitch Interactive, Inc.
 ***************************************************************************/

#include "twitchsdk/broadcast/internal/pch.h"

#include "twitchsdk/broadcast/internal/rgbyuv_sse.h"

#include "twitchsdk/core/assertion.h"

#include <intrin.h>

#define swizzle_xzyw 0xD8
#define swizzle_ywxz 0x8D
#define swizzle_zwxy 0x4E

namespace {
using namespace ttv::broadcast;

const char mz = -127;

//----------------------------------------------------------------------------------
__m128i BGRtoYUVComponent(
  const __m128i bgr, const __m128i factors, const __m128i normalises, const __m128i& outputMask) {
  const __m128i zeros = _mm_setzero_si128();
  const __m128i const128 = _mm_setr_epi32(128, 128, 128, 128);

  /////////////////////////////////////////////////////////////////////////////////////
  // zeros => 8 [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
  // rgb => 8 [ b0, g0, r0, a0, b1, g1, r1, a1, b2, g2, r2, a2, b3, g3, r3, a3 ]
  //
  // left => 16 [ b0, g0, r0, a0, b1, g1, r1, a1 ]
  // right => 16 [ b2, g2, r2, a2, b3, g3, r3, a3 ]
  /////////////////////////////////////////////////////////////////////////////////////

  __m128i left = _mm_unpacklo_epi8(bgr, zeros);
  __m128i right = _mm_unpackhi_epi8(bgr, zeros);

  /////////////////////////////////////////////////////////////////////////////////////
  // left => 16    [ b0, g0, r0, a0, b1, g1, r1, a1 ]
  // right => 16   [ b2, g2, r2, a2, b3, g3, r3, a3 ]
  // factors => 16 [ 25, 129, 66, 0, 25, 129, 66, 0 ]
  //
  // leftProd => 32    [ b0*25 + g0*129, r0*66 + 0, b1*25 + g1*129, r1*66 + 0 ]
  // rightProd => 32   [ b2*25 + g2*129, r2*66 + 0, b3*25 + g3*129, r3*66 + 0 ]
  /////////////////////////////////////////////////////////////////////////////////////
  __m128i leftProd = _mm_madd_epi16(left, factors);
  __m128i rightProd = _mm_madd_epi16(right, factors);

  //////////////////////////////////////////////////////////////////////////////////////
  // Shuffle
  // sizzle = 0xD8 = 11011000
  // leftshuff  => 32 [ lp0, lp2, lp1, lp3 ]
  // rightshuff => 32 [ rp0, rp2, rp1, rp3 ]
  ///////////////////////////////////////////////////////////////////////////////////////
  __m128i leftShuff = _mm_shuffle_epi32(leftProd, swizzle_xzyw);
  __m128i rightShuff = _mm_shuffle_epi32(rightProd, swizzle_xzyw);

  //////////////////////////////////////////////////////////////////////////////////////
  // Interleve
  // low => 32 [ lp0, rp0, lp2, rp2 ] <-- lo
  // hi =>  32 [ lp1, rp1, lp3, rp3 ] <-- hi
  //////////////////////////////////////////////////////////////////////////////////////
  __m128i low = _mm_unpacklo_epi32(leftShuff, rightShuff);
  __m128i hi = _mm_unpackhi_epi32(leftShuff, rightShuff);

  //////////////////////////////////////////////////////////////////////////////////////
  // Add
  // lowPlusHigh => 32 [ lp0+lp1, rp0+rp1, lp2+lp3, rp2+rp3 ]
  // temp => 32        [ lp0+lp1+128, rp0+rp1+128, lp2+lp3+128, rp2+rp3+128 ]
  // sum => 8          [ 0, lp0+lp1+128 / 256  +16 , 0 , 0 ,
  //                     0, rp0+rp1+128 / 256  +16 , 0 , 0 ,
  //                     0, lp2+lp3+128 / 256  +16 , 0 , 0 ,
  //                     0, rp2+rp3+128 / 256  +16 , 0 , 0 ]
  // Todo add 128 + <Y/U/V value> * 256 in one step
  //////////////////////////////////////////////////////////////////////////////////////
  __m128i lowPlusHigh = _mm_add_epi32(low, hi);

  __m128i temp = _mm_add_epi32(lowPlusHigh, const128);

  __m128i sum = _mm_add_epi8(temp, normalises);

  __m128i maskedSum = _mm_shuffle_epi8(sum, outputMask);
  return maskedSum;
}

//---------------------------------------------------------------------------------
__m128i CalcY(__m128i& p0, __m128i& p1, __m128i& p2, __m128i& p3, __m128i& YFactors) {
  const __m128i normalisers = _mm_setr_epi8(0, 16, 0, 0, 0, 16, 0, 0, 0, 16, 0, 0, 0, 16, 0, 0);
  const __m128i finalMask0 = _mm_setr_epi8(1, 9, 5, 13, mz, mz, mz, mz, mz, mz, mz, mz, mz, mz, mz, mz);
  const __m128i finalMask1 = _mm_setr_epi8(mz, mz, mz, mz, 1, 9, 5, 13, mz, mz, mz, mz, mz, mz, mz, mz);
  const __m128i finalMask2 = _mm_setr_epi8(mz, mz, mz, mz, mz, mz, mz, mz, 1, 9, 5, 13, mz, mz, mz, mz);
  const __m128i finalMask3 = _mm_setr_epi8(mz, mz, mz, mz, mz, mz, mz, mz, mz, mz, mz, mz, 1, 9, 5, 13);

  __m128i result0 = BGRtoYUVComponent(p0, YFactors, normalisers, finalMask0);
  __m128i result1 = BGRtoYUVComponent(p1, YFactors, normalisers, finalMask1);
  __m128i result2 = BGRtoYUVComponent(p2, YFactors, normalisers, finalMask2);
  __m128i result3 = BGRtoYUVComponent(p3, YFactors, normalisers, finalMask3);

  __m128i result_half0 = _mm_add_epi64(result0, result1);
  __m128i result_half1 = _mm_add_epi64(result2, result3);

  return _mm_add_epi64(result_half0, result_half1);
}

//---------------------------------------------------------------------------------
void DoY(const uint8_t* rgbBuffer, uint width, uint height, uint8_t* YBuffer, __m128i& YFactors, bool verticalFlip) {
  TTV_ASSERT(width % 4 == 0);

  uint widthIn128 = width / 4;  // 4 32-bit pixels fit in a m128

  const __m128i* input = reinterpret_cast<const __m128i*>(rgbBuffer);
  __m128i* output = reinterpret_cast<__m128i*>(YBuffer);

  int nextRowOffset = 0;
  if (verticalFlip) {
    // Start at the beginning of te last row of Y values
    output = reinterpret_cast<__m128i*>(YBuffer + width * (height - 1));

    // We've processed one row of Y values, jump to the beginning of the previous row
    // (that's why we have x2 to jump back over the current row AND the previous row)
    nextRowOffset = -1 * (widthIn128 / 4) * 2;
  }

  for (uint row = 0; row < height; ++row) {
    for (uint col = 0; col < widthIn128; col += 4) {
      __m128i curr0 = _mm_load_si128(input++);
      __m128i curr1 = _mm_load_si128(input++);
      __m128i curr2 = _mm_load_si128(input++);
      __m128i curr3 = _mm_load_si128(input++);

      __m128i result = CalcY(curr0, curr1, curr2, curr3, YFactors);
      _mm_store_si128(output++, result);
    }

    output += nextRowOffset;
  }
}

//---------------------------------------------------------------------------------
__m128i CalcUorV(__m128i& top0, __m128i& top1, __m128i& top2, __m128i& top3, __m128i& bottom0, __m128i& bottom1,
  __m128i& bottom2, __m128i& bottom3, __m128i& factors) {
  const __m128i normalisers = _mm_setr_epi8(0, -127, 0, 0, 0, -127, 0, 0, 0, -127, 0, 0, 0, -127, 0, 0);
  const __m128i finalMask0 = _mm_setr_epi8(1, mz, 5, mz, mz, mz, mz, mz, 9, mz, 13, mz, mz, mz, mz, mz);
  const __m128i finalMask1 = _mm_setr_epi8(mz, mz, mz, mz, 1, mz, 5, mz, mz, mz, mz, mz, 9, mz, 13, mz);

  const __m128i rightToLeftMask16 = _mm_setr_epi8(8, 9, 10, 11, 12, 13, 14, 15, mz, mz, mz, mz, mz, mz, mz, mz);
  const __m128i leftToRightMask16 = _mm_setr_epi8(mz, mz, mz, mz, mz, mz, mz, mz, 0, 1, 2, 3, 4, 5, 6, 7);
  const __m128i zeroOutRightMask16 = _mm_setr_epi8(0, 1, 2, 3, 4, 5, 6, 7, mz, mz, mz, mz, mz, mz, mz, mz);

  __m128i topVal0 = BGRtoYUVComponent(top0, factors, normalisers, finalMask0);
  topVal0 = _mm_add_epi64(
    topVal0, BGRtoYUVComponent(top1, factors, normalisers, finalMask1));  // u0, u2, u4, u6, u1, u3, u5, u7  (16-bit)
  __m128i topVal1 = BGRtoYUVComponent(top2, factors, normalisers, finalMask0);
  topVal1 = _mm_add_epi64(topVal1,
    BGRtoYUVComponent(top3, factors, normalisers, finalMask1));  // u8, u10, u12, u14, u9, u11, u13, u15 (16-bit)

  __m128i bottomVal0 = BGRtoYUVComponent(bottom0, factors, normalisers, finalMask0);
  bottomVal0 = _mm_add_epi64(bottomVal0,
    BGRtoYUVComponent(bottom1, factors, normalisers, finalMask1));  // u0', u2', u4', u6', u1', u3', u5', u7' (16-bit)
  __m128i bottomVal1 = BGRtoYUVComponent(bottom2, factors, normalisers, finalMask0);
  bottomVal1 = _mm_add_epi64(bottomVal1, BGRtoYUVComponent(bottom3, factors, normalisers,
                                           finalMask1));  // u8', u10', u12', u14', u9', u11', u13', u15' (16-bit)

  __m128i sum0 = _mm_add_epi16(topVal0, bottomVal0);  // U0, U2, U4, U6, U1, U3, U5, U7 (16-bit where Ui = ui + ui')
  __m128i sum1 =
    _mm_add_epi16(topVal1, bottomVal1);  // U8, U10, U12, U14, U9, U11, U13, U15 (16-bit where Ui = ui + ui')

  __m128i swap0 = _mm_shuffle_epi8(sum0, rightToLeftMask16);  // U1, U3, U5, U7, 0, 0, 0, 0  (16-bit)
  __m128i swap1 = _mm_shuffle_epi8(sum1, rightToLeftMask16);  // U9, U11, U13, U15, 0, 0, 0, 0  (16-bit)

  __m128i result0 = _mm_add_epi16(
    _mm_shuffle_epi8(sum0, zeroOutRightMask16), swap0);  // U0+U1, U2+U3, U4+U5, U6+U7, 0, 0, 0, 0   (16-bit)
  __m128i result1 = _mm_add_epi16(sum1,
    swap1);  // U8+U9, U10+U11, U12+U13, U14+U15, [U9], [U11], [U13], [U15]  (16-bit)  Note: The values in the right
             // half will be overwritten when we do a left->right shuffle below

  __m128i swapResult1 =
    _mm_shuffle_epi8(result1, leftToRightMask16);  // 0, 0, 0, 0, U8+U9, U10+U11, U12+U13, U14+U15 (16-bit)

  __m128i sum =
    _mm_add_epi16(result0, swapResult1);  // U0+U1, U2+U3, U4+U5, U6+U7, U8+U9, U10+U11, U12+U13, U14+U15 (16-bit)

  __m128i avg = _mm_srli_epi16(
    sum, 2);  // (U0+U1)/4, (U2+U3)/4, (U4+U5)/4, (U6+U7)/4, (U8+U9)/4, (U10+U11)/4, (U12+U13)/4, (U14+U15)/4 (16-bit)

  return avg;
}

//---------------------------------------------------------------------------------
void DoUV_Packed(const uint8_t* rgbBuffer, uint width, uint height, uint8_t* UVBuffer, __m128i& UFactors,
  __m128i& VFactors, bool verticalFlip) {
  const uint widthIn128 = width / 4;  // 4 32-bit BGRA pixels fit in a m128

  const __m128i* inputTop = reinterpret_cast<const __m128i*>(rgbBuffer);
  const __m128i* inputBot = inputTop + width / 4;

  __m128i* outputUV = reinterpret_cast<__m128i*>(UVBuffer);

  int nextRowOffset = 0;
  if (verticalFlip) {
    // Start at the beginning of the last row of UV values
    outputUV = reinterpret_cast<__m128i*>(UVBuffer + width * (height / 2 - 1));

    // We've processed one row of UV values, jump to the beginning of the previous row
    // (that's why we have x2 to jump back over the current row AND the previous row)
    nextRowOffset = -1 * (widthIn128 / 4) * 2;
  }

  const __m128i mask16to8ShiftRightBy1Byte = _mm_setr_epi8(mz, 0, mz, 2, mz, 4, mz, 6, mz, 8, mz, 10, mz, 12, mz, 14);

  for (uint row = 0; row < height; row += 2) {
    for (uint col = 0; col < widthIn128; col += 4) {
      __m128i top0 = _mm_load_si128(inputTop++);
      __m128i bottom0 = _mm_load_si128(inputBot++);
      __m128i top1 = _mm_load_si128(inputTop++);
      __m128i bottom1 = _mm_load_si128(inputBot++);
      __m128i top2 = _mm_load_si128(inputTop++);
      __m128i bottom2 = _mm_load_si128(inputBot++);
      __m128i top3 = _mm_load_si128(inputTop++);
      __m128i bottom3 = _mm_load_si128(inputBot++);

      __m128i resultU = CalcUorV(top0, top1, top2, top3, bottom0, bottom1, bottom2, bottom3, UFactors);
      __m128i resultV = CalcUorV(top0, top1, top2, top3, bottom0, bottom1, bottom2, bottom3, VFactors);

      __m128i resultVShifted = _mm_shuffle_epi8(resultV, mask16to8ShiftRightBy1Byte);
      __m128i resultUV = _mm_add_epi64(resultU, resultVShifted);

      _mm_store_si128(outputUV++, resultUV);
    }

    inputTop += widthIn128;
    inputBot += widthIn128;

    outputUV += nextRowOffset;
  }
}

//---------------------------------------------------------------------------------
void DoUV(const uint8_t* rgbBuffer, uint width, uint height, uint8_t* UVBuffer, __m128i& UFactors, __m128i& VFactors,
  YUVFormat yuvFormat, bool verticalFlip) {
  if (yuvFormat == YUVFormat::TTV_YUV_NV12) {
    DoUV_Packed(rgbBuffer, width, height, UVBuffer, UFactors, VFactors, verticalFlip);
  } else if (yuvFormat == YUVFormat::TTV_YUV_I420 || yuvFormat == YUVFormat::TTV_YUV_YV12) {
    const uint widthIn128 = width / 4;  // 4 32-bit BGRA pixels fit in a m128

    const __m128i* inputTop = reinterpret_cast<const __m128i*>(rgbBuffer);
    const __m128i* inputBot = inputTop + width / 4;

    uint8_t* uBuffer = UVBuffer;
    uint8_t* vBuffer = UVBuffer + width * height / 4;
    __m128i* outputU = reinterpret_cast<__m128i*>(uBuffer);
    __m128i* outputV = reinterpret_cast<__m128i*>(vBuffer);

    int nextRowOffset = 0;
    if (verticalFlip) {
      // Start at the beginning of the last row of U and V values
      outputU = reinterpret_cast<__m128i*>(uBuffer + (width / 2) * (height / 2 - 1));
      outputV = reinterpret_cast<__m128i*>(vBuffer + (width / 2) * (height / 2 - 1));

      // We've processed one row of U and V values, jump to the beginning of the previous row
      // (that's why we have x2 to jump back over the current row AND the previous row)
      nextRowOffset = -1 * (widthIn128 / 8) * 2;
    }

    if (yuvFormat == YUVFormat::TTV_YUV_YV12) {
      // Reverse the placement of U and V buffers
      __m128i* tmp = outputU;
      outputU = outputV;
      outputV = tmp;
    }

    const __m128i mask16to8Lo = _mm_setr_epi8(0, 2, 4, 6, 8, 10, 12, 14, mz, mz, mz, mz, mz, mz, mz, mz);
    const __m128i mask16to8Hi = _mm_setr_epi8(mz, mz, mz, mz, mz, mz, mz, mz, 0, 2, 4, 6, 8, 10, 12, 14);

    for (uint row = 0; row < height; row += 2) {
      for (uint col = 0; col < widthIn128; col += 8) {
        __m128i top0 = _mm_load_si128(inputTop++);
        __m128i bottom0 = _mm_load_si128(inputBot++);
        __m128i top1 = _mm_load_si128(inputTop++);
        __m128i bottom1 = _mm_load_si128(inputBot++);
        __m128i top2 = _mm_load_si128(inputTop++);
        __m128i bottom2 = _mm_load_si128(inputBot++);
        __m128i top3 = _mm_load_si128(inputTop++);
        __m128i bottom3 = _mm_load_si128(inputBot++);

        __m128i resultU = CalcUorV(top0, top1, top2, top3, bottom0, bottom1, bottom2, bottom3, UFactors);
        __m128i resultV = CalcUorV(top0, top1, top2, top3, bottom0, bottom1, bottom2, bottom3, VFactors);

        __m128i resultULo = _mm_shuffle_epi8(resultU,
          mask16to8Lo);  // (U0+U1)/4, (U2+U3)/4, (U4+U5)/4, (U6+U7)/4, (U8+U9)/4, (U10+U11)/4, (U12+U13)/4,
                         // (U14+U15)/4, 0, 0, 0, 0, 0, 0, 0, 0 (8-bit)
        __m128i resultVLo = _mm_shuffle_epi8(resultV,
          mask16to8Lo);  // (U0+U1)/4, (U2+U3)/4, (U4+U5)/4, (U6+U7)/4, (U8+U9)/4, (U10+U11)/4, (U12+U13)/4,
                         // (U14+U15)/4, 0, 0, 0, 0, 0, 0, 0, 0 (8-bit)

        top0 = _mm_load_si128(inputTop++);
        bottom0 = _mm_load_si128(inputBot++);
        top1 = _mm_load_si128(inputTop++);
        bottom1 = _mm_load_si128(inputBot++);
        top2 = _mm_load_si128(inputTop++);
        bottom2 = _mm_load_si128(inputBot++);
        top3 = _mm_load_si128(inputTop++);
        bottom3 = _mm_load_si128(inputBot++);

        resultU = CalcUorV(top0, top1, top2, top3, bottom0, bottom1, bottom2, bottom3, UFactors);
        resultV = CalcUorV(top0, top1, top2, top3, bottom0, bottom1, bottom2, bottom3, VFactors);

        __m128i resultUHi = _mm_shuffle_epi8(resultU,
          mask16to8Hi);  // 0, 0, 0, 0, 0, 0, 0, 0, (U0+U1)/4, (U2+U3)/4, (U4+U5)/4, (U6+U7)/4, (U8+U9)/4, (U10+U11)/4,
                         // (U12+U13)/4, (U14+U15)/4 (8-bit)
        __m128i resultVHi = _mm_shuffle_epi8(resultV,
          mask16to8Hi);  // 0, 0, 0, 0, 0, 0, 0, 0, (U0+U1)/4, (U2+U3)/4, (U4+U5)/4, (U6+U7)/4, (U8+U9)/4, (U10+U11)/4,
                         // (U12+U13)/4, (U14+U15)/4 (8-bit)

        _mm_store_si128(outputU++, _mm_add_epi64(resultULo, resultUHi));
        _mm_store_si128(outputV++, _mm_add_epi64(resultVLo, resultVHi));
      }

      inputTop += widthIn128;
      inputBot += widthIn128;

      outputU += nextRowOffset;
      outputV += nextRowOffset;
    }
  } else {
    TTV_ASSERT(false);
  }
}
}  // namespace

//---------------------------------------------------------------------------------
void ttv::broadcast::RGBtoYUV_SSE(const uint8_t* rgbBuffer, uint32_t bgraMask, uint width, uint height,
  uint8_t* YBuffer, uint8_t* UVBuffer, YUVFormat yuvFormat, bool verticalFlip) {
  uint8_t* factorsMask = reinterpret_cast<uint8_t*>(&bgraMask);
  const __m128i bgraFactorsY = _mm_setr_epi16(25, 129, 66, 0, 25, 129, 66, 0);
  const __m128i bgraFactorsU = _mm_setr_epi16(112, -74, -38, 0, 112, -74, -38, 0);
  const __m128i bgraFactorsV = _mm_setr_epi16(-18, -94, 112, 0, -18, -94, 112, 0);

  // The mask is based off of BGRA (0x00010203), so for example, RGBA would be 0x02010003. Note that since the mask is a
  // little-endian 32-bit integer the 4 bytes of it are in reverse order. We use each byte of the mask to shuffle the
  // factor values into the right position.
  //
  const __m128i bgraSwizzle = _mm_setr_epi8(factorsMask[3] * 2, factorsMask[3] * 2 + 1, factorsMask[2] * 2,
    factorsMask[2] * 2 + 1, factorsMask[1] * 2, factorsMask[1] * 2 + 1, factorsMask[0] * 2, factorsMask[0] * 2 + 1,
    factorsMask[3] * 2, factorsMask[3] * 2 + 1, factorsMask[2] * 2, factorsMask[2] * 2 + 1, factorsMask[1] * 2,
    factorsMask[1] * 2 + 1, factorsMask[0] * 2, factorsMask[0] * 2 + 1);

  __m128i YFactors = _mm_shuffle_epi8(bgraFactorsY, bgraSwizzle);
  __m128i UFactors = _mm_shuffle_epi8(bgraFactorsU, bgraSwizzle);
  __m128i VFactors = _mm_shuffle_epi8(bgraFactorsV, bgraSwizzle);

  DoY(rgbBuffer, width, height, YBuffer, YFactors, verticalFlip);
  DoUV(rgbBuffer, width, height, UVBuffer, UFactors, VFactors, yuvFormat, verticalFlip);
}
