#include <cstdlib>
#include <cmath>
#include <opencv2/opencv.hpp>
#include <iostream>
#include <array>
#include <string_view>
#include <print>

// https://johncostella.com/magic/

// TODO:
// [ ] fix bugs
// [ ] replace kernel by a LUT
// [ ] try upscale
// [ ] verify kernel size vs original
// [ ] change from double to float
// [x] implement gamma
// [ ] implement post-sharpening (not needed for mks2021)
// [ ] make test images larger, test resampling on them rather than downloaded images

static inline double magic_kernel(double x) 
{
    if (x < 0.0)
        x = -x;
    if (x <= 0.5)
        return(0.75 - (x * x));
    if (x < 1.5)
        return(0.5 * (x - 1.5) * (x - 1.5));
    return 0.0;
}

static inline double mks2013_kernel(double x)
{
    if (x < 0.0)
        x = -x;
    if (x <= 0.5)
        return 17.0/16.0-1.75*x*x;
    if (x <= 1.5)
        return (1.0-x)*(1.75-x);
    if (x <= 2.5)
        return -0.125*(x-2.5)*(x-2.5);
    return 0.0;
    
}

static inline double mks2021_kernel(double x)
{
    if (x < 0)
        x = -x;
    if (x <= 0.5)
        return 577.0/576.0-239.0/144.0*x*x;
    if (x <= 1.5)
        return 35.0/36.0*(x-1.0)*(x-239.0/140.0);
    if (x <= 2.5)
        return 1.0/6.0*(x-2.0)*(65.0/24.0-x);
    if (x <= 3.5)
        return 1.0/36.0*(x-3.0)*(x-15.0/4.0);
    if (x <= 4.5)
        return -1.0/288.0*(x-4.5)*(x-4.5);
    return 0;
}


// Wrap-around indexing for edges
static inline int wrap_index(int i, int size)
{
    if (i < 0) 
        return -i;
    if (i >= size) 
        return 2 * (size - 1) - i;
    return i;
}


// sRGB to linear (for float images in range [0, 1])
void srgbToLinear(const cv::Mat& src, cv::Mat& dst)
{
    src.convertTo(dst, CV_32F, 1.0/255.0);

    float* ptr = dst.ptr<float>();
    size_t total = dst.total() * dst.channels();

    for (size_t i = 0; i < total; i++) {
        float val = ptr[i];
        if (val <= 0.04045f)
            ptr[i] = val / 12.92f;
        else
            ptr[i] = std::pow((val + 0.055f) / 1.055f, 2.4f);
    }
}


// Linear to sRGB (for float images in range [0, 1])
void linearToSRGB(const cv::Mat& src, cv::Mat& dst) {
    cv::Mat temp = src.clone();

    float* ptr = temp.ptr<float>();
    size_t total = temp.total() * temp.channels();

    for (size_t i = 0; i < total; i++) {
        float val = ptr[i];
        if (val <= 0.0031308f)
            ptr[i] = val * 12.92f;
        else
            ptr[i] = 1.055f * std::pow(val, 1.0f/2.4f) - 0.055f;
    }
    temp.convertTo(dst, CV_8U, 255.0);
}


// Horizontal pass: src[rows][src_cols] -> dst[rows][dst_cols]
template<double support, auto KernelFunc>
void downsample_horizontally(
    const float *src, float *dst,
    int rows, int src_cols, int dst_cols, int channels
)
{
    const double fx_ratio = (double)src_cols / (double)dst_cols;
    const double scale_x = (fx_ratio > 1.0) ? fx_ratio : 1.0;
    const int radius = (dst_cols < src_cols) ? (int)ceil(support * scale_x) : 3;
    
    for (int row = 0; row < rows; row++) {
        const float *src_row = src + row * src_cols * channels;
        float       *dst_row = dst + row * dst_cols * channels;
        
        // for destination x
        for (int dx = 0; dx < dst_cols; dx++) {
            // Map destination pixel center to source coordinate (source x float)
            double src_x_f = ((dx + 0.5) * fx_ratio) - 0.5;
            int ix = (int)floor(src_x_f); // src_x_int
            double frac_x = src_x_f - ix; // src_x_frac
            
            double sum[4] = {0.0, 0.0, 0.0, 0.0};  // up to 4 channels
            double wsum = 0.0;
            
            // for kernel x
            for (int kx = -radius; kx <= radius; kx++) {
                int sx = wrap_index(ix + kx, src_cols);
                
                double dx_dist = fabs(frac_x - kx) / scale_x;

                if (dx_dist >= support)
                    continue;
                
                double w = KernelFunc(dx_dist);
                
                for (int c = 0; c < channels; c++) {
                    sum[c] += src_row[sx * channels + c] * w;
                }
                wsum += w;
            }
            
            // Normalize and write output
            if (wsum > 0.0) {
                double inv_wsum = 1.0 / wsum;
                for (int c = 0; c < channels; c++) {
                    dst_row[dx * channels + c] = (float)(sum[c] * inv_wsum);
                }
            }
            else {
                for (int c = 0; c < channels; c++) {
                    dst_row[dx * channels + c] = 0.0f;
                }
            }
        }
    }
}

// Vertical pass: src[src_rows][cols] -> dst[dst_rows][cols]
template<double support, auto KernelFunc>
void downsample_vertically(
    const float *src, float *dst,
    int src_rows, int dst_rows, int cols, int channels)
{
    const double fy_ratio = (double)src_rows / (double)dst_rows;
    const double scale_y = (fy_ratio > 1.0) ? fy_ratio : 1.0;
    const int radius = (dst_rows < src_rows) ? (int)ceil(support * scale_y) : 3;
    
    for (int dy = 0; dy < dst_rows; dy++) {
        // Map destination pixel center to source coordinate
        double src_y_f = ((dy + 0.5) * src_rows / dst_rows) - 0.5;
        int iy = (int)floor(src_y_f);
        double frac_y = src_y_f - iy;
        
        float *dst_row = dst + dy * cols * channels;
        
        for (int col = 0; col < cols; col++) {
            double sum[4] = {0.0, 0.0, 0.0, 0.0};
            double wsum = 0.0;
            
            for (int ky = -radius; ky <= radius; ky++) {
                int sy = wrap_index(iy + ky, src_rows);
                
                double dy_dist = fabs(frac_y - ky) / scale_y;

                if (dy_dist >= support)
                    continue;
                
                double w = KernelFunc(dy_dist);
                
                const float *src_pixel = src + (sy * cols + col) * channels;
                for (int c = 0; c < channels; c++) {
                    sum[c] += src_pixel[c] * w;
                }
                wsum += w;
            }
            
            // Normalize and write output
            if (wsum > 0.0) {
                double inv_wsum = 1.0 / wsum;
                for (int c = 0; c < channels; c++) {
                    dst_row[col * channels + c] = (float)(sum[c] * inv_wsum);
                }
            }
            else {
                for (int c = 0; c < channels; c++) {
                    dst_row[col * channels + c] = 0.0f;
                }
            }
        }
    }
}


using DownsampleFunc = float *(*)(const float*, int, int, int, int, int);


// Full separable downsample
template<double support, auto KernelFunc>
float* downsample_impl(
    const float *src,
    int src_rows, int src_cols,
    int dst_rows, int dst_cols,
    int channels)
{
    if (!src || src_rows <= 0 || src_cols <= 0 || dst_rows <= 0 || dst_cols <= 0)
        return NULL;
    
    // Allocate intermediate buffer (horizontal pass output)
    float *temp = (float*)malloc(src_rows * dst_cols * channels * sizeof(float));
    if (!temp) 
        return NULL;
    
    // Allocate final output buffer
    float *dst = (float*)malloc(dst_rows * dst_cols * channels * sizeof(float));
    if (!dst) {
        free(temp);
        return NULL;
    }
    
    // Pass 1: Horizontal resample
    downsample_horizontally<support, KernelFunc>(src, temp, src_rows, src_cols, dst_cols, channels);
    
    // Pass 2: Vertical resample
    downsample_vertically<support, KernelFunc>(temp, dst, src_rows, dst_rows, dst_cols, channels);
    
    free(temp);
    return dst;
}


float* downsample_magic_kernel(
    const float* s,
    int sr, int sc, int dr, int dc, 
    int ch)
{
    return downsample_impl<1.5, magic_kernel>(s, sr, sc, dr, dc, ch);
}


float* downsample_mks2021(
    const float* s,
    int sr, int sc, int dr, int dc, 
    int ch)
{
    return downsample_impl<4.5, mks2021_kernel>(s, sr, sc, dr, dc, ch);
}

float* downsample_mks2013(
   const float* s,
   int sr, int sc, int dr, int dc,
   int ch)
{
   return downsample_impl<2.5, mks2013_kernel>(s, sr, sc, dr, dc, ch);
}


// Example usage:
// float *input = ...; // src_rows × src_cols × channels
// float *output = downsample(input, 1080, 1920, 540, 960, 3);
// ... use output ...
// free(output);


// TODO: convert to linear RGB
int main()
{
    // Load image
    //cv::Mat img = cv::imread("/home/pavel/devel/mks/lena512.png", cv::IMREAD_GRAYSCALE);
    cv::Mat img =cv::imread("Square-wedges-80grey.png", cv::IMREAD_GRAYSCALE);
    if (img.empty()) {
        std::cerr << "Error: Could not load lena.png" << std::endl;
        return -1;
    }
    
    std::cout << "Loaded image: " << img.rows << "x" << img.cols 
              << " channels: " << img.channels() << std::endl;
    
    // Convert to float [0, 1]
    cv::Mat img_float;
    //img.convertTo(img_float, CV_32F, 1.0/255.0);
    srgbToLinear(img, img_float);
    // Get image properties
    int src_cols = img_float.cols;
    int src_rows = img_float.rows;
    int channels = img_float.channels();
    int dst_cols = 512;
    int dst_rows = 410;
    
    constexpr auto funcsTable = std::to_array<std::tuple<DownsampleFunc, std::string_view>> ({
        // ~/bin/resize_image -w 300 -g -m MAGIC_KERNEL lena512.png lena-no-gamma-no-sharp.png
        { downsample_magic_kernel, "mk" },
	{ downsample_mks2013, "mks2013" },
        // ~/bin/resize_image -w 150 -g ../test_zoneplate.png ref-g.png
        { downsample_mks2021, "mks2021" }
    });

    for (const auto& [downsampleFunc, name] : funcsTable) {
        // Call our downsample function
        float *output = downsampleFunc(
            (const float*)img_float.data,
            src_rows, src_cols,
            dst_rows, dst_cols,
            channels
        );
        
        if (!output) {
            std::cerr << "Error: Downsample failed" << std::endl;
            return -1;
        }
        
        // Wrap output in cv::Mat
        cv::Mat result(dst_rows, dst_cols, CV_32FC(channels), output);
        
        // Convert back to 8-bit [0, 255]
        cv::Mat result_8u;
        //result.convertTo(result_8u, CV_8U, 255.0);
        linearToSRGB(result, result_8u);
        
        // Save result
        std::string filename = std::format("out_{}.png", name); 
        if (cv::imwrite(filename, result_8u)) {
            std::println("Saved {} ({}x{})", filename, dst_cols, dst_rows);
        }
        else {
            std::println(std::cerr, "Error: Could not save lena-out.png");
        }
        
        // Cleanup
        free(output);
    }
    return 0;
}

