/****************************************************************************
 * Twitch SDK
 *
 * This software is supplied under the terms of a license agreement with
 * Twitch Interactive, Inc. and may not be copied or used except in accordance
 * with the terms of that agreement
 *
 * Copyright (c) 2012-2016 Twitch Interactive, Inc.
 ***************************************************************************/

#include "twitchsdk/core/internal/pch.h"

#include "twitchsdk/core/cfhttprequest.h"

#include "twitchsdk/core/httprequestutils.h"

#include <CFNetwork/CFNetwork.h>
#include <zlib.h>

namespace {
const CFAbsoluteTime kMaximumRequestTime = 600.;
const CFAbsoluteTime kDefaultDataTimeout = 10.;
const size_t kReadStreamBufferSize = 4096;
const size_t kDecompressionBufferSize = 65536;
const CFOptionFlags kObservedCFStreamEvents =
  kCFStreamEventHasBytesAvailable | kCFStreamEventEndEncountered | kCFStreamEventErrorOccurred;

typedef struct {
  UInt8* buffer;
  CFErrorRef error;
  Boolean finished;
  Boolean succeeded;
  std::vector<char> responseBuffer;
} readStreamCallbackInfo_t;

void ReadStreamCallback(CFReadStreamRef readStream, CFStreamEventType type, void* readStreamCallbackInfo) {
  CFIndex bytesRead = 0;
  readStreamCallbackInfo_t* callbackInfo = reinterpret_cast<readStreamCallbackInfo_t*>(readStreamCallbackInfo);

  switch (type) {
    case kCFStreamEventHasBytesAvailable:
      bytesRead = CFReadStreamRead(readStream, callbackInfo->buffer, kReadStreamBufferSize);
      if (bytesRead > 0) {
        callbackInfo->responseBuffer.insert(
          callbackInfo->responseBuffer.cend(), callbackInfo->buffer, callbackInfo->buffer + bytesRead);
      } else if (bytesRead < 0) {
        callbackInfo->error = CFReadStreamCopyError(readStream);
        callbackInfo->finished = true;
      } else {
        // Read 0 bytes, assume end of stream reached
        callbackInfo->finished = true;
        callbackInfo->succeeded = true;
      }
      break;
    case kCFStreamEventErrorOccurred:
      callbackInfo->error = CFReadStreamCopyError(readStream);
      callbackInfo->finished = true;
      break;
    case kCFStreamEventEndEncountered:
      callbackInfo->finished = true;
      callbackInfo->succeeded = true;
      break;
    default:
      // Unexpected event
      callbackInfo->finished = true;
      break;
  }
}

bool ConvertToUtf8String(CFStringRef cfString, std::string& result) {
  const char* raw = CFStringGetCStringPtr(cfString, kCFStringEncodingUTF8);
  if (raw) {
    result = raw;
  } else {
    CFIndex length = CFStringGetLength(cfString);
    CFIndex convertedSize = CFStringGetMaximumSizeForEncoding(length, kCFStringEncodingUTF8);

    std::vector<char> buffer(static_cast<size_t>(convertedSize));
    if (!CFStringGetCString(cfString, buffer.data(), convertedSize, kCFStringEncodingUTF8)) {
      return false;
    }

    result = std::string(buffer.data(), 0, static_cast<size_t>(convertedSize));
  }

  return true;
}

TTV_ErrorCode DecompressGzipResponse(std::vector<char>& responseBuffer, std::vector<char>& decompressedBuffer) {
  z_stream zStream;

  // Initialize zStream for inflation
  zStream.zalloc = Z_NULL;
  zStream.zfree = Z_NULL;
  zStream.opaque = Z_NULL;
  int status = inflateInit2(&zStream,
    (15 + 32));  // 15 for base window size, plus 32 to allow for gzip decoding with automatic header detection (see
                 // http://www.zlib.net/manual.html )
  if (status != Z_OK) {
    return TTV_EC_HTTPREQUEST_ERROR;
  }

  std::vector<unsigned char> decompressionBuffer(kDecompressionBufferSize);
  zStream.next_in = reinterpret_cast<unsigned char*>(responseBuffer.data());
  zStream.avail_in = static_cast<unsigned int>(responseBuffer.size());
  unsigned long lastTotalOut = zStream.total_out;

  do {
    zStream.avail_out = static_cast<unsigned int>(kDecompressionBufferSize);
    zStream.next_out = decompressionBuffer.data();

    status = inflate(&zStream, Z_NO_FLUSH);

    if (status == Z_OK || status == Z_STREAM_END) {
      decompressedBuffer.insert(decompressedBuffer.cend(), decompressionBuffer.begin(),
        decompressionBuffer.begin() + static_cast<long>(zStream.total_out - lastTotalOut));
      lastTotalOut = zStream.total_out;
    }
  } while (status == Z_OK);

  inflateEnd(&zStream);  // Call to avoid a memory leak
  return status == Z_STREAM_END ? TTV_EC_SUCCESS : TTV_EC_HTTPREQUEST_ERROR;
}
}  // namespace

//--------------------------------------------------------------------
TTV_ErrorCode ttv::CfHttpRequest::SendHttpRequest(const std::string& /*requestName*/, const std::string& url,
  const std::vector<HttpParam>& requestHeaders, const uint8_t* requestBody, size_t requestBodySize,
  HttpRequestType httpReqType, uint timeOutInSecs, HttpRequestHeadersCallback headersCallback,
  HttpRequestCallback responseCallback, void* userData) {
  // Create a CFString for the URL
  CFStringRef urlString = CFStringCreateWithCString(kCFAllocatorDefault, url.c_str(), kCFStringEncodingUTF8);

  // Finalize parameters based on request type
  CFStringRef requestMethod = nullptr;

  switch (httpReqType) {
    case HTTP_GET_REQUEST:
      requestMethod = CFSTR("GET");
      break;
    case HTTP_PUT_REQUEST:
      requestMethod = CFSTR("PUT");
      break;
    case HTTP_POST_REQUEST:
      requestMethod = CFSTR("POST");
      break;
    case HTTP_DELETE_REQUEST:
      requestMethod = CFSTR("DELETE");
      break;
    default:
      TTV_ASSERT(false);
  }

  // Create the request URL
  CFURLRef requestUrl = CFURLCreateWithString(kCFAllocatorDefault, urlString, nullptr);
  CFRelease(urlString);

  if (requestUrl == nullptr) {
    ttv::trace::Message("CFHttpRequest", MessageLevel::Error, "Invalid URL: %s", url.c_str());
    return TTV_EC_HTTPREQUEST_ERROR;
  }

  // Create the request
  CFHTTPMessageRef request =
    CFHTTPMessageCreateRequest(kCFAllocatorDefault, requestMethod, requestUrl, kCFHTTPVersion1_1);
  CFRelease(requestUrl);
  if (request == nullptr) {
    ttv::trace::Message("CFHttpRequest", MessageLevel::Error, "Unable to create request");
    return TTV_EC_HTTPREQUEST_ERROR;
  }

  // Set the body
  CFDataRef bodyData = CFDataCreateWithBytesNoCopy(kCFAllocatorDefault, reinterpret_cast<const UInt8*>(requestBody),
    static_cast<CFIndex>(requestBodySize), kCFAllocatorNull);
  CFHTTPMessageSetBody(request, bodyData);
  CFIndex contentLength = CFDataGetLength(bodyData);

  if (bodyData) {
    CFRelease(bodyData);
  }

  // Set up the header
  CFStringRef headerName;
  CFStringRef headerValue;
  for (auto it = requestHeaders.begin(); it != requestHeaders.end(); ++it) {
    headerName = CFStringCreateWithCString(kCFAllocatorDefault, it->paramName.c_str(), kCFStringEncodingUTF8);
    headerValue = CFStringCreateWithCString(kCFAllocatorDefault, it->paramValue.c_str(), kCFStringEncodingUTF8);
    CFHTTPMessageSetHeaderFieldValue(request, headerName, headerValue);
    CFRelease(headerName);
    CFRelease(headerValue);
  }
  if (httpReqType == HTTP_PUT_REQUEST || httpReqType == HTTP_POST_REQUEST || httpReqType == HTTP_DELETE_REQUEST) {
    if (contentLength) {
      headerValue = CFStringCreateWithFormat(kCFAllocatorDefault, nullptr, CFSTR("%lu"), contentLength);
      CFHTTPMessageSetHeaderFieldValue(request, CFSTR("Content-Length"), headerValue);
      CFRelease(headerValue);
    }
  }

  // Create a read stream for the request
  CFReadStreamRef readStream = CFReadStreamCreateForHTTPRequest(kCFAllocatorDefault, request);
  CFRelease(request);
  if (readStream == nullptr) {
    ttv::trace::Message("CFHttpRequest", MessageLevel::Error, "Unable to create read stream for request");
    return TTV_EC_HTTPREQUEST_ERROR;
  }

  // Allocated read stream buffer
  readStreamCallbackInfo_t callbackInfo;
  callbackInfo.buffer = new UInt8[kReadStreamBufferSize];
  callbackInfo.finished = false;
  callbackInfo.succeeded = false;
  callbackInfo.error = nullptr;
  CFStreamClientContext streamContext = {0, &callbackInfo, nullptr, nullptr, nullptr};

  if (!CFReadStreamSetClient(readStream, kObservedCFStreamEvents, ReadStreamCallback, &streamContext)) {
    ttv::trace::Message("CFHttpRequest", MessageLevel::Error, "Unable to schedule read stream asynchronously");
    CFRelease(readStream);
    return TTV_EC_HTTPREQUEST_ERROR;
  }

  // Schedule read stream in this thread's runloop
  CFReadStreamScheduleWithRunLoop(readStream, CFRunLoopGetCurrent(), kCFRunLoopDefaultMode);

  // Open the stream
  if (!CFReadStreamOpen(readStream)) {
    ttv::trace::Message("CFHttpRequest", MessageLevel::Error, "Unable to open read stream");
    CFRelease(readStream);
    return TTV_EC_HTTPREQUEST_ERROR;
  }

  // Set timeouts
  CFAbsoluteTime requestStartTime = CFAbsoluteTimeGetCurrent();
  CFAbsoluteTime lastDataReceivedTime = requestStartTime;
  CFAbsoluteTime maximumRequestTime = requestStartTime + kMaximumRequestTime;
  CFAbsoluteTime requestDataTimeout =
    timeOutInSecs > 0 ? static_cast<CFAbsoluteTime>(timeOutInSecs) : kDefaultDataTimeout;
  unsigned long lastDataSize = 0;

  while (!callbackInfo.finished) {
    // Run for 250ms or until the source is handled
    SInt32 result = CFRunLoopRunInMode(kCFRunLoopDefaultMode, 0.25, true);

    CFAbsoluteTime currentTime = CFAbsoluteTimeGetCurrent();
    if (callbackInfo.responseBuffer.size() > lastDataSize) {
      lastDataSize = callbackInfo.responseBuffer.size();
      lastDataReceivedTime = currentTime;
    }

    if (result == kCFRunLoopRunStopped || result == kCFRunLoopRunFinished || currentTime >= maximumRequestTime ||
        currentTime - lastDataReceivedTime >= requestDataTimeout) {
      break;
    }
  }

  CFRunLoopStop(CFRunLoopGetCurrent());
  CFReadStreamSetClient(readStream, 0, nullptr, nullptr);
  CFReadStreamUnscheduleFromRunLoop(readStream, CFRunLoopGetCurrent(), kCFRunLoopDefaultMode);
  CFReadStreamClose(readStream);
  delete[] callbackInfo.buffer;
  callbackInfo.buffer = nullptr;

  // Handle any errors
  if (!callbackInfo.succeeded) {
    if (callbackInfo.error != nullptr) {
      // Handle errors propagated by the read stream
      CFStringRef errorDescription = CFErrorCopyDescription(callbackInfo.error);
      CFRelease(callbackInfo.error);
      ttv::trace::Message("CFHttpRequest", MessageLevel::Error, "Encountered read error: %s",
        CFStringGetCStringPtr(errorDescription, kCFStringEncodingUTF8));
      CFRelease(errorDescription);
    } else {
      // Handle other errors
      ttv::trace::Message("CFHttpRequest", MessageLevel::Error, "Encountered unspecified error during request");
    }
    CFRelease(readStream);
    return TTV_EC_HTTPREQUEST_ERROR;
  }

  CFHTTPMessageRef response =
    CFHTTPMessageRef(CFReadStreamCopyProperty(readStream, kCFStreamPropertyHTTPResponseHeader));

  CFIndex responseCode = CFHTTPMessageGetResponseStatusCode(response);
  CFRelease(readStream);

  // Notify of headers
  bool notifyResponse = true;
  if (headersCallback != nullptr) {
    std::map<std::string, std::string> responseHeaders;

    CFDictionaryRef headerDict = CFHTTPMessageCopyAllHeaderFields(response);
    if (headerDict != nullptr) {
      CFIndex count = CFDictionaryGetCount(headerDict);
      std::vector<CFStringRef> keys(static_cast<size_t>(count), nullptr);
      std::vector<CFStringRef> values(static_cast<size_t>(count), nullptr);
      CFDictionaryGetKeysAndValues(
        headerDict, reinterpret_cast<const void**>(keys.data()), reinterpret_cast<const void**>(values.data()));

      std::string name, value;
      for (size_t i = 0; i < static_cast<size_t>(count); ++i) {
        if (ConvertToUtf8String(keys[i], name) && ConvertToUtf8String(values[i], value)) {
          responseHeaders[name] = value;
        } else {
          ttv::trace::Message("CFHttpRequest", MessageLevel::Error,
            "Failed to convert header value to string - it will be missing from the header callback");
        }
      }
    }

    notifyResponse = headersCallback(static_cast<unsigned int>(responseCode), responseHeaders, userData);

    if (headerDict != nullptr) {
      CFRelease(headerDict);
    }
  }

  // Notify of response
  if (notifyResponse) {
    // Check for a GZIP'ed response
    CFStringRef contentEncoding = CFHTTPMessageCopyHeaderFieldValue(response, CFSTR("Content-Encoding"));
    if (contentEncoding != nullptr &&
        CFStringFind(contentEncoding, CFSTR("gzip"), kCFCompareCaseInsensitive).location != kCFNotFound) {
      std::vector<char> decompressedBuffer;
      if (TTV_SUCCEEDED(DecompressGzipResponse(callbackInfo.responseBuffer, decompressedBuffer))) {
        // Perform the callback with decompressed data
        responseCallback(static_cast<unsigned int>(responseCode), decompressedBuffer, userData);
      } else {
        // Error during decompression, clean up and return an error
        CFRelease(contentEncoding);
        CFRelease(response);
        return TTV_EC_HTTPREQUEST_ERROR;
      }
    } else {
      // Perform the callback
      responseCallback(static_cast<unsigned int>(responseCode), callbackInfo.responseBuffer, userData);
    }

    if (contentEncoding != nullptr)
      CFRelease(contentEncoding);
  }

  CFRelease(response);

  return TTV_EC_SUCCESS;
}
