/**
 * SortKBuffer.wgsl
 * 
 * Compute shader that sorts transparent fragments in the K-buffer by depth.
 * Uses insertion sort to order fragments from front to back for each pixel.
 *
 * This was our approach of trying to resolve the flickering which is happening
 * the more overlap between tubes there is. It does mitigate the effect to some degree
 * but it still happens. Interestingly it also happens more when your GPU is at higher usage
 * by other programs. Not entirely sure how to properly fix that, maybe upping the kbuffer to 32 would probably
 * help a lot.
 */

struct UniformData
{
    viewMatrix: mat4x4f,
    projMatrix: mat4x4f,
    camPos: vec4f,
    camDir: vec4f,
    clearColor: vec4f,
    helperLineColor: vec4f,
    kBufferInfo: vec4f,
    
    dirLightDirection: vec4f,
    dirLightColor: vec4f,
    ambLightColor: vec4f,
    materialLightResponse: vec4f,
    
    vertexColorMin: vec4f,
    vertexColorMax: vec4f,
    vertexAlphaBounds: vec4f,
    vertexRadiusBounds: vec4f,
    
    billboardClippingEnabled: u32,
    billboardShadingEnabled: u32,
    vertexColorMode: u32,
    vertexAlphaMode: u32,
    vertexRadiusMode: u32,
    
    vertexAlphaInvert: u32,
    vertexRadiusInvert: u32,
    dataMaxLineLength: f32,
    dataMaxVertexAdjacentLineLength: f32,
}

@group(0) @binding(0) var<uniform> uniforms: UniformData;
@group(0) @binding(1) var<storage, read_write> kBuffer: array<u32>;

fn listPos(i: u32, coord: vec2i) -> u32
{
    let imgSize = vec3i(uniforms.kBufferInfo.xyz);
    return u32(coord.x + coord.y * imgSize.x) + i * u32(imgSize.x * imgSize.y);
}

@compute @workgroup_size(16, 16, 1)
fn main(@builtin(global_invocation_id) dispatchID: vec3u)
{
    let coord = vec2i(dispatchID.xy);
    let imgSize = vec3i(uniforms.kBufferInfo.xyz);

    if (coord.x >= imgSize.x || coord.y >= imgSize.y)
    {
        return;
    }

    let K = u32(imgSize.z);
    
    var depths: array<u32, 32>;
    var colors: array<u32, 32>;
    
    for (var i = 0u; i < K; i++)
    {
        let idx = listPos(i, coord) * 2u;
        depths[i] = kBuffer[idx];
        colors[i] = kBuffer[idx + 1u];
    }
    
    for (var i = 1u; i < K; i++)
    {
        let keyDepth = depths[i];
        let keyColor = colors[i];
        var j = i;
        
        if (keyDepth == 0xFFFFFFFFu)
        {
            continue;
        }
        
        while (j > 0u && depths[j - 1u] > keyDepth)
        {
            depths[j] = depths[j - 1u];
            colors[j] = colors[j - 1u];
            j = j - 1u;
        }
        
        depths[j] = keyDepth;
        colors[j] = keyColor;
    }
    
    for (var i = 0u; i < K; i++)
    {
        let idx = listPos(i, coord) * 2u;
        kBuffer[idx] = depths[i];
        kBuffer[idx + 1u] = colors[i];
    }
}