
// Soft Shadows implemented based on: https://medium.com/@alexander.wester/ray-tracing-soft-shadows-in-real-time-a53b836d123b
// by Alexander Wester

// Scene loaded with GLTF simply.

// Raytracing part based on basic sample starting project from Microsoft.

struct Payload
{
    bool missed;
};

cbuffer ConstantBuffer : register(b1)
{
    float3 cameraPos;
    float cameraFov;
    float3 cameraDir;
    float pad2;
    float3 cameraUp;
    float pad3;
    float3 cameraRight;
    float pad4;
    float3 lightPos;
    float pad5;
    float lightRadiusSphere;
    float time;
    float2 pad6;
    float4x4 invViewProj;
};

RaytracingAccelerationStructure scene : register(t0);
Texture2D<float> depthTex : register(t1);
RWTexture2D<float4> uav : register(u0);

static const float3 skyTop = float3(0.24, 0.44, 0.72);
static const float3 skyBottom = float3(0.75, 0.86, 0.93);

void GetBasis(float3 N, out float3 T, out float3 B)
{
    float3 up = abs(N.z) < 0.999 ? float3(0, 0, 1) : float3(1, 0, 0);
    T = normalize(cross(up, N));
    B = cross(N, T);
}

float InterleavedGradientNoise(float2 position_screen)
{
    float3 magic = float3(0.06711056, 0.00583715, 52.9829189);
    return frac(magic.z * frac(dot(position_screen, magic.xy)));
}

float3 worldPosFromDepth(float depth, float2 uv) {
    float4 clipPos;
    clipPos.x = uv.x * 2.0 - 1.0;
    clipPos.y = (uv.y * -2.0) + 1.0;
    clipPos.z = depth;
    clipPos.w = 1.0;

    float4 worldPos = mul(invViewProj, clipPos);
    return worldPos.xyz / worldPos.w;
}

[shader("raygeneration")]
void RayGeneration()
{
    uint2 idx = DispatchRaysIndex().xy;
    float2 size = DispatchRaysDimensions().xy;

    float2 uv = (idx + 0.5) / size;

    float depth = depthTex[idx];
    if (depth == 0.0) {
      uav[idx] = float4(skyBottom, 10000.0);
      return;
    }

    float3 worldPos = worldPosFromDepth(depth, uv);

    float3 lightVec = lightPos - worldPos;
    float3 lightDir = normalize(lightVec);

    float3 tangent, bitangent;
    GetBasis(lightDir, tangent, bitangent);

    float shadowAccumulator = 0.0;
    const int numSamples = 2; 
    
    float blueNoiseVal = InterleavedGradientNoise(float2(idx));
    
    float noiseRotation = 2.0 * 3.14159 * blueNoiseVal;
    
    const float goldenAngle = 2.39996323;
       
    for (int i = 0; i < numSamples; i++)
    {
        float r = sqrt(float(i) + 0.5) / sqrt(float(numSamples));
        float theta = float(i) * goldenAngle + noiseRotation;

        float2 diskPoint = float2(r * cos(theta), r * sin(theta));

        float3 lightSamplePos = lightPos + (tangent * diskPoint.x * lightRadiusSphere) + (bitangent * diskPoint.y * lightRadiusSphere);
        
        float3 shadowDir = normalize(lightSamplePos - worldPos);
        float shadowDist = length(lightSamplePos - worldPos);

        RayDesc shadowRay;
        shadowRay.Origin = worldPos + (shadowDir * 0.05);
        shadowRay.Direction = shadowDir;
        shadowRay.TMin = 0.0;
        shadowRay.TMax = shadowDist - 0.05;

        Payload shadowPayload;
        shadowPayload.missed = false;

        TraceRay(
            scene,
            RAY_FLAG_FORCE_OPAQUE | RAY_FLAG_ACCEPT_FIRST_HIT_AND_END_SEARCH | RAY_FLAG_SKIP_CLOSEST_HIT_SHADER,
            0xFF, 0, 0, 0, shadowRay, shadowPayload
        );

        if (shadowPayload.missed)
        {
            shadowAccumulator += 1.0;
        }
    }

    float shadowFactor = shadowAccumulator / float(numSamples);
    shadowFactor = max(shadowFactor, 0.1);
    uav[idx] = float4(shadowFactor, shadowFactor, shadowFactor, depth);
}

[shader("miss")]
void Miss(inout Payload payload)
{
    payload.missed = true;
}