#include <catch2/catch_test_macros.hpp>
#include <catch2/catch_session.hpp>
#include <catch2/generators/catch_generators.hpp>
#include <catch2/generators/catch_generators_random.hpp>
#include <catch2/generators/catch_generators_adapters.hpp>
#include <array>
#include <functional>

#include "../assets/shader/utils/glsl_adapter.h"
#include <glm/gtx/transform.hpp>
// glsl shader code to test
#include "../assets/shader/utils/glsl/random.glsl"
#include "../assets/shader/utils/glsl/ray_tracing_utils.glsl"
#include "../assets/shader/utils/glsl/bsdf/microfacet.glsl"

TEST_CASE("TEA Random Generator")
{
    uint32_t seed = tea_init(5, 19);
    constexpr double delta = 0.001;

    SECTION("draw 1000000 random single precision numbers [0..1[") {
        double sum = 0;
        for (uint32_t i = 0; i < 1000000; i++) sum += static_cast<double>(tea_nextFloat(seed));
        sum /= 1000000.0;

        REQUIRE(abs(sum - 0.5) < delta);
    }

    SECTION("draw 1000000 random single precision numbers [0..1[ (fast version)") {
        double sum = 0;
        for (uint32_t i = 0; i < 1000000; i++) sum += static_cast<double>(tea_nextFloatFast(seed));
        sum /= 1000000.0;

        REQUIRE(abs(sum - 0.5) < delta);
    }
}

TEST_CASE("PCG32 Random Generator") {
    u64vec2 seed = pcg32_init(5, 19);
    constexpr double delta = 0.001;

    SECTION("draw 1000000 random single precision numbers [0..1[") {
        double sum = 0;
        for (uint32_t i = 0; i < 1000000; i++) sum += static_cast<double>(pcg32_nextFloat(seed));
        sum /= 1000000.0;

        REQUIRE(abs(sum - 0.5) < delta);
    }

    SECTION("draw 1000000 random double precision numbers [0..1[") {
        double sum = 0;
        for (uint32_t i = 0; i < 1000000; i++) sum += static_cast<double>(pcg32_nextDouble(seed));
        sum /= 1000000.0;

        REQUIRE(abs(sum - 0.5) < delta);
    }
}

TEST_CASE("Random range") {
	const auto init_rands = GENERATE(take(1, chunk(2, random(0u, std::numeric_limits<uint32_t>::max()))));
    uint32_t seed = tea_init(init_rands[0], init_rands[1]);
    constexpr float delta = 0.05f;

    SECTION("draw 1000000 random number between [1..10]") {
        constexpr uint32_t rounds = 1000000;
        //const auto bucket_count = GENERATE(5, 10);
        constexpr uint32_t bucket_count = 10;
        constexpr float pdf = PDF_RANGE_UNIFORM(bucket_count);
        int buckets[bucket_count] = {};

        for (uint32_t i = 0; i < rounds; i++) {
            const int idx = sampleRangeUniform(1, bucket_count, tea_nextFloat(seed)) - 1;
            REQUIRE(idx < static_cast<int>(bucket_count));
            buckets[idx]++;
        }
        for (const int bucket : buckets) {
            const float cpdf = static_cast<float>(bucket) / static_cast<float>(rounds);
            REQUIRE(abs(cpdf - pdf) < delta);
        }
    }
}

// https://en.wikipedia.org/wiki/Spherical_coordinate_system
// physics convention, orientation in +z direction, phi == 0 in +x direction
TEST_CASE("Spherical <> Cartesian Convertion Test") {
    constexpr float delta = 0.001f;
    const vec3 sphericalInitDir(1, 0, 0);   // r = 1 | theta = 0 | phi = 0
    const vec3 cartesianInitDir(0, 0, 1);   // x = 0 | y = 0 | z = 1

    SECTION("Simple Check : Spherical to Cartesian") {
        const vec3 cartesianTransformedDir = sphericalToCartesian(sphericalInitDir);
        REQUIRE(all(equal(cartesianInitDir, cartesianTransformedDir)));
    }

    SECTION("Simple Check : Cartesian to Spherical") {
        const vec3 sphericalTransformedDir = cartesianToSpherical(cartesianInitDir);
        REQUIRE(all(equal(sphericalInitDir, sphericalTransformedDir)));
    }

    SECTION("Spherical to Cartesian") {
        constexpr uint32_t steps = 10;
        for (uint32_t stepTheta = 0; stepTheta < steps; stepTheta++) {
            for (uint32_t stepPhi = 0; stepPhi < steps; stepPhi++) {
                const float angleTheta = M_PI_DIV_2 * (static_cast<float>(stepTheta) / static_cast<float>(steps-1));
                const float anglePhi = M_2PI * (static_cast<float>(stepPhi) / static_cast<float>(steps-1));

                const glm::mat3 rmatTheta = glm::rotate(angleTheta, glm::vec3(0.0f, 1.0f, 0.0f));
                const glm::mat3 rmatPhi = glm::rotate(anglePhi, glm::vec3(0.0f, 0.0f, 1.0f));
                const vec3 transformedDir = rmatPhi * rmatTheta * cartesianInitDir;

                const vec3 cartesianTransformedDir = sphericalToCartesian(vec3(1.0f, angleTheta, anglePhi));

                REQUIRE(all(lessThan(abs(cartesianTransformedDir - transformedDir), vec3(delta))));
            }
        }
    }

    SECTION("Cartesian to Spherical") {
        constexpr uint32_t steps = 10;
        for (uint32_t stepTheta = 0; stepTheta < steps; stepTheta++) {
            for (uint32_t stepPhi = 0; stepPhi < steps; stepPhi++) {
                const float angleTheta = M_PI_DIV_2 * (static_cast<float>(stepTheta) / static_cast<float>(steps - 1));
                const float anglePhi = M_PI * (static_cast<float>(stepPhi) / static_cast<float>(steps - 1));

                const glm::mat3 rmatTheta = glm::rotate(angleTheta, glm::vec3(0.0f, 1.0f, 0.0f));
                const glm::mat3 rmatPhi = glm::rotate(anglePhi, glm::vec3(0.0f, 0.0f, 1.0f));
                const vec3 transformedDir = rmatPhi * rmatTheta * cartesianInitDir;

                vec3 sphericalTransformedDir = cartesianToSpherical(transformedDir);
                // wrap is problematic
                if (stepTheta == 0) sphericalTransformedDir.z = anglePhi;
                if (stepPhi == (steps - 1)) sphericalTransformedDir.z = abs(sphericalTransformedDir.z);

                REQUIRE(all(lessThan(abs(sphericalTransformedDir - vec3(1, angleTheta, anglePhi)), vec3(delta))));
            }
        }
    }
}

TEST_CASE("Weak White Furnace Test") {
    constexpr bool heitz = false;
    constexpr float dtheta = 0.05f;
    constexpr float dphi = 0.05f;

    float theta_i = GENERATE(0.2f, 0.5f, 0.7f) * M_PI_DIV_2;
    float alpha = GENERATE(0.2f, 0.5f, 0.7f);
    float alpha2 = alpha * alpha;

    // n == [0, 0, 1]
    const vec3 wi(sin(theta_i), 0, cos(theta_i));
    const float nDotWi = wi[2];

    float g1_ggx = 0.0f;
    if constexpr (!heitz) g1_ggx = G1_GGX(1, nDotWi, alpha);
    else {
        const float theta_o = acos(nDotWi);
        const float a = 1.0f / (alpha * tan(theta_o));
        const float lambda = (-1.0f + sqrt(1.0f + 1.0f / (a * a))) / 2.0f;
        g1_ggx = 1.0f / (1.0f + lambda);
    }

    float integral = 0.0f;
    for(float theta = 0.0f; theta <= M_PI;) {
        for (float phi = 0.0f; phi <= M_2PI;) {
        	const vec3 wo(cos(phi) * sin(theta), cos(theta), sin(phi) * sin(theta));
            const vec3 m = normalize(wi + wo);
            const float nDotM = m[2];// dot(n, m);
            const float nDotWo = wo[2];

            float d_ggx = 0.0f;
            if (nDotM > 0) {
                if constexpr (!heitz) d_ggx = D_GTR2(nDotM, alpha);
                else {
                    const float theta_h = acos(nDotM);
                    d_ggx = 1.0f / pow(1.0f + pow(tan(theta_h) / alpha, 2.0f), 2.0f);
                    d_ggx = d_ggx / (M_PI * alpha2 * pow(nDotM, 4.0f));
                }
            }

            integral = integral + sin(theta) * (d_ggx * g1_ggx) / (4.0f * abs(nDotWi));

            phi += dphi;
        }
        theta += dtheta;
    }
    integral = integral * dphi * dtheta;
    REQUIRE(abs(integral - 1.0f) < 0.01f);
}

namespace {
    template<uint32_t col, uint32_t row, typename T>
    class matrix {
        std::array<std::array<T, row>, col> m;
    public:
        matrix() = default;
        std::array<T, row>& operator[](uint32_t x) {
            return m.at(x);
        }
    };

    // theta from [0..pi/2]
    // phi from [0..2pi]
    auto getHemisphereSolidAngleBucketIndex(const float theta, const float phi, const int bucketCount)
    {
        const auto fbuckets = static_cast<float>(bucketCount);
        for (int i = 0; i < bucketCount; i++) {
            const float cos_theta_low = acos(static_cast<float>(i+1) / fbuckets);
            const float cos_theta_high = acos(static_cast<float>(i) / fbuckets);
            for (int j = 0; j < bucketCount; j++) {
                const float phi_low = (static_cast<float>(j) / fbuckets) * M_2PI;
                const float phi_high = (static_cast<float>(j + 1) / fbuckets) * M_2PI;
                if (theta >= cos_theta_low && theta <= cos_theta_high && phi >= phi_low && phi <= phi_high) return std::make_pair(i, j);
            }
        }
        return std::make_pair(-1, -1);
    };

    float trapezoidalIntegration(float a, float b, float f_a, float f_b)
    { return (b - a) * 0.5f * (f_a + f_b); }
    float trapezoidalIntegration2D(const std::function<float(float, float)>& f, float x0, float y0, float x1, float y1, uint32_t samples)
    {
        const float x_steps = (x1 - x0) / static_cast<float>(samples);
        const float y_steps = (y1 - y0) / static_cast<float>(samples);
        float integral = 0.0f;
        for (uint32_t i = 0; i < samples; i++) {
            const float f_i = static_cast<float>(i);
            const float x_low = x0 + f_i * x_steps;
            const float x_high = x0 + (f_i + 1.0f) * x_steps;
            for (uint32_t j = 0; j < samples; j++) {
                const float f_j = static_cast<float>(j);
                const float y_low = y0 + f_j * y_steps;
                const float y_high = y0 + (f_j + 1.0f) * y_steps;

                const float integral_x_low = trapezoidalIntegration(x_low, x_high, f(x_low, y_low), f(x_high, y_low));
                const float integral_x_high = trapezoidalIntegration(x_low, x_high, f(x_low, y_high), f(x_high, y_high));
                const float integral_y = trapezoidalIntegration(y_low, y_high, integral_x_low, integral_x_high);
                integral +=  integral_y;
            }
        }
        return integral;
    }

    template<uint32_t col, uint32_t row, typename T>
    float chiSquareTest(matrix<col, row, float>& observation, matrix<col, row, float>& expected)
    {
        uint32_t dof = 0;
        float chi = 0.0f;
        for (uint32_t i = 0; i < col; i++) {
            for (uint32_t j = 0; j < row; j++) {
                INFO("cell frequency less than 5 not handled");
                REQUIRE(expected[i][j] >= 5);

                if (observation[i][j] == 0.0f && expected[i][j] == 0.0f) continue;

                chi += pow(observation[i][j] - expected[i][j], 2.0f) / expected[i][j];
                dof++;
            }
        }
        return chi;
    }
}

TEST_CASE("Microfacet Important Sampling Test") {
    const auto init_rands = GENERATE(take(1, chunk(2, random(0ull, std::numeric_limits<uint64_t>::max()))));
    u64vec2 seed = pcg32_init(init_rands[0], init_rands[1]);

    constexpr float eps = 0.05f;
    constexpr int samples = 1000000;
    constexpr int bucketCountAxis = 10;
    constexpr int bucketCountTotal = bucketCountAxis * bucketCountAxis;
    // since the integral of each bucket is 2pi/bucketCountTotal we multiply by bucketCountTotal to get 2pi for each bucket
    constexpr float weight = static_cast<float>(bucketCountTotal) / static_cast<float>(samples);

    const float alpha = 0.5f;
    auto pdf = [&](float cosTheta, float phi) -> float {
        return D_GGX_PDF(cosTheta, alpha);
        //return PDF_HEMISPHERE_UNIFORM(1);
    };
    auto sampler = [&](const vec2 random) -> vec2 {
        return D_GGX_Importance_Sample(random, alpha);
        //return cartesianToSpherical(sampleUnitHemisphereUniform(random)).yz + vec2(0, M_PI);
    };

    matrix<bucketCountAxis, bucketCountAxis, float> buckets = {};
    const auto f_buckets = static_cast<float>(bucketCountAxis);
    for (int s = 0; s < samples; s++) {
        const vec2 sc = sampler(pcg32_nextFloat2(seed));
        const auto [i, j] = getHemisphereSolidAngleBucketIndex(sc.x, sc.y, bucketCountAxis);

        const bool valid_index = (i >= 0) && (i < bucketCountAxis) && (j >= 0) && (j < bucketCountAxis);
        REQUIRE(valid_index);
        buckets[i][j] += weight / pdf(cos(sc.x), 0);
    }

    float final_avg = 0.0f;
    for (int i = 0; i < bucketCountAxis; i++) {
        for (int j = 0; j < bucketCountAxis; j++) {
            final_avg += buckets[i][j] / static_cast<float>(bucketCountTotal);
        }
    }

    const float error = final_avg - M_2PI;
    INFO("average : " << final_avg << " (error: " << error << ")");
    REQUIRE(abs(error) < eps);
}

TEST_CASE("Microfacet Important Sampling Chi2 Test") {
    const auto init_rands = GENERATE(take(1, chunk(2, random(0ull, std::numeric_limits<uint64_t>::max()))));
    u64vec2 seed = pcg32_init(init_rands[0], init_rands[1]);

    constexpr int samples = 1000000;
    constexpr int bucketCountAxis = 10;

    const float alpha = 0.3f;
    auto integrand = [&](float cosTheta, float phi) -> float {
        return D_GGX_PDF(cosTheta, alpha);
        //return PDF_HEMISPHERE_UNIFORM(1);
    };
    auto sampler = [&](const vec2 random) -> vec2 {
        return D_GGX_Importance_Sample(random, alpha);
        //return cartesianToSpherical(sampleUnitHemisphereUniform(random)).yz + vec2(0, M_PI);
    };

    float pdf_sum = 0.0f;
    matrix<bucketCountAxis, bucketCountAxis, float> expected = {};
    const auto f_buckets = static_cast<float>(bucketCountAxis);
    for (int i = 0; i < bucketCountAxis; i++) {
        const float cos_theta_low = static_cast<float>(i) / f_buckets;
        const float cos_theta_high = static_cast<float>(i + 1) / f_buckets;
        for (int j = 0; j < bucketCountAxis; j++) {
            const float phi_low = (static_cast<float>(j) / f_buckets) * M_2PI;
            const float phi_high = (static_cast<float>(j + 1) / f_buckets) * M_2PI;

            //double integral = adaptiveSimpson2D(integrand, cos_theta_low, phi_low, cos_theta_high, phi_high);
            float integral = trapezoidalIntegration2D(integrand, cos_theta_low, phi_low,cos_theta_high, phi_high, 20);
            expected[i][j] = integral * static_cast<float>(samples);
            pdf_sum += integral;
        }
    }
    REQUIRE(abs(1.0f - pdf_sum) < 0.01f);

    matrix<bucketCountAxis, bucketCountAxis, float> observed = {};
    for (int s = 0; s < samples; s++) {
        const vec2 sc = sampler(pcg32_nextFloat2(seed));
        const auto [i, j] = getHemisphereSolidAngleBucketIndex(sc.x, sc.y, bucketCountAxis);

        const bool valid_index = (i >= 0) && (i < bucketCountAxis) && (j >= 0) && (j < bucketCountAxis);
        REQUIRE(valid_index);
        observed[i][j]++;
    }
    //const float significance_level = 0.05f;
    const float significance_level_value = 124.342f;

    const float chi_test_statistic = chiSquareTest<bucketCountAxis, bucketCountAxis, float>(observed, expected);
    //const float p_value = toP(chi_test_statistic);
    // todo: right now hardcoded for fod 100 and 0.05
    INFO("Rejected null hypothesis  : " << chi_test_statistic << " > " << significance_level_value);
    REQUIRE(chi_test_statistic < significance_level_value);
}

//int main(int argc, char *argv[])
//{
//	const int result = Catch::Session().run(argc, argv);
//    return result;
//}
