/**
 * TubeGenerator.wgsl
 * 
 * Renders cylindrical tubes with ray-cone intersection and lighting.
 * Inserts transparent fragments into the K-buffer for order-independent transparency.
 * Supports optional clipping planes and Blinn-Phong shading.
 */

struct UniformData
{
    viewMatrix: mat4x4f,
    projMatrix: mat4x4f,
    camPos: vec4f,
    camDir: vec4f,
    clearColor: vec4f,
    helperLineColor: vec4f,
    kBufferInfo: vec4f,
    
    dirLightDirection: vec4f,
    dirLightColor: vec4f,
    ambLightColor: vec4f,
    materialLightResponse: vec4f,
    
    vertexColorMin: vec4f,
    vertexColorMax: vec4f,
    vertexAlphaBounds: vec4f,
    vertexRadiusBounds: vec4f,
    
    billboardClippingEnabled: u32,
    billboardShadingEnabled: u32,
    vertexColorMode: u32,
    vertexAlphaMode: u32,
    vertexRadiusMode: u32,
    
    vertexAlphaInvert: u32,
    vertexRadiusInvert: u32,
    dataMaxLineLength: f32,
    dataMaxVertexAdjacentLineLength: f32,
}

@group(0) @binding(0) var<uniform> uniforms: UniformData;
@group(0) @binding(1) var<storage, read_write> kBuffer: array<atomic<u32>>;

struct VertexInput
{
    @location(0) posWS: vec3f,
    @location(1) radiusA: f32,
    @location(2) posA: vec3f,
    @location(3) radiusB: f32,
    @location(4) posB: vec3f,
    @location(5) color: vec4f,
    @location(6) n0: vec3f,
    @location(7) pad2: f32,
    @location(8) n1: vec3f,
    @location(9) pad3: f32,
}

struct VertexOutput
{
    @builtin(position) position: vec4f,
    @location(0) viewRay: vec3f,
    @location(1) color: vec4f,
    @location(2) posA: vec3f,
    @location(3) posB: vec3f,
    @location(4) radiusA: f32,
    @location(5) radiusB: f32,
    @location(6) n0: vec3f,
    @location(7) n1: vec3f,
    @location(8) posWS: vec3f,
}

@vertex
fn vs_main(input: VertexInput) -> VertexOutput
{
    var output: VertexOutput;
    
    let worldPos = vec4f(input.posWS, 1.0);
    output.position = uniforms.projMatrix * uniforms.viewMatrix * worldPos;
    output.posWS = input.posWS;
    output.viewRay = input.posWS - uniforms.camPos.xyz;
    output.color = input.color;
    output.posA = input.posA;
    output.posB = input.posB;
    output.radiusA = input.radiusA;
    output.radiusB = input.radiusB;
    output.n0 = input.n0;
    output.n1 = input.n1;
    
    return output;
}

//https://www.shadertoy.com/view/MlKfzm
//https://iquilezles.org/articles/intersectors/
fn iRoundedCone(ro: vec3f, rd: vec3f, pa: vec3f, pb: vec3f, ra: f32, rb: f32) -> vec4f
{
    let ba = pb - pa;
    let oa = ro - pa;
    let ob = ro - pb;
    let rr = ra - rb;
    let m0 = dot(ba, ba);
    let m1 = dot(ba, oa);
    let m2 = dot(ba, rd);
    let m3 = dot(rd, oa);
    let m5 = dot(oa, oa);
    let m6 = dot(ob, rd);
    let m7 = dot(ob, ob);
    
    let d2 = m0 - rr * rr;
    let k2 = d2 - m2 * m2;
    let k1 = d2 * m3 - m1 * m2 + m2 * rr * ra;
    let k0 = d2 * m5 - m1 * m1 + m1 * rr * ra * 2.0 - m0 * ra * ra;
    
    let h = k1 * k1 - k0 * k2;
    if (h < 0.0)
    {
        return vec4f(-1.0);
    }
    
    var t = (-sqrt(h) - k1) / k2;
    if (t < 0.0)
    {
        return vec4f(-1.0);
    }
    
    let y = m1 - ra * rr + t * m2;
    if (y > 0.0 && y < d2)
    {
        return vec4f(t, normalize(d2 * (oa + t * rd) - ba * y));
    }
    
    let h1 = m3 * m3 - m5 + ra * ra;
    let h2 = m6 * m6 - m7 + rb * rb;
    if (max(h1, h2) < 0.0)
    {
        return vec4f(-1.0);
    }
    
    var result = vec4f(1e20, 0.0, 0.0, 0.0);
    if (h1 > 0.0)
    {
        t = -m3 - sqrt(h1);
        result = vec4f(t, (oa + t * rd) / ra);
    }
    if (h2 > 0.0)
    {
        t = -m6 - sqrt(h2);
        if (t < result.x)
        {
            result = vec4f(t, (ob + t * rd) / rb);
        }
    }
    
    return result;
}

fn getDistanceFromPlane(point: vec3f, plane: vec4f) -> f32
{
    return dot(plane.xyz, point) - plane.w;
}

fn calcBlinnPhong(toLight: vec3f, toEye: vec3f, normal: vec3f, 
                  diffFactor: vec3f, specFactor: vec3f, specShininess: f32) -> vec3f
{
    let nDotL = max(0.0, dot(normal, toLight));
    let h = normalize(toLight + toEye);
    let nDotH = max(0.0, dot(normal, h));
    let specPower = pow(nDotH, specShininess);
    let diffuse = diffFactor * nDotL;
    let specular = specFactor * specPower;
    return diffuse + specular;
}

fn calculateIllumination(eyePos: vec3f, fragPos: vec3f, 
                        fragNorm: vec3f, color: vec4f) -> vec3f
{
    let response = uniforms.materialLightResponse;
    let ambient = response.x * color.rgb;
    let diff = response.y * color.rgb;
    let spec = response.zzz;
    let shini = response.w;
    
    let ambientIllumination = ambient * uniforms.ambLightColor.rgb;
    let toLightDirWS = -uniforms.dirLightDirection.xyz;
    let toEyeNrmWS = normalize(eyePos - fragPos);
    let diffAndSpec = uniforms.dirLightColor.rgb * 
                      calcBlinnPhong(toLightDirWS, toEyeNrmWS, fragNorm, diff, spec, shini);
    
    return ambientIllumination + diffAndSpec;
}

fn packUnorm4x8(v: vec4f) -> u32
{
    let clamped = clamp(v, vec4f(0.0), vec4f(1.0));
    let scaled = vec4u(round(clamped * 255.0));
    return scaled.x | (scaled.y << 8u) | (scaled.z << 16u) | (scaled.w << 24u);
}

fn listPos(i: u32, coord: vec2i) -> u32
{
    let imgSize = vec3i(uniforms.kBufferInfo.xyz);
    return u32(coord.x + coord.y * imgSize.x) + i * u32(imgSize.x * imgSize.y);
}

fn insertIntoKBuffer(coord: vec2i, depth: f32, color: vec4f)
{
    let depthBits = bitcast<u32>(depth);
    let colorBits = packUnorm4x8(color);
    let K = u32(uniforms.kBufferInfo.z);
    
    for (var i = 0u; i < K; i++)
    {
        let depthIdx = listPos(i, coord) * 2u;
        let colorIdx = depthIdx + 1u;
        
        let oldDepth = atomicCompareExchangeWeak(&kBuffer[depthIdx], 0xFFFFFFFFu, depthBits);
        
        if (oldDepth.exchanged)
        {
            atomicStore(&kBuffer[colorIdx], colorBits);
            return;
        }
    }
    
    let lastIdx = listPos(K - 1u, coord) * 2u;
    let lastDepth = atomicLoad(&kBuffer[lastIdx]);
    
    if (depthBits < lastDepth)
    {
        atomicMin(&kBuffer[lastIdx], depthBits);
        atomicStore(&kBuffer[lastIdx + 1u], colorBits);
    }
}

@fragment
fn fs_main(input: VertexOutput) -> @location(0) vec4f
{
    let coord = vec2i(input.position.xy);
    let camWS = uniforms.camPos.xyz;
    let viewRayWS = normalize(input.viewRay);
    
    let tnor = iRoundedCone(camWS, viewRayWS, input.posA, input.posB, 
                           input.radiusA, input.radiusB);
    let posWsOnCone = camWS + viewRayWS * tnor.x;
    
    if (uniforms.billboardClippingEnabled != 0u)
    {
        if (tnor.x <= 0.0)
        {
            discard;
        }
        
        let plane1 = vec4f(input.n0, dot(input.posA, input.n0));
        let plane2 = vec4f(input.n1, dot(input.posB, input.n1));
        let dp1 = getDistanceFromPlane(posWsOnCone, plane1);
        let dp2 = getDistanceFromPlane(posWsOnCone, plane2);
        
        if (dp1 > 0.0 || dp2 > 0.0)
        {
            discard;
        }
    }
    
    var illumination = input.color.rgb;
    if (uniforms.billboardShadingEnabled != 0u)
    {
        illumination = calculateIllumination(camWS, posWsOnCone, tnor.yzw, input.color);
    }
    
    let color = vec4f(illumination * input.color.a, 1.0 - input.color.a);
    
    insertIntoKBuffer(coord, input.position.z, color);
    
    return vec4f(0.0, 0.0, 0.0, 0.0);
}