Texture2D<float4> input : register(t1);
RWTexture2D<float4> output : register(u0); 

float GetSpatialWeight(int x, int y, float sigma)
{
    float distSq = float(x * x + y * y);
    return exp(-(distSq) / (2.0 * sigma * sigma));
}

float GetDepthWeight(float centerDepth, float neighborDepth, float sigma)
{
    return exp(-abs(centerDepth - neighborDepth) / sigma);
}

[numthreads(16, 16, 1)] 
void DenoiseCS(uint3 dispatchThreadId : SV_DispatchThreadID)
{
    uint2 pos = dispatchThreadId.xy;
    
    float width, height;
    output.GetDimensions(width, height);
    
    if (pos.x >= width || pos.y >= height)
        return;

    float4 centerVal = input[pos];
    float3 centerColor = centerVal.rgb;
    float centerDepth = centerVal.a; 

    float3 sumColor = centerColor; 
    float sumWeight = 1.0; 

    const int radius = 4;

    const float sigmaSpatial = 6.0;

    const float sigmaDepth = 0.1;

    [unroll]
    for (int y = -radius; y <= radius; y++)
    {
        [unroll]
        for (int x = -radius; x <= radius; x++)
        {
            if (x == 0 && y == 0)
                continue;

            int2 neighborPos = pos + int2(x, y);
            
           
            if (neighborPos.x < 0 || neighborPos.y < 0 || neighborPos.x >= width || neighborPos.y >= height) 
                continue;

            float4 neighborVal = input[neighborPos];
            float3 neighborColor = neighborVal.rgb;
            float neighborDepth = neighborVal.a;

            float wSpatial = GetSpatialWeight(x, y, sigmaSpatial);
            float wDepth = GetDepthWeight(centerDepth, neighborDepth, sigmaDepth);
            
            float weight = wSpatial * wDepth;

            sumColor += neighborColor * weight;
            sumWeight += weight;
        }
    }

    
    output[pos] = float4(sumColor / max(sumWeight, 0.0001), 1.0);
}