/**
 * GroundGrid.wgsl
 * 
 * Renders an infinite grid on the ground plane (y=0) with axis highlighting.
 * Uses ray-plane intersection to compute grid lines with distance-based fading.
 * This shader is an adaption from a personal project of mine
 * It was originally an OpenGL shader which was based on
 * Unfortunately, since its been a few years now that I have it
 * I cant find the website anymore where I originally based this version on
 */

struct Uniforms
{
    projMatrix: mat4x4f,
    viewMatrix: mat4x4f,
    invProjMatrix: mat4x4f,
    invViewMatrix: mat4x4f,
    camPos: vec3f,
    farPlane: f32,
}

@group(0) @binding(0) var<uniform> uniforms: Uniforms;

struct VertexInput
{
    @builtin(vertex_index) vertexIndex: u32,
}

struct VertexOutput
{
    @builtin(position) position: vec4f,
    @location(0) ndcCoord: vec2f,
}

@vertex
fn vs_main(input: VertexInput) -> VertexOutput
{
    var positions = array<vec2f, 6>(
        vec2f(-1.0, -1.0),
        vec2f( 1.0, -1.0),
        vec2f(-1.0,  1.0),
        vec2f(-1.0,  1.0),
        vec2f( 1.0, -1.0),
        vec2f( 1.0,  1.0)
    );

    var out: VertexOutput;
    let pos = positions[input.vertexIndex];
    out.position = vec4f(pos, 0.0, 1.0);
    out.ndcCoord = pos;
    return out;
}

fn grid(fragPos3D: vec3f, scale: f32, drawAxis: bool, fade: f32) -> vec4f
{
    let coord = fragPos3D.xz * scale;
    let derivative = fwidth(coord);

    let r = fract(coord);
    let g = min(r, 1.0 - r) / derivative;
    let line = min(g.x, g.y);

    var color = vec4f(0.15, 0.15, 0.15, 1.0 - min(line, 1.0));

    if (drawAxis)
    {
        if (abs(fragPos3D.x) < derivative.x * 2.0)
        {
            color.z = 1.0;
        }
        if (abs(fragPos3D.z) < derivative.y * 2.0)
        {
            color.x = 1.0;
        }
    }

    color.a *= fade;
    return color;
}

@fragment
fn fs_main(input: VertexOutput) -> @location(0) vec4f
{
    let clipNear = vec4f(input.ndcCoord, -1.0, 1.0);
    let clipFar = vec4f(input.ndcCoord, 1.0, 1.0);

    let projInv = uniforms.invProjMatrix;
    let viewInv = uniforms.invViewMatrix;

    let viewNear = projInv * clipNear;
    let viewFar = projInv * clipFar;

    let worldNear = viewInv * (viewNear / viewNear.w);
    let worldFar = viewInv * (viewFar / viewFar.w);

    let rayDir = normalize(worldFar.xyz - worldNear.xyz);

    let denom = dot(rayDir, vec3f(0.0, 1.0, 0.0));
    if (abs(denom) < 1e-6)
    {
        discard;
    }

    let t = -worldNear.y / denom;
    if (t < 0.0)
    {
        discard;
    }

    let fragPos = worldNear.xyz + t * rayDir;

    let dist = length(fragPos.xz);
    if (dist > uniforms.farPlane)
    {
        discard;
    }

    let fade = clamp(1.0 - dist / uniforms.farPlane, 0.0, 1.0);

    var color = grid(fragPos, 0.2, true, fade) + grid(fragPos, 0.02, true, fade);

    return color;
}