#version 460
#pragma shader_stage(raygen)
#extension GL_EXT_ray_tracing : require

layout(set = 0, binding = 0) uniform accelerationStructureEXT tlas;
layout(set = 0, binding = 1, rgba8) uniform writeonly image2D outImage;

layout(push_constant) uniform PushConstants {
    mat4 viewInv;
    mat4 projInv;
} pc;

struct Payload {
    vec3 rayOrigin;
    vec3 rayDirection;
    vec3 color;
    float reflectionFactor;
};
layout(location = 0) rayPayloadEXT Payload payload;

const int MAX_BOUNCES = 5;

void main() {
    uvec2 pixel = gl_LaunchIDEXT.xy;
    uvec2 dim = gl_LaunchSizeEXT.xy;

    vec2 ndc = vec2(pixel) / vec2(dim) * 2.0 - 1.0;
    vec4 viewCoords = pc.projInv * vec4(ndc, -1.0, 1.0);
    vec3 rayOrigin = (pc.viewInv * vec4(0.0, 0.0, 0.0, 1.0)).xyz;
    vec3 rayDirection = (pc.viewInv * vec4(normalize(viewCoords.xyz), 0.0)).xyz;

    payload.rayOrigin = rayOrigin;
    payload.rayDirection = rayDirection;
    payload.color = vec3(0.0);
    payload.reflectionFactor = 1.0;
    int level = 0;

    while(level < MAX_BOUNCES && length(payload.rayDirection) > 0.1 && payload.reflectionFactor > 0.0) {
        traceRayEXT(
            tlas,
            gl_RayFlagsOpaqueEXT,
            0xFFu,                  // cull mask
            0,                      // sbtRecordOffset
            0,                      // sbtRecordStride
            0,                      // missIndex
            payload.rayOrigin,
            0.001,                  // tMin
            payload.rayDirection,
            1e30,                   // tMax
            0                       // payload location
        );
        level++;
    }

    vec4 finalColor = vec4(payload.color, 1.0);
    if (level <= 1) {
        finalColor = vec4(0.0);
    }
    imageStore(outImage, ivec2(pixel), finalColor);
}

