/**
 * BillboardGenerator.wgsl
 * 
 * Compute shader that generates camera-facing billboard geometry from line segments.
 * Converts line vertices into quad billboards with radius, color, and clipping normals.
 * Uses atomic operations to track vertex count for indirect drawing.
 */

struct UniformData
{
    viewMatrix: mat4x4f,
    projMatrix: mat4x4f,
    camPos: vec4f,
    camDir: vec4f,
    clearColor: vec4f,
    helperLineColor: vec4f,
    kBufferInfo: vec4f,
    
    dirLightDirection: vec4f,
    dirLightColor: vec4f,
    ambLightColor: vec4f,
    materialLightResponse: vec4f,
    
    vertexColorMin: vec4f,
    vertexColorMax: vec4f,
    vertexAlphaBounds: vec4f,
    vertexRadiusBounds: vec4f,
    
    billboardClippingEnabled: u32,
    billboardShadingEnabled: u32,
    vertexColorMode: u32,
    vertexAlphaMode: u32,
    vertexRadiusMode: u32,
    
    vertexAlphaInvert: u32,
    vertexRadiusInvert: u32,
    dataMaxLineLength: f32,
    dataMaxVertexAdjacentLineLength: f32,
}

struct LineVertex
{
    position: vec4f,
    attr: vec4f,
}

struct BillboardVertex
{
    posWS: vec3f,
    radiusA: f32,
    posA: vec3f,
    radiusB: f32,
    posB: vec3f,
    pad1: f32,
    color: vec4f,
    n0: vec3f,
    pad2: f32,
    n1: vec3f,
    pad3: f32,
}

struct DrawIndirect
{
    vertexCount: atomic<u32>,
    instanceCount: u32,
    firstVertex: u32,
    firstInstance: u32,
}

@group(0) @binding(0) var<uniform> uniforms: UniformData;
@group(0) @binding(1) var<storage, read> lineVertices: array<LineVertex>;
@group(0) @binding(2) var<storage, read_write> billboardVertices: array<BillboardVertex>;
@group(0) @binding(3) var<storage, read_write> drawIndirect: DrawIndirect;

var<workgroup> localVertexCount: atomic<u32>;

fn constructBillboard(
    posA: vec3f, posB: vec3f,
    radA: f32, radB: f32,
    eyePos: vec3f, camDir: vec3f,
    posAPre: vec3f, posBSuc: vec3f,
    colA: vec4f, colB: vec4f,
    vertexIndex: u32
)
{
    var x0 = posA;
    var x1 = posB;
    var r0 = radA;
    var r1 = radB;
    
    if (r0 > r1)
    {
        x0 = posB;
        x1 = posA;
        r0 = radB;
        r1 = radA;
    }
    
    let d = x1 - x0;
    let d0 = eyePos - x0;
    let d1 = eyePos - x1;
    
    let u = normalize(cross(d, d0));
    let v0 = normalize(cross(u, d0));
    let v1 = normalize(cross(u, d1));
    
    let len0 = length(d0);
    let len1 = length(d1);
    
    let t0 = sqrt(len0 * len0 - r0 * r0);
    let s0 = r0 / t0;
    let r0s = len0 * s0;
    
    let t1 = sqrt(len1 * len1 - r1 * r1);
    let s1 = r1 / t1;
    let r1s = len1 * s1;
    
    let p0 = x0 + r0s * v0;
    let p1 = x0 - r0s * v0;
    let p2 = x1 + r1s * v1;
    let p3 = x1 - r1s * v1;
    
    let sm = max(s0, s1);
    let r0ss = len0 * sm;
    let r1ss = len1 * sm;
    
    let v = camDir;
    let w = cross(u, v);
    let a0 = dot(normalize(p0 - eyePos), normalize(w));
    let a2 = dot(normalize(p2 - eyePos), normalize(w));
    
    var ps = p0;
    var rs = r0ss;
    if (a0 <= a2)
    {
        ps = p0;
        rs = r0ss;
    }
    else
    {
        ps = p2;
        rs = r1ss;
    }
    
    let a1 = dot(normalize(p1 - eyePos), normalize(w));
    let a3 = dot(normalize(p3 - eyePos), normalize(w));
    
    var pe = p3;
    var re = r1ss;
    if (a1 <= a3)
    {
        pe = p3;
        re = r1ss;
    }
    else
    {
        pe = p1;
        re = r0ss;
    }
    
    let c0 = ps - rs * u;
    let c1 = ps + rs * u;
    let c2 = pe - re * u;
    let c3 = pe + re * u;
    
    var cx0 = posAPre;
    var cx1 = x0;
    var cx2 = x1;
    var cx3 = posBSuc;
    
    var start = 1.0;
    if (cx0.x < 0.0)
    {
        start = 0.0;
        cx0 = abs(cx0);
    }
    
    var end = 1.0;
    if (cx3.x < 0.0)
    {
        end = 0.0;
        cx3 = abs(cx3);
    }
    
    let n0 = -normalize(cx2 - cx0) * start;
    let n1 = normalize(cx3 - cx1) * end;
    
    let baseIdx = vertexIndex * 6u;
    
    billboardVertices[baseIdx + 0u].posWS = c0;
    billboardVertices[baseIdx + 0u].radiusA = radA;
    billboardVertices[baseIdx + 0u].posA = posA;
    billboardVertices[baseIdx + 0u].radiusB = radB;
    billboardVertices[baseIdx + 0u].posB = posB;
    billboardVertices[baseIdx + 0u].color = colA;
    billboardVertices[baseIdx + 0u].n0 = n0;
    billboardVertices[baseIdx + 0u].n1 = n1;
    
    billboardVertices[baseIdx + 1u].posWS = c1;
    billboardVertices[baseIdx + 1u].radiusA = radA;
    billboardVertices[baseIdx + 1u].posA = posA;
    billboardVertices[baseIdx + 1u].radiusB = radB;
    billboardVertices[baseIdx + 1u].posB = posB;
    billboardVertices[baseIdx + 1u].color = colA;
    billboardVertices[baseIdx + 1u].n0 = n0;
    billboardVertices[baseIdx + 1u].n1 = n1;
    
    billboardVertices[baseIdx + 2u].posWS = c2;
    billboardVertices[baseIdx + 2u].radiusA = radA;
    billboardVertices[baseIdx + 2u].posA = posA;
    billboardVertices[baseIdx + 2u].radiusB = radB;
    billboardVertices[baseIdx + 2u].posB = posB;
    billboardVertices[baseIdx + 2u].color = colB;
    billboardVertices[baseIdx + 2u].n0 = n0;
    billboardVertices[baseIdx + 2u].n1 = n1;
    
    billboardVertices[baseIdx + 3u].posWS = c2;
    billboardVertices[baseIdx + 3u].radiusA = radA;
    billboardVertices[baseIdx + 3u].posA = posA;
    billboardVertices[baseIdx + 3u].radiusB = radB;
    billboardVertices[baseIdx + 3u].posB = posB;
    billboardVertices[baseIdx + 3u].color = colB;
    billboardVertices[baseIdx + 3u].n0 = n0;
    billboardVertices[baseIdx + 3u].n1 = n1;
    
    billboardVertices[baseIdx + 4u].posWS = c1;
    billboardVertices[baseIdx + 4u].radiusA = radA;
    billboardVertices[baseIdx + 4u].posA = posA;
    billboardVertices[baseIdx + 4u].radiusB = radB;
    billboardVertices[baseIdx + 4u].posB = posB;
    billboardVertices[baseIdx + 4u].color = colA;
    billboardVertices[baseIdx + 4u].n0 = n0;
    billboardVertices[baseIdx + 4u].n1 = n1;
    
    billboardVertices[baseIdx + 5u].posWS = c3;
    billboardVertices[baseIdx + 5u].radiusA = radA;
    billboardVertices[baseIdx + 5u].posA = posA;
    billboardVertices[baseIdx + 5u].radiusB = radB;
    billboardVertices[baseIdx + 5u].posB = posB;
    billboardVertices[baseIdx + 5u].color = colB;
    billboardVertices[baseIdx + 5u].n0 = n0;
    billboardVertices[baseIdx + 5u].n1 = n1;
}

@compute @workgroup_size(64, 1, 1)
fn main(@builtin(global_invocation_id) globalID: vec3u)
{
    let lineIdx = globalID.x;
    let lineCount = arrayLength(&lineVertices);
    
    if (globalID.x == 0u)
    {
        atomicStore(&drawIndirect.vertexCount, 0u);
        drawIndirect.instanceCount = 1u;
        drawIndirect.firstVertex = 0u;
        drawIndirect.firstInstance = 0u;
    }
    
    workgroupBarrier();
    
    if (lineIdx + 1u >= lineCount)
    {
        return;
    }
    
    let v0 = lineVertices[lineIdx];
    let v1 = lineVertices[lineIdx + 1u];
    
    if (v0.position.x == v1.position.x && 
        v0.position.y == v1.position.y && 
        v0.position.z == v1.position.z)
    {
        return;
    }
    
    var vPre: vec3f;
    var vNext: vec3f;
    if (lineIdx == 0u)
    {
        vPre = vec3f(-1.0, 0.0, 0.0);
    }
    else
    {
        vPre = lineVertices[lineIdx - 1u].position.xyz;
    }
    
    if (lineIdx + 2u >= lineCount)
    {
        vNext = vec3f(-1.0, 0.0, 0.0);
    }
    else
    {
        vNext = lineVertices[lineIdx + 2u].position.xyz;
    }
    
    var colA = vec4f(1.0);
    var colB = vec4f(1.0);
    let curvature = v0.attr.y;
    
    if (uniforms.vertexColorMode == 0u)
    {
        colA = vec4f(uniforms.vertexColorMin.rgb, 1.0);
        colB = vec4f(uniforms.vertexColorMin.rgb, 1.0);
    }
    else if (uniforms.vertexColorMode == 1u)
    {
        colA = vec4f(mix(uniforms.vertexColorMin.rgb, uniforms.vertexColorMax.rgb, v0.attr.x), 1.0);
        colB = vec4f(mix(uniforms.vertexColorMin.rgb, uniforms.vertexColorMax.rgb, v1.attr.x), 1.0);
    }
    else if (uniforms.vertexColorMode == 2u)
    {
        let factor = distance(v0.position, v1.position) / uniforms.dataMaxLineLength;
        let col = mix(uniforms.vertexColorMin.rgb, uniforms.vertexColorMax.rgb, factor);
        colA = vec4f(col, 1.0);
        colB = vec4f(col, 1.0);
    }
    else if (uniforms.vertexColorMode == 3u)
    {
        colA = vec4f(mix(uniforms.vertexColorMin.rgb, uniforms.vertexColorMax.rgb, v0.attr.y), 1.0);
        colB = vec4f(mix(uniforms.vertexColorMin.rgb, uniforms.vertexColorMax.rgb, v1.attr.y), 1.0);
    }
    
    if (uniforms.vertexAlphaMode == 0u)
    {
        colA.a = uniforms.vertexAlphaBounds.x;
        colB.a = uniforms.vertexAlphaBounds.x;
    }
    else if (uniforms.vertexAlphaMode == 1u)
    {
        let dataA = select(v0.attr.x, 1.0 - v0.attr.x, uniforms.vertexAlphaInvert != 0u);
        let dataB = select(v1.attr.x, 1.0 - v1.attr.x, uniforms.vertexAlphaInvert != 0u);
        colA.a = mix(uniforms.vertexAlphaBounds.x, uniforms.vertexAlphaBounds.y, dataA);
        colB.a = mix(uniforms.vertexAlphaBounds.x, uniforms.vertexAlphaBounds.y, dataB);
    }
    else if (uniforms.vertexAlphaMode == 2u)
    {
        var factor = distance(v0.position, v1.position) / uniforms.dataMaxLineLength;
        factor = select(factor, 1.0 - factor, uniforms.vertexAlphaInvert != 0u);
        colA.a = mix(uniforms.vertexAlphaBounds.x, uniforms.vertexAlphaBounds.y, factor);
        colB.a = colA.a;
    }
    else if (uniforms.vertexAlphaMode == 3u)
    {
        let curvA = select(v0.attr.y, 1.0 - v0.attr.y, uniforms.vertexAlphaInvert != 0u);
        let curvB = select(v1.attr.y, 1.0 - v1.attr.y, uniforms.vertexAlphaInvert != 0u);
        colA.a = mix(uniforms.vertexAlphaBounds.x, uniforms.vertexAlphaBounds.y, curvA);
        colB.a = mix(uniforms.vertexAlphaBounds.x, uniforms.vertexAlphaBounds.y, curvB);
    }
    
    var radA = uniforms.vertexRadiusBounds.x;
    var radB = uniforms.vertexRadiusBounds.x;
    
    if (uniforms.vertexRadiusMode == 1u)
    {
        let dataA = select(v0.attr.x, 1.0 - v0.attr.x, uniforms.vertexRadiusInvert != 0u);
        let dataB = select(v1.attr.x, 1.0 - v1.attr.x, uniforms.vertexRadiusInvert != 0u);
        radA = mix(uniforms.vertexRadiusBounds.x, uniforms.vertexRadiusBounds.y, dataA);
        radB = mix(uniforms.vertexRadiusBounds.x, uniforms.vertexRadiusBounds.y, dataB);
    }
    else if (uniforms.vertexRadiusMode == 2u)
    {
        var factor = distance(v0.position, v1.position) / uniforms.dataMaxLineLength;
        factor = select(factor, 1.0 - factor, uniforms.vertexRadiusInvert != 0u);
        radA = mix(uniforms.vertexRadiusBounds.x, uniforms.vertexRadiusBounds.y, factor);
        radB = radA;
    }
    else if (uniforms.vertexRadiusMode == 3u)
    {
        let curvA = select(v0.attr.y, 1.0 - v0.attr.y, uniforms.vertexRadiusInvert != 0u);
        let curvB = select(v1.attr.y, 1.0 - v1.attr.y, uniforms.vertexRadiusInvert != 0u);
        radA = mix(uniforms.vertexRadiusBounds.x, uniforms.vertexRadiusBounds.y, curvA);
        radB = mix(uniforms.vertexRadiusBounds.x, uniforms.vertexRadiusBounds.y, curvB);
    }
    
    constructBillboard(
        v0.position.xyz, v1.position.xyz,
        radA, radB,
        uniforms.camPos.xyz, uniforms.camDir.xyz,
        vPre, vNext,
        colA, colB,
        lineIdx
    );
    
    atomicAdd(&drawIndirect.vertexCount, 6u);
}