/****************************************************************************
 * Twitch SDK
 *
 * This software is supplied under the terms of a license agreement with
 * Twitch Interactive, Inc. and may not be copied or used except in accordance
 * with the terms of that agreement
 *
 * Copyright (c) 2012-2016 Twitch Interactive, Inc.
 ***************************************************************************/

#include "twitchsdk/core/internal/pch.h"

#include "twitchsdk/core/standardopensslsocket.h"

#include "twitchsdk/core/assertion.h"
#include "twitchsdk/core/httprequestutils.h"
#include "twitchsdk/core/socket.h"
#include "twitchsdk/core/stringutilities.h"
#include "twitchsdk/core/systemclock.h"
#include "twitchsdk/core/thread.h"

#include <openssl/bio.h>
#include <openssl/buffer.h>
#include <openssl/conf.h>
#include <openssl/err.h>
#include <openssl/opensslconf.h>
#include <openssl/pem.h>
#include <openssl/ssl.h>
#include <openssl/x509.h>
#include <openssl/x509v3.h>

#include <algorithm>
#include <atomic>

// TODO: Silly Windows, fix the includes
#ifdef SetPort
#undef SetPort
#endif
#ifdef min
#undef min
#endif

namespace {
uint64_t kConnectionTimeoutMilliseconds = 10000000;
const size_t kScratchBufferSize = 1024 * 4;

std::atomic<int> gOpenSslInitializations(0);

void LogError() {
  unsigned long ec = ERR_get_error();
  const char* str = ERR_reason_error_string(ec);
  ttv::trace::Message("Socket", ttv::MessageLevel::Error, "%s", str);
}

// TODO: Abstract TTV_ErrorCode LoadSystemTrustStore(SSL_CTX* context) in platform-independent way
#if WIN32

#include <Windows.h>
#pragma comment(lib, "crypt32.lib")
#include <wincrypt.h>

TTV_ErrorCode LoadSystemTrustStore(SSL_CTX* context) {
  // return TTV_EC_SUCCESS;

  ttv::trace::Message("Socket", ttv::MessageLevel::Debug,
    "OpenSslSocket: Loading Windows trust store certs into the OpenSSL trust store");

  // http://stackoverflow.com/questions/9507184/can-openssl-on-windows-use-the-system-certificate-store

  // Get the system cert store
  HCERTSTORE windowsTrustStore = CertOpenSystemStoreA(NULL, "ROOT");
  if (windowsTrustStore == nullptr) {
    ttv::trace::Message("Socket", ttv::MessageLevel::Error, "OpenSslSocket: Could not access the Windows trust store");
    return TTV_EC_API_REQUEST_FAILED;
  }

  int numCertsLoaded = 0;

  X509_STORE* openSslTrustStore = SSL_CTX_get_cert_store(context);
  if (openSslTrustStore != nullptr) {
    // Load each cert in the store
    PCCERT_CONTEXT data = nullptr;
    for (;;) {
      data = CertEnumCertificatesInStore(windowsTrustStore, data);
      if (data == nullptr) {
        break;
      }

      const unsigned char* cert = reinterpret_cast<const unsigned char*>(data->pbCertEncoded);
      X509* x509 = d2i_X509(nullptr, &cert, data->cbCertEncoded);

      if (x509 != nullptr) {
        int ret = X509_STORE_add_cert(openSslTrustStore, x509);
        if (ret == 1) {
          numCertsLoaded++;
        } else {
          ttv::trace::Message(
            "Socket", ttv::MessageLevel::Error, "OpenSslSocket: Error adding cert to the OpenSSL trust store");
        }
      } else {
        ttv::trace::Message(
          "Socket", ttv::MessageLevel::Error, "OpenSslSocket: Failed to convert Windows encoded cert to X509");
      }
    }
  } else {
    ttv::trace::Message("Socket", ttv::MessageLevel::Error, "OpenSslSocket: OpenSSL trust store is null");
    return TTV_EC_API_REQUEST_FAILED;
  }

  ttv::trace::Message("Socket", ttv::MessageLevel::Debug,
    "OpenSslSocket: Loaded %d Windows certs into the OpenSSL trust store", numCertsLoaded);

  // Close the cert store
  BOOL ret = CertCloseStore(windowsTrustStore, 0);
  if (ret == TRUE) {
    return TTV_EC_SUCCESS;
  } else {
    ttv::trace::Message("Socket", ttv::MessageLevel::Error, "OpenSslSocket: Error closing Windows trust store");
    return TTV_EC_API_REQUEST_FAILED;
  }
}

#else

TTV_ErrorCode LoadSystemTrustStore(SSL_CTX* context) {
  return TTV_EC_UNIMPLEMENTED;
}

#endif

// After the automatic validation is done on a certification this is called
// to notify the app and give it a chance to change the result.  Don't change it,
// but it's nice to have it for informational purposes.
int FilterVerifyCertificate(int ok, X509_STORE_CTX* /*store*/) {
  return ok;
}

}  // namespace

/*
                                    +-----------+
                                    |  encrypt  |
              <app write bytes> ==> | ========> | ==> OutputBio ==> <Send over raw soket>
                                    |           |
                                    |    SSL    |
                                    |           |
  <app read bytes> <== InputBio <== | <======== | <== <Receive from raw socket>
                                    |  decrypt  |
                                    +-----------+
 */

namespace ttv {
// We declare this class in the .cpp so we can hide OpenSSL types in the public headers
class OpenSslSocket::SocketData {
 public:
  SocketData() : mSsl(nullptr), mContext(nullptr), mInputBio(nullptr), mOutputBio(nullptr) {
    mScratchBuffer.resize(kScratchBufferSize);
  }

  ~SocketData() {
    if (mSsl != nullptr) {
      SSL_free(mSsl);
      mSsl = nullptr;
      mInputBio = nullptr;
      mOutputBio = nullptr;
    }

    // NOTE: mInputBio and mOutputBio are implicitly freed by SSL_free

    if (mInputBio != nullptr) {
      BIO_free(mInputBio);
      mInputBio = nullptr;
    }

    if (mOutputBio != nullptr) {
      BIO_free(mOutputBio);
      mOutputBio = nullptr;
    }

    if (mContext != nullptr) {
      SSL_CTX_free(mContext);
      mContext = nullptr;
    }
  }

  SSL_CTX* GetContext() { return mContext; }
  void SetContext(SSL_CTX* context) {
    TTV_ASSERT(mContext == nullptr);

    mContext = context;
  }

  SSL* GetSsl() { return mSsl; }
  void SetSsl(SSL* ssl) {
    TTV_ASSERT(mSsl == nullptr);

    mSsl = ssl;
  }

  BIO* GetInputBio() { return mInputBio; }
  void SetInputBio(BIO* bio) {
    TTV_ASSERT(mInputBio == nullptr);

    mInputBio = bio;
  }

  BIO* GetOutputBio() { return mOutputBio; }
  void SetOutputBio(BIO* bio) {
    TTV_ASSERT(mOutputBio == nullptr);

    mOutputBio = bio;
  }

  std::vector<uint8_t>& GetScratchBuffer() { return mScratchBuffer; }

 private:
  SSL* mSsl;
  SSL_CTX* mContext;
  BIO* mInputBio;                       //!< The encrypted data received from the base socket to be decrypted by ssl.
  BIO* mOutputBio;                      //!< The data to be encrypted by SSL and sent over the base socket.
  std::vector<uint8_t> mScratchBuffer;  //!< The bytes that are remaining to be passed to the client.
};
}  // namespace ttv

TTV_ErrorCode ttv::OpenSslSocket::InitializeOpenSslSockets() {
  gOpenSslInitializations++;

  if (gOpenSslInitializations == 1) {
    (void)SSL_library_init();
    SSL_load_error_strings();
    ERR_load_BIO_strings();
    // TODO: Make sure we can't just load specific ones and reduce footprint
    OpenSSL_add_all_algorithms();

    return TTV_EC_SUCCESS;
  } else {
    return TTV_EC_ALREADY_INITIALIZED;
  }
}

TTV_ErrorCode ttv::OpenSslSocket::ShutdownOpenSslSockets() {
  gOpenSslInitializations--;

  if (gOpenSslInitializations == 0) {
    ERR_free_strings();
    EVP_cleanup();

    return TTV_EC_SUCCESS;
  } else {
    return TTV_EC_NOT_INITIALIZED;
  }
}

void ttv::OpenSslSocket::SetTrustedHosts(const std::vector<std::string>& hosts) {
  sTrustedHosts = hosts;
}

std::vector<std::string> ttv::OpenSslSocket::sTrustedHosts;

ttv::OpenSslSocket::OpenSslSocket() : mConnected(false) {}

ttv::OpenSslSocket::~OpenSslSocket() {
  (void)Disconnect();
}

TTV_ErrorCode ttv::OpenSslSocket::Initialize(const std::string& host, const std::string& port) {
  ttv::trace::Message(
    "Socket", MessageLevel::Debug, "OpenSslSocket: Creating openssl socket for %s:%s", host.c_str(), port.c_str());

  TTV_ASSERT(gOpenSslInitializations > 0);

  mHost = host;
  mPort = port;

  TTV_ErrorCode ec = TTV_EC_SUCCESS;

  Uri url;
  url.SetProtocol("tcp");
  url.SetHostName(host);
  url.SetPort(port);

  ec = ttv::CreateSocket(url.GetUrl(), mBaseSocket);
  if (TTV_FAILED(ec)) {
    ttv::trace::Message("Socket", MessageLevel::Debug, "WebSocket: Failed create base socket with protocol: %s",
      url.GetProtocol().c_str());
  }

  return ec;
}

TTV_ErrorCode ttv::OpenSslSocket::Connect() {
  TTV_ErrorCode ec = TTV_EC_SUCCESS;

  // Already connected
  if (Connected()) {
    ec = TTV_EC_SOCKET_EALREADY;
  }

  // Setup SSL for translating data
  if (SOCKET_SUCCEEDED(ec)) {
    ec = PrepareConnection();
  }

  // Now connect to the base socket
  if (SOCKET_SUCCEEDED(ec)) {
    ec = mBaseSocket->Connect();
  }

  // Perform the handshake
  if (SOCKET_SUCCEEDED(ec)) {
    ec = Handshake();
  }

  mConnected = SOCKET_SUCCEEDED(ec);

  // Failure so disconnect
  if (SOCKET_FAILED(ec)) {
    Disconnect();
  }

  return ec;
}

TTV_ErrorCode ttv::OpenSslSocket::Disconnect() {
  mConnected = false;
  (void)mBaseSocket->Disconnect();
  mSocketData.reset();

  return TTV_EC_SUCCESS;
}

TTV_ErrorCode ttv::OpenSslSocket::PrepareConnection() {
  // https://wiki.openssl.org/index.php/SSL/TLS_Client
  // http://www.ibm.com/developerworks/library/l-openssl/
  // http://www.roxlu.com/2014/042/using-openssl-with-memory-bios
  // http://stackoverflow.com/questions/7698488/turn-a-simple-socket-into-an-ssl-socket

  TTV_ErrorCode ec = TTV_EC_SUCCESS;

  std::unique_ptr<SocketData> socketData = std::make_unique<SocketData>();

  // Get the method for creating contexts
  const SSL_METHOD* method = SSLv23_client_method();
  TTV_ASSERT(method != nullptr);
  if (method == nullptr) {
    ec = TTV_EC_SOCKET_CREATE_FAILED;
    ttv::trace::Message("Socket", MessageLevel::Error, "OpenSslSocket: SSLv23_client_method failed");
  }

  // Create a context used for this connection
  SSL_CTX* context = nullptr;
  if (TTV_SUCCEEDED(ec)) {
    context = SSL_CTX_new(method);
    TTV_ASSERT(context != nullptr);
    if (context == nullptr) {
      ec = TTV_EC_SOCKET_CREATE_FAILED;
      ttv::trace::Message("Socket", MessageLevel::Error, "OpenSslSocket: SSL_CTX_new failed");
    }

    socketData->SetContext(context);
  }

  if (TTV_SUCCEEDED(ec)) {
    // Setup to use TLS1.1 or higher
    long flags = SSL_OP_ALL | SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 | SSL_OP_NO_TLSv1;
    (void)SSL_CTX_set_options(context, flags);

    ec = LoadSystemTrustStore(context);

    //// Load the trust store file
    //// TODO: We probably want to load the trust store from a memory buffer
    //// http://openssl.6102.n7.nabble.com/how-to-load-a-certs-chain-from-memory-thanks-td42881.html
    // long result = SSL_CTX_load_verify_locations(context, "random-org-chain.pem", NULL);
    // TTV_ASSERT(result == 1);
    // if (result != 1)
    //{
    //  ec = TTV_EC_SOCKET_CREATE_FAILED;
    //  ttv::trace::Message("Socket", MessageLevel::Error, "OpenSslSocket: SSL_CTX_load_verify_locations failed");
    //}
  }

  // Create input BIO object
  BIO* inputBio = nullptr;
  if (TTV_SUCCEEDED(ec)) {
    inputBio = BIO_new(BIO_s_mem());
    if (inputBio != nullptr) {
      socketData->SetInputBio(inputBio);

      BIO_set_mem_eof_return(inputBio, -1);
    } else {
      ec = TTV_EC_SOCKET_CREATE_FAILED;
      ttv::trace::Message("Socket", MessageLevel::Error, "OpenSslSocket: BIO_new failed");
    }
  }

  // Create output BIO object
  BIO* outputBio = nullptr;
  if (TTV_SUCCEEDED(ec)) {
    outputBio = BIO_new(BIO_s_mem());
    if (outputBio != nullptr) {
      socketData->SetOutputBio(outputBio);

      BIO_set_mem_eof_return(outputBio, -1);
    } else {
      ec = TTV_EC_SOCKET_CREATE_FAILED;
      ttv::trace::Message("Socket", MessageLevel::Error, "OpenSslSocket: BIO_new failed");
    }
  }

  // Create the SSL object
  SSL* ssl = nullptr;
  if (TTV_SUCCEEDED(ec)) {
    ssl = SSL_new(context);
    if (ssl != nullptr) {
      socketData->SetSsl(ssl);

      // Configure SSL to use the memory buffers we've allocated
      SSL_set_bio(ssl, inputBio, outputBio);

      // Use client protocol
      SSL_set_connect_state(ssl);
    } else {
      ec = TTV_EC_SOCKET_CREATE_FAILED;
      ttv::trace::Message("Socket", MessageLevel::Error, "OpenSslSocket: SSL_new failed");
    }
  }

  // Setup certificate verification
  if (TTV_SUCCEEDED(ec)) {
    X509_VERIFY_PARAM* param = SSL_get0_param(ssl);
    X509_VERIFY_PARAM_set_hostflags(param, X509_CHECK_FLAG_MULTI_LABEL_WILDCARDS);

    // Add trusted hosts
    for (const auto& host : sTrustedHosts) {
      X509_VERIFY_PARAM_add1_host(param, host.c_str(), 0);
    }

    if (!ttv::IsHostAnIpAddress(mHost)) {
      std::vector<std::string> hosts;
      GenerateSslVerificationHosts(mHost, hosts);

      // Enable automatic hostname checks
      for (const auto& host : hosts) {
        X509_VERIFY_PARAM_add1_host(param, host.c_str(), 0);
      }
    }

    // SSL_VERIFY_NONE: Failure to verify the certificate does not fail the handshake
    // SSL_VERIFY_PEER: Failure to verify the certificate fails the handshake
    SSL_set_verify(ssl, SSL_VERIFY_PEER, FilterVerifyCertificate);
  }

  if (TTV_SUCCEEDED(ec)) {
    // Now that we're connected retain the instance
    mSocketData = std::move(socketData);
  } else {
    LogError();
  }

  return ec;
}

TTV_ErrorCode ttv::OpenSslSocket::Handshake() {
  TTV_ErrorCode ec = TTV_EC_SUCCESS;

  ttv::trace::Message("Socket", MessageLevel::Debug, "OpenSslSocket: Begining handshake");

  SSL* ssl = mSocketData->GetSsl();

  uint64_t start = GetSystemTimeMilliseconds();
  int ret = -1;
  while (SOCKET_SUCCEEDED(ec)) {
    if ((GetSystemTimeMilliseconds() - start) >= kConnectionTimeoutMilliseconds) {
      ttv::trace::Message("Socket", MessageLevel::Error, "OpenSslSocket: Timed out trying to connect");
      ec = TTV_EC_SOCKET_ETIMEDOUT;
      break;
    }

    // Send handshake bytes over the socket
    if (SOCKET_SUCCEEDED(ec)) {
      ec = FlushOutgoing();
    }

    // Wait for some data to arrive on the socket
    if (SOCKET_SUCCEEDED(ec)) {
      Sleep(10);
    }

    // Receive handshake bytes from the socket
    if (SOCKET_SUCCEEDED(ec)) {
      ec = FlushIncoming();
    }

    ret = SSL_do_handshake(ssl);

    if (ret == 1) {
      ttv::trace::Message("Socket", MessageLevel::Debug, "OpenSslSocket: Handshake successful");
      break;
    } else {
      int err = SSL_get_error(ssl, ret);
      if (err != SSL_ERROR_WANT_READ && err != SSL_ERROR_WANT_WRITE) {
        ec = TTV_EC_SOCKET_ECONNRESET;
        ttv::trace::Message(
          "Socket", MessageLevel::Error, "OpenSslSocket: SSL_do_handshake failed: %s", ERR_error_string(err, nullptr));
        ttv::trace::Message("Socket", MessageLevel::Error, "OpenSslSocket: Detailed error: %s",
          ERR_error_string(ERR_get_error(), nullptr));
      }
    }
  }

  if (ec == TTV_EC_SOCKET_EWOULDBLOCK) {
    ec = TTV_EC_SUCCESS;
  }

  // Verify that the server presented a proper certificate
  if (TTV_SUCCEEDED(ec)) {
    X509* cert = SSL_get_peer_certificate(ssl);
    if (cert != nullptr) {
      X509_free(cert);
    } else {
      ttv::trace::Message("Socket", MessageLevel::Error, "OpenSslSocket: No certificate presented");
      ec = TTV_EC_AUTHENTICATION;
    }
  }

  // Verify the result of chain verification
  if (TTV_SUCCEEDED(ec)) {
    long res = SSL_get_verify_result(ssl);
    if (res != X509_V_OK) {
      ttv::trace::Message("Socket", MessageLevel::Error, "OpenSslSocket: Certificate chain verification failed: %s",
        X509_verify_cert_error_string(res));
      ec = TTV_EC_AUTHENTICATION;
    }
  }

  // NOTE: Hostname verification is done automatically by OpenSSL

  return ec;
}

TTV_ErrorCode ttv::OpenSslSocket::Flush() {
  TTV_ErrorCode ec = FlushOutgoing();

  if (SOCKET_SUCCEEDED(ec)) {
    ec = FlushIncoming();
  }

  return ec;
}

TTV_ErrorCode ttv::OpenSslSocket::FlushOutgoing() {
  TTV_ErrorCode ec = TTV_EC_SUCCESS;
  uint8_t* scratch = mSocketData->GetScratchBuffer().data();
  int scratchSize = static_cast<int>(mSocketData->GetScratchBuffer().size());

  // Send pending encrypted bytes over the socket
  BIO* outputBio = mSocketData->GetOutputBio();
  int remaining = static_cast<int>(BIO_ctrl_pending(outputBio));

  while (SOCKET_SUCCEEDED(ec) && remaining > 0) {
    int maxSend = std::min(remaining, scratchSize);
    int numEncrypted = BIO_read(outputBio, static_cast<void*>(scratch), maxSend);
    if (numEncrypted == maxSend) {
      ttv::trace::Message("Socket", MessageLevel::Debug,
        "OpenSslSocket: Sending %d outgoing encrypted bytes over base socket", numEncrypted);
      ec = mBaseSocket->Send(scratch, numEncrypted);

      remaining -= numEncrypted;
    } else {
      // Not all was read for some reason so abort
      ec = TTV_EC_SOCKET_ECONNABORTED;
      ttv::trace::Message(
        "Socket", MessageLevel::Error, "OpenSslSocket: Not all outgoing encrypted bytes were read from output bio");
    }
  }

  return ec;
}

TTV_ErrorCode ttv::OpenSslSocket::FlushIncoming() {
  TTV_ErrorCode ec = TTV_EC_SUCCESS;
  uint8_t* scratch = mSocketData->GetScratchBuffer().data();
  int scratchSize = static_cast<int>(mSocketData->GetScratchBuffer().size());

  // Read some encrypted bytes from the base socket
  size_t encryptedReceived = 0;
  ec = mBaseSocket->Recv(scratch, scratchSize, encryptedReceived);

  // Feed the encrypted bytes into SSL for decrytion to the output bio
  if (SOCKET_SUCCEEDED(ec) && encryptedReceived > 0) {
    ttv::trace::Message("Socket", MessageLevel::Debug,
      "OpenSslSocket: Received %u incoming encrypted bytes from base socket", encryptedReceived);

    BIO* inputBio = mSocketData->GetInputBio();
    int written = BIO_write(inputBio, scratch, static_cast<int>(encryptedReceived));
    if (written != static_cast<int>(encryptedReceived)) {
      ec = TTV_EC_SOCKET_ECONNABORTED;
      ttv::trace::Message(
        "Socket", MessageLevel::Error, "OpenSslSocket: Not all incoming encrypted bytes were accepted by input bio");
    }
  }

  return ec;
}

TTV_ErrorCode ttv::OpenSslSocket::Send(const uint8_t* buffer, size_t length, size_t& sent) {
  TTV_ASSERT(buffer != nullptr);

  TTV_ErrorCode ec = TTV_EC_SUCCESS;

  if (!Connected()) {
    ec = TTV_EC_SOCKET_ENOTCONN;
  }

  // Handle socket activity
  if (SOCKET_SUCCEEDED(ec)) {
    ec = Flush();
  }

  // Encrypt the data to be sent
  if (SOCKET_SUCCEEDED(ec)) {
    ttv::trace::Message(
      "Socket", MessageLevel::Debug, "OpenSslSocket: Feeding %u unencrypted bytes into SSL_write", length);

    const void* buf = reinterpret_cast<const void*>(buffer);
    sent = static_cast<size_t>(SSL_write(mSocketData->GetSsl(), buf, static_cast<int>(length)));
  }

  // Send pending encrypted stuff over the socket
  if (SOCKET_SUCCEEDED(ec)) {
    ec = FlushOutgoing();
  }

  // Failure so disconnect
  if (SOCKET_FAILED(ec)) {
    (void)Disconnect();
  }

  return ec;
}

TTV_ErrorCode ttv::OpenSslSocket::Recv(uint8_t* buffer, size_t length, size_t& received) {
  TTV_ASSERT(buffer != nullptr);

  received = 0;

  TTV_ErrorCode ec = TTV_EC_SUCCESS;

  if (!Connected()) {
    ec = TTV_EC_SOCKET_ENOTCONN;
  }

  if (SOCKET_SUCCEEDED(ec)) {
    // Send and receive bytes over the base socket
    ec = Flush();
  }

  if (SOCKET_SUCCEEDED(ec)) {
    // Feed the bytes into SSL for decryption
    SSL* ssl = mSocketData->GetSsl();
    BIO* inputBio = mSocketData->GetInputBio();

    // Read decrypted bytes out of SSL
    size_t pending = BIO_ctrl_pending(inputBio);
    if (pending > 0) {
      length = std::min(pending, length);

      int read = SSL_read(ssl, static_cast<void*>(buffer), static_cast<int>(length));
      if (read > 0) {
        ttv::trace::Message(
          "Socket", MessageLevel::Debug, "OpenSslSocket: Reading %d unencrypted bytes from SSL_read", read);

        length -= static_cast<size_t>(read);
        received += static_cast<size_t>(read);
      } else if (read < 0) {
        int err = SSL_get_error(ssl, read);
        ec = TTV_EC_SOCKET_ECONNABORTED;
        ttv::trace::Message("Socket", MessageLevel::Error, "OpenSslSocket: SSL_read failed to read all bytes: %s",
          ERR_error_string(err, nullptr));
      }
    }
  }

  // Failure so disconnect
  if (SOCKET_FAILED(ec)) {
    (void)Disconnect();
  }

  return ec;
}

uint64_t ttv::OpenSslSocket::TotalSent() {
  if (mBaseSocket != nullptr) {
    return mBaseSocket->TotalSent();
  } else {
    return 0;
  }
}

uint64_t ttv::OpenSslSocket::TotalReceived() {
  if (mBaseSocket != nullptr) {
    return mBaseSocket->TotalReceived();
  } else {
    return 0;
  }
}

bool ttv::OpenSslSocket::Connected() {
  return mConnected;
}

TTV_ErrorCode ttv::OpenSslSocket::SetCertificateData(const std::string& data) {
  // TODO: Use this
  mCertificateData = data;

  return TTV_EC_SUCCESS;
}

ttv::OpenSslSocketFactory::~OpenSslSocketFactory() {}

bool ttv::OpenSslSocketFactory::IsProtocolSupported(const std::string& protocol) {
  return protocol == "tls" || protocol == "ssl";
}

TTV_ErrorCode ttv::OpenSslSocketFactory::CreateSocket(const std::string& uri, std::shared_ptr<ttv::ISocket>& result) {
  result.reset();

  Uri url(uri);
  if (url.GetProtocol() == "tls" || url.GetProtocol() == "ssl") {
    std::shared_ptr<OpenSslSocket> socket = std::make_shared<OpenSslSocket>();
    if (TTV_SUCCEEDED(socket->Initialize(url.GetHostName(), url.GetPort()))) {
      result = socket;
      return TTV_EC_SUCCESS;
    }
  }

  if (result != nullptr) {
    return TTV_EC_SUCCESS;
  } else {
    return TTV_EC_UNIMPLEMENTED;
  }
}
