#version 460
#pragma shader_stage(closest)
#extension GL_EXT_debug_printf : enable
#extension GL_EXT_ray_tracing : require
#extension GL_EXT_buffer_reference : require
#extension GL_ARB_gpu_shader_int64 : require
#extension GL_EXT_scalar_block_layout : require

struct Vertex {
    vec3 pos;
    vec3 color;
    vec3 normal;
    vec2 uv; // only 44 bytes, need 4 more for alignment
    uint _pad;
};
layout(buffer_reference, scalar, buffer_reference_align = 16) readonly buffer VertexBuffer {
    Vertex data[];
};

layout(buffer_reference, std430, buffer_reference_align = 4) readonly buffer IndexBuffer {
    uint data[];
};


// Meshes buffer
struct Mesh {
    mat4 normalMat;
    uint64_t vertexBufferAddr;
    uint64_t indexBufferAddr;
    int materialIndex; // only 88 bytes, needs padding to 96 bytes (mod 16)
    uint64_t _pad;
};
layout(std430, set = 0, binding = 2) readonly buffer Meshes {
    Mesh meshes[];
};

// Materials buffer
struct Material {
    vec4 baseColorFactor;
    float metallicFactor;
    float roughnessFactor;
    int baseColorTex;
    int metallicRoughTex;
    int normalTex;
    uint _pad0;
    uint _pad1;
    uint _pad2;
};
layout(set = 0, binding = 3) readonly buffer Materials {
    Material materials[];
};

layout(constant_id = 0) const uint MAX_TEXTURES = 128;
layout(set = 0, binding = 4) uniform sampler2D textures[MAX_TEXTURES];

struct Light {
    vec4 color;
    vec3 pos;
    float intensity;
};
layout(set = 0, binding = 6) readonly buffer Lights {
    Light lights[];
};

layout(push_constant) uniform PushConstants {
    mat4 viewInv;
    mat4 projInv;
} pc;

struct Payload {
    vec3 rayOrigin;
    vec3 rayDirection;
    vec3 color;
    float reflectionFactor;
};
layout(location = 0) rayPayloadInEXT Payload payload;

hitAttributeEXT vec2 baryCoord;

Vertex getVertex(int baseIndex, mat4 normalMat, VertexBuffer vertices, IndexBuffer indices) {
    uvec3 is = uvec3(
        indices.data[baseIndex + 0],
        indices.data[baseIndex + 1],
        indices.data[baseIndex + 2]
    );

    Vertex v0 = vertices.data[is.x];
    Vertex v1 = vertices.data[is.y];
    Vertex v2 = vertices.data[is.z];

    vec3 barys = vec3(1.0f - baryCoord.x - baryCoord.y, baryCoord.x, baryCoord.y);

    Vertex res;
    res.pos = v0.pos * barys.x + v1.pos * barys.y + v2.pos * barys.z;
    res.color = v0.color * barys.x + v1.color * barys.y + v2.color * barys.z;
    res.normal = normalize(v0.normal * barys.x + v1.normal * barys.y + v2.normal * barys.z);
    res.uv = v0.uv * barys.x + v1.uv * barys.y + v2.uv * barys.z;

    res.normal = normalize(mat3(normalMat) * res.normal);
    res.pos = (gl_ObjectToWorld3x4EXT * res.pos).xyz;

    return res;
}

void main() {
    uint meshId = gl_InstanceCustomIndexEXT;
    Mesh mesh = meshes[meshId];

    VertexBuffer vertices = VertexBuffer(uint64_t(mesh.vertexBufferAddr));
    IndexBuffer indices = IndexBuffer(uint64_t(mesh.indexBufferAddr));

    Vertex v = getVertex(3 * gl_PrimitiveID, mesh.normalMat, vertices, indices);

    if (mesh.materialIndex < 0) {
        // NOTE: Technically the lighting is missing here, but there should be no mesh without a material
        payload.color = v.color;
        return;
    }

    Material m = materials[mesh.materialIndex];
    vec4 tmp_color = vec4(1.0);
    if (m.baseColorTex >= 0) {
        tmp_color = texture(textures[m.baseColorTex], v.uv);
    }
    vec4 color = vec4(v.color, 1.0) * tmp_color * m.baseColorFactor;

    float metallic = m.metallicFactor;
    float roughness = m.roughnessFactor;
    if (m.metallicRoughTex >= 0.0) {
        vec2 metallicRoughness = texture(textures[m.metallicRoughTex], v.uv).rg;
        metallic = clamp(metallicRoughness.r * metallic, 0.0, 1.0);
        roughness = clamp(metallicRoughness.g * roughness, 0.001, 1.0);
    }
    float shininess = max(2.0 / (roughness * roughness) - 2.0, 1.0);

    vec3 lighting = vec3(0.0);
    for (int i = 0; i < lights.length(); i++) {
        Light l = lights[i];
        vec3 lightColor = l.color.xyz;

        // NOTE: This is also hardcoded, fix later
        vec4 material = vec4(0.1, 0.7, 0.3, 32.0);

        vec3 lightDir = l.pos - v.pos;
        float dist = length(lightDir);
        float intensity = l.intensity / max(dist * dist, 0.001);
        lightDir /= dist;

        vec3 ambient = lightColor * material.x;

        vec3 diffuse = lightColor * material.y * max(dot(v.normal, lightDir), 0.0);

        vec3 viewDir = normalize(payload.rayOrigin - v.pos);
        vec3 reflectDir = reflect(-lightDir, v.normal);
        float spec = pow(max(dot(reflectDir, viewDir), 0.0), shininess);
        vec3 specular = lightColor * material.z * spec;

        lighting += color.xyz * ((ambient + diffuse + specular) * intensity);
    }

    payload.color += lighting * payload.reflectionFactor;
    payload.reflectionFactor *= metallic * (1.0 - roughness);

    vec3 hitPos = gl_WorldRayOriginEXT + gl_HitTEXT * gl_WorldRayDirectionEXT;
    payload.rayOrigin = hitPos + (v.normal * 0.01);
    payload.rayDirection = reflect(gl_WorldRayDirectionEXT, v.normal);
}

