#version 460
#extension GL_EXT_buffer_reference : require    // for raw addresses
#extension GL_ARB_gpu_shader_int64 : require    // 64‑bit ints
#extension GL_EXT_scalar_block_layout : require

layout(location = 0) in vec3 fragPos;
layout(location = 1) in vec3 fragColor;
layout(location = 2) in vec3 fragNormal;
layout(location = 3) in vec2 fragTexCoord;

layout(scalar, binding = 0) uniform UniformBuffer {
    mat4 viewProj;
    vec3 viewPos;
    vec4 reflectionOpts;
    float gamma;
} ubo;

// Materials buffer
struct Material {
    vec4 baseColorFactor;
    float metallicFactor;
    float roughnessFactor;
    int baseColorTex;
    int metallicRoughTex;
    int normalTex;
    uint _pad0;
    uint _pad1;
    uint _pad2;
};
layout(set = 0, binding = 1) readonly buffer Materials {
    Material materials[];
};

layout(constant_id = 0) const uint MAX_TEXTURES = 128;
layout(set = 0, binding = 2) uniform sampler2D textures[MAX_TEXTURES];

struct Light {
    vec4 color;
    vec3 pos;
    float intensity;
};
layout(set = 0, binding = 3) readonly buffer Lights {
    Light lights[];
};

layout(push_constant) uniform PushConstants {
    mat4 model;
    mat4 normal;
    int materialIndex;
} pc;

layout(set = 0, binding = 4, rgba8) uniform readonly image2D rtReflections;

layout(location = 0) out vec4 outColor;


vec4 lighting(vec3 n, vec3 pos, float shininess) {
    vec3 color = vec3(0.0);
    for (int i = 0; i < lights.length(); i++) {
        Light l = lights[i];
        vec3 lightColor = l.color.xyz;

        vec4 material = vec4(0.1, 0.7, 0.3, 32.0);

        vec3 lightDir = l.pos - pos;
        float dist = length(lightDir);
        float intensity = l.intensity / max(dist * dist, 0.001);
        lightDir /= dist;

        vec3 ambient = lightColor * material.x;

        vec3 diffuse = lightColor * material.y * max(dot(n, lightDir), 0.0);

        vec3 viewDir = normalize(ubo.viewPos - pos);
        vec3 reflect = reflect(-lightDir, n);
        float spec = pow(max(dot(reflect, viewDir), 0.0), shininess);
        vec3 specular = lightColor * material.z * spec;

        color += (ambient + diffuse + specular) * intensity;
    }

    return vec4(color, 1.0);
}

void main() {
    ivec2 pixel = ivec2(gl_FragCoord.xy);
    vec4 reflections = imageLoad(rtReflections, pixel);

    float reflectionMix = reflections.a;
    // Don't show any reflections
    if (ubo.reflectionOpts.x == 1.0) {
        reflectionMix = 0.0;
    }

    // Only show reflections
    if (ubo.reflectionOpts.y == 1.0) {
        outColor = reflections;
        return;
    }

    // outColor = color;
    // outColor = vec4(fragPos, 1.0);
    // outColor = vec4(fragTexCoord, 0.0, 1.0);
    // return;

    vec3 normal = normalize(fragNormal);

    if (pc.materialIndex < 0) {
        outColor = vec4(fragColor, 1.0) * lighting(normal, fragPos, 32.0);
        outColor = mix(outColor, reflections, reflectionMix);
        return;
    }

    Material m = materials[pc.materialIndex];

    float roughness = m.roughnessFactor;
    if (m.metallicRoughTex >= 0.0) {
        vec2 metallicRoughness = texture(textures[m.metallicRoughTex], fragTexCoord).rg;
        roughness = clamp(metallicRoughness.g * roughness, 0.001, 1.0);
    }

    float shininess = max(2.0 / (roughness * roughness) - 2.0, 1.0);

    if (m.baseColorTex < 0) {
        vec4 color = vec4(fragColor, 1.0) * m.baseColorFactor;
        outColor = color * lighting(normal, fragPos, shininess);
        outColor = mix(outColor, reflections, reflectionMix);
        return;
    }

    vec4 texColor = texture(textures[m.baseColorTex], fragTexCoord);
    vec4 color = vec4(fragColor, 1.0) * texColor * m.baseColorFactor;

    outColor = color * lighting(normal, fragPos, shininess);
    outColor = mix(outColor, reflections, reflectionMix);
    outColor = outColor + outColor * clamp(ubo.gamma, 0.0, 2.0); // simple gamma correction for demo presentation
    // outColor = color * vec4((specular) * attenuation, 1.0);
    // outColor = color * vec4((diffuse) * attenuation, 1.0);
    // outColor = color * vec4((ambient) * attenuation, 1.0);
}
