#include <stdio.h>
#include <stdlib.h>
#include <math.h>

#define STB_IMAGE_IMPLEMENTATION
#include "stb_image.h"
#define STB_IMAGE_WRITE_IMPLEMENTATION
#include "stb_image_write.h"

#define CHANNELS 3

void processImage(unsigned char* input, unsigned char* output, int width, int height) {
    float* localStdDev = (float*)malloc(width * height * sizeof(float));
    float* error = (float*)malloc(width * height * sizeof(float));

    for (int c = 0; c < CHANNELS; c++) {
        // Calculate local standard deviation
        for (int y = 0; y < height; y++) {
            for (int x = 0; x < width; x++) {
                int halfWindow = 5;
                float sum = 0, sumSq = 0;
                int count = 0;
                for (int wy = -halfWindow; wy <= halfWindow; wy++) {
                    for (int wx = -halfWindow; wx <= halfWindow; wx++) {
                        int ny = y + wy, nx = x + wx;
                        if ((unsigned)ny < height && (unsigned)nx < width) {
                            float val = input[(ny*width + nx) * CHANNELS + c];
                            sum += val;
                            sumSq += val * val;
                            count++;
                        }
                    }
                }
                float mean = sum / count;
                localStdDev[y*width + x] = sqrtf((sumSq / count) - (mean * mean));
            }
        }

        // Calculate global standard deviation
        float sum = 0, sumSq = 0;
        for (int i = 0; i < width * height; i++) {
            sum += input[i * CHANNELS + c];
            sumSq += (float)input[i * CHANNELS + c] * input[i * CHANNELS + c];
        }
        float mean = sum / (width * height);
        float globalStdDev = sqrtf((sumSq / (width * height)) - (mean * mean));

        // Find min and max local standard deviation
        float minStdDev = localStdDev[0], maxStdDev = localStdDev[0];
        for (int i = 1; i < width * height; i++) {
            minStdDev = fminf(minStdDev, localStdDev[i]);
            maxStdDev = fmaxf(maxStdDev, localStdDev[i]);
        }

        memset(error, 0, width * height * sizeof(float));

        // Error diffusion
        for (int y = 0; y < height; y++) {
            for (int x = 0; x < width; x++) {
                static const int kernel[3][3] = {{0, 1, 0}, {1, -4, 1}, {0, 1, 0}};
                int sum = 0;
                if((unsigned)(y-1) < height-2 && (unsigned)(x-1) < width-2) {
                    for (int ky = -1; ky <= 1; ky++) {
                        for (int kx = -1; kx <= 1; kx++) {
                            sum += kernel[ky+1][kx+1] * input[((y+ky)*width + (x+kx)) * CHANNELS + c];
                        }
                    }
                }
                float laplacian = fmaxf(-128, fminf(128, sum));
                
                float C = 5;
                float K = C / globalStdDev * ((maxStdDev - localStdDev[y*width + x]) / (maxStdDev - minStdDev)) + C;
                float threshold = K * laplacian + ((float)rand() / RAND_MAX * 2 - 1) * (255 * 0.1);
                
                float pixel = input[(y*width + x) * CHANNELS + c] + error[y*width + x];
                output[(y*width + x) * CHANNELS + c] = (pixel > threshold) ? 255 : 0;
                float quant_error = pixel - output[(y*width + x) * CHANNELS + c];
                if (x < width - 1) error[y*width + x + 1] += quant_error * 7 / 16.0;
                if (y < height - 1) {
                    if (x > 0) error[(y+1)*width + x - 1] += quant_error * 3 / 16.0;
                    error[(y+1)*width + x] += quant_error * 5 / 16.0;
                    if (x < width - 1) error[(y+1)*width + x + 1] += quant_error * 1 / 16.0;
                }
            }
        }
    }

    free(localStdDev);
    free(error);
}

int main(int argc, char* argv[]) {
    if (argc != 3) {
        printf("Usage: %s <input_image> <output_image>\n", argv[0]);
        return 1;
    }

    int width, height, channels;
    unsigned char* input = stbi_load(argv[1], &width, &height, &channels, CHANNELS);
    if (!input) {
        printf("Failed to read input image.\n");
        return 1;
    }

    unsigned char* output = (unsigned char*)malloc(width * height * CHANNELS * sizeof(unsigned char));

    processImage(input, output, width, height);

    stbi_write_png(argv[2], width, height, CHANNELS, output, width * CHANNELS);

    stbi_image_free(input);
    free(output);

    return 0;
}