#define STB_IMAGE_IMPLEMENTATION
#include "stb_image.h"
#define STB_IMAGE_WRITE_IMPLEMENTATION
#include "stb_image_write.h"

#include <stdio.h>
#include <stdlib.h>

void subsample(const unsigned char* in, unsigned char* out, int w, int h, int c, int n) {
    int nw = w / n, nh = h / n;
    for (int y = 0; y < nh; y++)
        for (int x = 0; x < nw; x++)
            for (int k = 0; k < c; k++)
                out[(y * nw + x) * c + k] = in[(y * n * w + x * n) * c + k];
}

template <typename T>
void upsample(const T* in, T* out, int w, int h, int c, int n) {
    int nw = w * n, nh = h * n;
    for (int y = 0; y < nh; y++) {
        for (int x = 0; x < nw; x++) {
            float gx = (float)x / n, gy = (float)y / n;
            int gxi = (int)gx, gyi = (int)gy;
            float dx = gx - gxi, dy = gy - gyi;
            
            for (int k = 0; k < c; k++) {
                float v00 = in[(gyi * w + gxi) * c + k];
                float v01 = in[(gyi * w + ((gxi + 1) < w ? gxi + 1 : gxi)) * c + k];
                float v10 = in[(((gyi + 1) < h ? gyi + 1 : gyi) * w + gxi) * c + k];
                float v11 = in[(((gyi + 1) < h ? gyi + 1 : gyi) * w + ((gxi + 1) < w ? gxi + 1 : gxi)) * c + k];
                
                float value = (1 - dx) * (1 - dy) * v00 + dx * (1 - dy) * v01 +
                              (1 - dx) * dy * v10 + dx * dy * v11;
                out[(y * nw + x) * c + k] = (T)(value);
            }
        }
    }
}

void ediz(const unsigned char* input, unsigned char* output, int width, int height, int channels, int n) {
    int size = width * height * channels;
    int down_size = (width / n) * (height / n) * channels;
    int up_size = (width * n) * (height * n) * channels;
    
    unsigned char* I_down = (unsigned char*)malloc(down_size);
    unsigned char* I_rec = (unsigned char*)malloc(size);
    float* error = (float*)malloc(size * sizeof(float));
    float* E_e = (float*)malloc(up_size * sizeof(float));
    unsigned char* I_in_zoom = (unsigned char*)malloc(up_size);
    
    // Step 1 & 2: Subsample and then upsample
    subsample(input, I_down, width, height, channels, n);
    upsample(I_down, I_rec, width / n, height / n, channels, n);
    
    // Step 3: Calculate reconstruction error
    for (int i = 0; i < size; i++) {
        error[i] = input[i] - I_rec[i];
    }
    
    // Step 4 & 5: Upsample error and input
    upsample(error, E_e, width, height, channels, n);
    upsample(input, I_in_zoom, width, height, channels, n);
    
    // Step 6: Add estimated error to zoomed input image
    for (int i = 0; i < up_size; i++) {
        if (i % channels == 3) { // Alpha channel
            output[i] = I_in_zoom[i];
        } else {
            int value = I_in_zoom[i] + E_e[i];
            output[i] = (value > 255) ? 255 : value < 0 ? 0 : value;
        }
    }
    
    free(I_down);
    free(I_rec);
    free(error);
    free(E_e);
    free(I_in_zoom);
}

int main(int argc, char* argv[]) {
    if (argc != 4) {
        fprintf(stderr, "Usage: %s <input_image> <output_image> <zoom_factor>\n", argv[0]);
        return 1;
    }

    int width, height, channels;
    unsigned char* input = stbi_load(argv[1], &width, &height, &channels, 0);
    if (!input) {
        fprintf(stderr, "Error loading image %s\n", argv[1]);
        return 1;
    }

    int n = atoi(argv[3]);
    int new_width = width * n, new_height = height * n;
    unsigned char* output = (unsigned char*)malloc(new_width * new_height * channels);

    ediz(input, output, width, height, channels, n);

    if (!stbi_write_png(argv[2], new_width, new_height, channels, output, new_width * channels)) {
        fprintf(stderr, "Error writing image %s\n", argv[2]);
        return 1;
    }

    printf("Processed %dx%d image with %d channels, saved as %dx%d\n", 
           width, height, channels, new_width, new_height);

    stbi_image_free(input);
    free(output);
    return 0;
}