#include "libav-decoder-helpers.hpp"

#include "../config.hpp"
#include "./avlogger.hpp"
#include "./libav-decoder.hpp"
#include "debug/log.hpp"
#include "fundamentals/helpers.hpp"

extern "C" {
#include <libavcodec/avcodec.h>
#include <libavfilter/avfilter.h>
#include <libavfilter/buffersink.h>
#include <libavfilter/buffersrc.h>
#include <libavutil/opt.h>
#include <libavutil/pixdesc.h>
}

#include <mutex>

namespace Vape {

static const std::string_view h{"LibAVHelpers"};

static const char *
GetColorRangeName(AVColorRange range)
{
    if (range == AVCOL_RANGE_JPEG) {
        return "jpeg";
    }

    return "mpeg";
}

static void
UpgradePixelFormat(int &format, bool *is_full_range)
{
    bool fullrange = false;

    switch (format) {
        case AV_PIX_FMT_YUVJ420P:
            format    = AV_PIX_FMT_YUV420P;
            fullrange = true;
            break;
        case AV_PIX_FMT_YUVJ422P:
            format    = AV_PIX_FMT_YUV422P;
            fullrange = true;
            break;
        case AV_PIX_FMT_YUVJ444P:
            format    = AV_PIX_FMT_YUV444P;
            fullrange = true;
            break;
        case AV_PIX_FMT_YUVJ440P:
            format    = AV_PIX_FMT_YUV440P;
            fullrange = true;
            break;
        default:
            fullrange = false;
            break;
    }

    if (is_full_range) {
        *is_full_range = fullrange;
    }
}

bool
IsHwDecoder(AVCodecContext *decoderContext)
{
    return decoderContext && decoderContext->pix_fmt == AV_PIX_FMT_VAAPI;
}

static AVPacket *
AVPacketFromMediaSample(const PlayerCore::WrappedMediaSample &sample, AVRational rational)
{
    AVPacket *pkt = av_packet_alloc();

    if (!pkt) {
        return nullptr;
    }

    auto &buffer = sample.sample->buffer;

    if (av_new_packet(pkt, buffer.size()) != 0) {
        av_packet_free(&pkt);
        return nullptr;
    }

    std::copy(buffer.begin(), buffer.end(), pkt->data);

    // Store duration in flicks to avoid convertion errors
    pkt->duration = sample.Duration().count();
    pkt->pts      = AVTimeFromFlicks(sample.PTS(), rational);
    pkt->dts      = AVTimeFromFlicks(sample.DTS(), rational);

    return pkt;
}

void
CorrectAVFrameColorInfo(AVFrame *frame)
{
    bool isFullRange;

#if 0
    Log::Println("IN: Range: {}, Space: {}, Transfer: {}, Primaries: {}",
                 GetColorRangeName(frame->color_range), frame->colorspace, frame->color_trc,
                 frame->color_primaries);
#endif

    UpgradePixelFormat(frame->format, &isFullRange);

    if (frame->color_range == AVCOL_RANGE_UNSPECIFIED) {
        if (isFullRange) {
            frame->color_range = AVCOL_RANGE_JPEG;
        } else {
            frame->color_range = AVCOL_RANGE_MPEG;
        }
    }

    if (frame->colorspace == AVCOL_SPC_UNSPECIFIED) {
        frame->colorspace = AVCOL_SPC_SMPTE170M;
    }

    if (frame->color_trc == AVCOL_TRC_UNSPECIFIED) {
        if (frame->colorspace == AVCOL_SPC_BT709) {
            frame->color_trc = AVCOL_TRC_BT709;
        } else {
            frame->color_trc = AVCOL_TRC_SMPTE170M;
        }
    }

    if (frame->color_primaries == AVCOL_PRI_UNSPECIFIED) {
        if (frame->colorspace == AVCOL_SPC_BT709) {
            frame->color_primaries = AVCOL_PRI_BT709;
        } else {
            frame->color_primaries = AVCOL_PRI_SMPTE170M;
        }
    }

#if 0
    Log::Println("OUT: Range: {}, Space: {}, Transfer: {}, Primaries: {}",
                 GetColorRangeName(frame->color_range), frame->colorspace, frame->color_trc,
                 frame->color_primaries);
#endif
}

bool
EnsureValidAVRational(AVRational *rational, const AVRational &r)
{
    if (rational != nullptr) {
        if (rational->num == 0) {
            *rational = r;
            return true;
        }
    }

    return false;
}

static bool
LinkFilterGraph(const std::vector<AVFilterContext *> &filters)
{
    for (size_t i = 1; i < filters.size(); ++i) {
        auto first  = filters[i - 1];
        auto second = filters[i];

        auto err = avfilter_link(first, 0, second, 0);
        if (err != 0) {
            return false;
        }
    }

    return true;
}

static AVFilterGraph *
CreateSwVideoFilterGraph(AVCodecContext *decoderContext, IVideoOutputConsumer::VideoFormat fmt)
{
    auto dec         = decoderContext;
    auto dst_h       = fmt.height;
    auto dst_w       = fmt.width;
    auto dst_pix_fmt = fmt.format;
    int src_pix_fmt  = dec->pix_fmt;
    UpgradePixelFormat(src_pix_fmt, nullptr);
    const bool enableScaling = fmt.scaled;

    AVFilterGraph *graph = avfilter_graph_alloc();
    if (graph == nullptr) {
        Log::TDebug(h, "Failed to allocate filter graph");
        return nullptr;
    }

    std::vector<AVFilterContext *> filters;

    {
        const AVFilter *buffer = avfilter_get_by_name("buffer");
        AVFilterContext *ctx   = avfilter_graph_alloc_filter(graph, buffer, "src");
        assert(ctx);

        auto args = fS("video_size={}x{}:pix_fmt={}:time_base={}/{}:pixel_aspect={}/{}", dec->width,
                       dec->height, src_pix_fmt, dec->time_base.num, dec->time_base.den,
                       dec->sample_aspect_ratio.num, dec->sample_aspect_ratio.den);

        avfilter_init_str(ctx, args.c_str());

        filters.push_back(ctx);
    }

    {
        const AVFilter *scale = avfilter_get_by_name("scale");
        AVFilterContext *ctx  = avfilter_graph_alloc_filter(graph, scale, "scale");
        assert(ctx);

        std::string args;
#if 0
        if (enableScaling) {
            args = fS("width={}:height={}:sws_flags=fast_bilinear:out_color_matrix={}:out_range={}",
                      dst_w, dst_h, av_get_colorspace_name(fmt.colorSpace),
                      GetColorRangeName(fmt.colorRange));

        } else {
            args = fS("out_color_matrix={}:out_range={}", av_get_colorspace_name(fmt.colorSpace),
                      GetColorRangeName(fmt.colorRange));
        }
#else
        if (enableScaling) {
            args = fS("width={}:height={}:sws_flags=fast_bilinear", dst_w, dst_h);
        }
#endif
        avfilter_init_str(ctx, args.c_str());

        filters.push_back(ctx);
    }

    {
        const AVFilter *buffersink = avfilter_get_by_name("buffersink");
        AVFilterContext *ctx       = avfilter_graph_alloc_filter(graph, buffersink, "sink");
        assert(ctx);

        avfilter_init_str(ctx, nullptr);

        av_opt_set_bin(ctx, "pix_fmts", (uint8_t *)&dst_pix_fmt, sizeof(dst_pix_fmt),
                       AV_OPT_SEARCH_CHILDREN);

        filters.push_back(ctx);
    }

    if (!LinkFilterGraph(filters)) {
        Log::TDebug(h, "Failed to link filter graph");
        avfilter_graph_free(&graph);
        return nullptr;
    }

    return graph;
}

static AVFilterGraph *
CreateHwVideoFilterGraph(AVCodecContext *decoderContext, IVideoOutputConsumer::VideoFormat fmt)
{
    auto dec         = decoderContext;
    auto dst_h       = fmt.height;
    auto dst_w       = fmt.width;
    auto dst_pix_fmt = fmt.format;

    AVFilterGraph *graph = avfilter_graph_alloc();

    if (graph == nullptr) {
        Log::TDebug(h, "Failed to allocate filter graph");
        return nullptr;
    }

    const bool enableScaling = fmt.scaled;

    std::vector<AVFilterContext *> filters;

    {
        const AVFilter *buffer = avfilter_get_by_name("buffer");
        AVFilterContext *ctx   = avfilter_graph_alloc_filter(graph, buffer, "src");
        assert(ctx);

        AVBufferSrcParameters *par = av_buffersrc_parameters_alloc();
        memset(par, 0, sizeof(*par));
        par->hw_frames_ctx = dec->hw_frames_ctx;
        av_buffersrc_parameters_set(ctx, par);
        av_freep(&par);

        auto args = fS("video_size={}x{}:pix_fmt={}:time_base={}/{}:pixel_aspect={}/{}", dec->width,
                       dec->height, dec->pix_fmt, dec->time_base.num, dec->time_base.den,
                       dec->sample_aspect_ratio.num, dec->sample_aspect_ratio.den);

        avfilter_init_str(ctx, args.c_str());

        filters.push_back(ctx);
    }

    if (enableScaling) {
        const AVFilter *scale = avfilter_get_by_name("scale_vaapi");
        AVFilterContext *ctx  = avfilter_graph_alloc_filter(graph, scale, "scale");
        assert(ctx);

        ctx->hw_device_ctx = av_buffer_ref(dec->hw_device_ctx);

        auto args = fS("w={}:h={}:mode=fast", dst_w, dst_h);
        int err   = avfilter_init_str(ctx, args.c_str());

        if (err != 0) {
            Log::TError(h, "Old version of FFmpeg, can't set VAAPI scale mode to Fast");
            args = fS("w={}:h={}", dst_w, dst_h);
            avfilter_init_str(ctx, args.c_str());
        }

        filters.push_back(ctx);
    }

    {
        const AVFilter *download = avfilter_get_by_name("hwdownload");
        AVFilterContext *ctx     = avfilter_graph_alloc_filter(graph, download, "download");
        assert(ctx);

        avfilter_init_str(ctx, nullptr);

        filters.push_back(ctx);
    }

#if 0
    {
        const AVFilter *format = avfilter_get_by_name("format");
        AVFilterContext *ctx   = avfilter_graph_alloc_filter(graph, format, "format");
        assert(ctx);

        std::string args = fS("pix_fmts=nv12");//, av_get_pix_fmt_name(dst_pix_fmt));

        avfilter_init_str(ctx, args.c_str());

        filters.push_back(ctx);
    }

    {
        const AVFilter *scale = avfilter_get_by_name("scale");
        AVFilterContext *ctx  = avfilter_graph_alloc_filter(graph, scale, "scale");
        assert(ctx);

        // std::string args = fS("out_color_matrix={}:out_range={}", av_get_colorspace_name(fmt.colorSpace),
        //             GetColorRangeName(fmt.colorRange));
        std::string args = fS("out_color_matrix={}", av_get_colorspace_name(fmt.colorSpace));

        avfilter_init_str(ctx, args.c_str());

        filters.push_back(ctx);
    }
#elif 0
    {
        const AVFilter *color = avfilter_get_by_name("colorspace");
        AVFilterContext *ctx = avfilter_graph_alloc_filter(graph, color, "colorspace");
        assert(ctx);

        auto args = fS("all={}:range={}:fast=1", av_get_colorspace_name(fmt.colorSpace),
                       GetColorRangeName(fmt.colorRange));
        Log::Println(args);
        avfilter_init_str(ctx, args.c_str());

        filters.push_back(ctx);
    }
#endif

    {
        const AVFilter *buffersink = avfilter_get_by_name("buffersink");
        AVFilterContext *ctx       = avfilter_graph_alloc_filter(graph, buffersink, "sink");
        assert(ctx);

        avfilter_init_str(ctx, nullptr);

        av_opt_set_bin(ctx, "pix_fmts", (uint8_t *)&dst_pix_fmt, sizeof(dst_pix_fmt),
                       AV_OPT_SEARCH_CHILDREN);

        filters.push_back(ctx);
    }

    if (!LinkFilterGraph(filters)) {
        Log::TDebug(h, "Failed to link filter graph");
        avfilter_graph_free(&graph);
        return nullptr;
    }

    return graph;
}

AVFilterGraph *
CreateVideoFilterGraph(AVCodecContext *decoderContext, IVideoOutputConsumer::VideoFormat fmt)
{
    AVFilterGraph *graph = nullptr;

    const char *pixFmtname = av_get_pix_fmt_name(decoderContext->pix_fmt);
    if (pixFmtname) {
        Log::TDebug(h, "Creating filter graph for pixel format {}", pixFmtname);
    }

    if (IsHwDecoder(decoderContext)) {
        graph = CreateHwVideoFilterGraph(decoderContext, fmt);
    } else {
        graph = CreateSwVideoFilterGraph(decoderContext, fmt);
    }

    if (!graph) {
        Log::TDebug(h, "Failed to create filter graph");
        return nullptr;
    }

    if (avfilter_graph_config(graph, NULL) != 0) {
        Log::TDebug(h, "Failed to config filter graph");
        avfilter_graph_free(&graph);
        return nullptr;
    }

    char *dump = avfilter_graph_dump(graph, NULL);
    Log::TDebug(h, "Video filter graph:\n{}", dump);
    av_free(dump);

    return graph;
}

AVFilterGraph *
CreateAudioFilterGraph(AVCodecContext *decoderContext, IVideoOutputConsumer::AudioFormat fmt)
{
    AVFilterGraph *graph = avfilter_graph_alloc();
    if (graph == nullptr) {
        Log::TDebug(h, "Failed to allocate filter graph");
        return nullptr;
    }

    std::vector<AVFilterContext *> filters;

    {
        const AVFilter *abuffer = avfilter_get_by_name("abuffer");
        AVFilterContext *ctx    = avfilter_graph_alloc_filter(graph, abuffer, "src");
        assert(ctx);

        char ch_layout[64];
        av_get_channel_layout_string(ch_layout, sizeof(ch_layout), 0,
                                     decoderContext->channel_layout);

        // AV_SAMPLE_FMT_FLT replaces dec->sample_fmt to avoid one conversion in SoundTouchWrapper
        auto options_str =
            fS("sample_fmt={}:sample_rate={}:channel_layout={}",
               av_get_sample_fmt_name(AV_SAMPLE_FMT_FLT), decoderContext->sample_rate, ch_layout);

        avfilter_init_str(ctx, options_str.c_str());

        filters.push_back(ctx);
    }

    {
        const AVFilter *aformat = avfilter_get_by_name("aformat");
        AVFilterContext *ctx    = avfilter_graph_alloc_filter(graph, aformat, "format");
        assert(ctx);

        char ch_layout[64];
        av_get_channel_layout_string(ch_layout, sizeof(ch_layout), 0, fmt.layout);

        auto options_str = fS("sample_fmts={}:sample_rates={}:channel_layouts={}",
                              av_get_sample_fmt_name(fmt.format), fmt.sample_rate, ch_layout);

        avfilter_init_str(ctx, options_str.c_str());

        filters.push_back(ctx);
    }

    {
        const AVFilter *abuffersink = avfilter_get_by_name("abuffersink");
        AVFilterContext *ctx        = avfilter_graph_alloc_filter(graph, abuffersink, "sink");
        assert(ctx);

        avfilter_init_str(ctx, nullptr);

        filters.push_back(ctx);
    }

    if (!LinkFilterGraph(filters)) {
        Log::TDebug(h, "Failed to link filter graph");
        avfilter_graph_free(&graph);
        return nullptr;
    }

    if (avfilter_graph_config(graph, NULL) != 0) {
        Log::TDebug(h, "Failed to config filter graph");
        avfilter_graph_free(&graph);
        return nullptr;
    }

    return graph;
}

std::vector<float>
GetPeaksForFrame(AVFrame *frame)
{
    if (frame->format != AV_SAMPLE_FMT_FLTP) {
        Log::TError(h, "GetPeakForFrame called with wrong format");
        return {};
    }

    // auto rms = [](const float *data, int count) {
    //     float sq = 0;

    //     for (int i = 0; i < count; ++i) {
    //         float f = data[i];
    //         sq += f * f;
    //     }

    //     return std::sqrt(sq / count);
    // };

    auto peak = [](const float *data, int count) {
        float peak = 0;

        for (int i = 0; i < count; ++i) {
            float f = std::abs(data[i]);
            peak    = std::max(f, peak);
        }

        return peak;
    };

    std::vector<float> peaks;

    for (int i = 0; i < frame->channels; ++i) {
        peaks.push_back(peak(reinterpret_cast<float *>(frame->data[i]), frame->nb_samples));
    }

    return peaks;
}

bool
SendSampleToDecoder(AVCodecContext *context, PlayerCore::WrappedMediaSample &sample)
{
    AVPacket *avpkt = AVPacketFromMediaSample(sample, context->time_base);
    if (!avpkt) {
        Log::TError(h, "AVPacketFromMediaSample did not return a packet");
        return false;
    }

    int err = avcodec_send_packet(context, avpkt);
    av_packet_free(&avpkt);

    if (err < 0) {
        char error[64];
        av_strerror(err, error, 64);
        Log::TError(h, "Failed to send package: {}", error);

        return false;
    }

    return true;
}

int
GetFramesFromDecoder(AVCodecContext *context, int maxCount,
                     const std::function<bool(AVFrame *, flicks)> &cb)
{
    if (!context) {
        return 0;
    }

    AVFrame *frame = av_frame_alloc();
    int count      = 0;

    while (count < maxCount && avcodec_receive_frame(context, frame) == 0) {
        auto tb  = context->time_base;
        auto tc  = frame->pts;
        auto pts = FlicksFromAVTime(tc, tb);

        if (!cb(frame, pts)) {
            break;
        }

        count++;
    }

    av_frame_free(&frame);

    return count;
}

bool
SendFrameToFilterGraph(AVFilterGraph *filterGraph, AVFrame *frame)
{
    if (!filterGraph) {
        return false;
    }

    auto vbuffer_ctx = avfilter_graph_get_filter(filterGraph, "src");

    if (!vbuffer_ctx || av_buffersrc_add_frame(vbuffer_ctx, frame) != 0) {
        Log::TError(h, "Error submitting the frame to the filtergraph");
        return false;
    }

    return true;
}

int
GetFramesFromFilterGraph(AVFilterGraph *filterGraph, int maxCount,
                         const std::function<bool(AVFrame *, flicks)> &cb)
{
    if (!filterGraph) {
        return 0;
    }

    auto vbuffersink_ctx = avfilter_graph_get_filter(filterGraph, "sink");
    if (!vbuffersink_ctx) {
        return 0;
    }

    AVFrame *frame = av_frame_alloc();
    int count      = 0;

    while (count < maxCount && av_buffersink_get_frame(vbuffersink_ctx, frame) >= 0) {
        auto link = vbuffersink_ctx->inputs[0];
        auto tb   = link->time_base;
        auto pts  = FlicksFromAVTime(link->current_pts, tb);

        bool res = cb(frame, pts);

        av_frame_unref(frame);

        count++;

        if (!res) {
            break;
        }
    }

    av_frame_free(&frame);

    return count;
}

static enum AVPixelFormat
get_hw_format(AVCodecContext * /*ctx*/, const enum AVPixelFormat *pix_fmts)
{
    const enum AVPixelFormat *p;

    for (p = pix_fmts; *p != -1; p++) {
        if (*p == AV_PIX_FMT_VAAPI)
            return *p;
    }

    fprintf(stderr, "Failed to get HW surface format.\n");
    return AV_PIX_FMT_NONE;
}

static int
hw_decoder_init(AVCodecContext *ctx, const enum AVHWDeviceType type)
{
    int err = av_hwdevice_ctx_create(&ctx->hw_device_ctx, type, NULL, NULL, 0);

    if (err < 0) {
        fprintf(stderr, "Failed to create specified HW device.\n");
        return err;
    }

    return err;
}

AVCodecContext *
CreateVideoDecoder(std::string_view h, VideoConfiguration &cfg, bool forceSwDecode)
{
    AVCodec *video_dec = avcodec_find_decoder(AV_CODEC_ID_H264);
    if (!video_dec) {
        Log::TError(h, "Codec not found");
        return nullptr;
    }

    AVCodecContext *ctx = avcodec_alloc_context3(video_dec);
    if (!ctx) {
        Log::TError(h, "Could not allocate video codec context");
        return nullptr;
    }

    AVLogger::Register(ctx, std::string(h));

    if (video_dec->capabilities & AV_CODEC_CAP_TRUNCATED) {
        ctx->flags |= AV_CODEC_CAP_TRUNCATED;
    }

    if (video_dec->capabilities & AV_CODEC_FLAG2_CHUNKS) {
        ctx->flags2 |= AV_CODEC_FLAG2_CHUNKS;
    }

    ctx->width  = cfg.width;
    ctx->height = cfg.height;

    // Copy AVCC data to context before codec open
    ctx->extradata_size = cfg.avcc.size();
    ctx->extradata      = (uint8_t *)av_malloc(ctx->extradata_size + AV_INPUT_BUFFER_PADDING_SIZE);
    std::copy(cfg.avcc.begin(), cfg.avcc.end(), ctx->extradata);

    if (Config::Get().enableHardwareDecoder && !forceSwDecode) {
        if (hw_decoder_init(ctx, AV_HWDEVICE_TYPE_VAAPI) >= 0) {
            Log::TDebug(h, "HW Decoder init successful");
            ctx->get_format = get_hw_format;
        } else {
            Log::TError(h, "Could not enable HW Decoder");
        }
    }

    if (avcodec_open2(ctx, video_dec, nullptr) < 0) {
        Log::TError(h, "Could not open codec");
        ReleaseDecoder(&ctx);
        return nullptr;
    }

    Log::TDebug(h, "Input - width: {}, height: {}", cfg.width, cfg.height);

    return ctx;
}

AVCodecContext *
CreateAudioDecoder(std::string_view h, AudioConfiguration &cfg)
{
    AVCodec *audio_dec = avcodec_find_decoder(AV_CODEC_ID_AAC);
    if (!audio_dec) {
        Log::TError(h, "Codec not found");
        return nullptr;
    }

    AVCodecContext *ctx = avcodec_alloc_context3(audio_dec);
    if (!ctx) {
        Log::TError(h, "Could not allocate audio codec context");
        return nullptr;
    }

    AVLogger::Register(ctx, std::string(h));

    // Setup context with parameters from player-core
    ctx->sample_rate = cfg.sample_rate;
    ctx->frame_size  = cfg.max_sample_size;
    ctx->channels    = cfg.channel_count;

    if (avcodec_open2(ctx, audio_dec, NULL) < 0) {
        Log::TError(h, "Could not open codec");
        ReleaseDecoder(&ctx);
        return nullptr;
    }

    return ctx;
}

void
ReleaseDecoder(AVCodecContext **decoderContext)
{
    AVLogger::Deregister(*decoderContext);
    avcodec_free_context(decoderContext);
}

}  // namespace Vape
