#version 330 core
out vec4 FragColor;

in vec3 color;
in vec3 Normal;
in vec3 currentPosition;
in vec2 textCoordinates;
in vec4 FragPosLightSpace;

// ===== TEXTURE SAMPLER =====
uniform sampler2D diffuse0;
uniform sampler2D shadowMap;

// ===== LIGHT UNIFORMS =====
uniform vec4 lightColor;
uniform vec3 camPos;
uniform vec3 lightPos;
uniform vec3 pointLightPos;
uniform bool bothLight;


// ===== Shadow map ======

float ShadowCalculation(vec4 fragPosLightSpace, vec3 normal, vec3 lightDir){
    vec3 projCoords = fragPosLightSpace.xyz / fragPosLightSpace.w;
    projCoords = projCoords * 0.5 + 0.5;

    if (projCoords.x < 0.0 || projCoords.x > 1.0 || projCoords.y < 0.0 || projCoords.y > 1.0 || projCoords.z < 0.0 || projCoords.z > 1.0)
        return 0.0;

    float bias = max(0.003 * (1.0 - dot(normal, lightDir)), 0.0005);

    float shadow = 0.0;
    vec2 texelSize = 1.0 / vec2(textureSize(shadowMap, 0));

    // Change the loops to -2 to 2 for a 5x5 kernel
    for (int x = -2; x <= 2; ++x)
    {
        for (int y = -2; y <= 2; ++y)
        {
            float depth = texture(shadowMap, projCoords.xy + vec2(x,y) * texelSize).r;
            shadow += projCoords.z - bias > depth ? 1.0 : 0.0;
        }
    }
    return shadow / 25.0; // Divide by total samples (5 * 5)
}

// ===== LIGHT FUNCTIONS =====
vec3 computeDirectionalLight(vec3 baseColor, vec4 fragPosLightSpace)
{
    vec3 normal = normalize(Normal);
    vec3 lightDir = normalize(lightPos);

    float diff = max(dot(normal, lightDir), 0.0);

    // Reduced specular for palette textures to preserve colors
    vec3 viewDir = normalize(camPos - currentPosition);
    vec3 reflectDir = reflect(-lightDir, normal);
    float specAmount = pow(max(dot(viewDir, reflectDir), 0.0), 32.0);
    float spec = specAmount * 0.1;

    float shadow = ShadowCalculation(
        fragPosLightSpace,
        normal,
        -lightDir
    );

    float lighting = (1.0 - shadow) * (diff + spec);

    return baseColor * lighting * lightColor.rgb;
}

vec3 computePointLight(vec3 baseColor)
{
    vec3 lightVec = pointLightPos - currentPosition;
    float dist = length(lightVec);

    float intensity = 1.0 / (3.0 * dist * dist + 0.7 * dist + 1.0);

    vec3 normal = normalize(Normal);
    vec3 lightDir = normalize(lightVec);

    float diff = max(dot(normal, lightDir), 0.0);

    vec3 viewDir = normalize(camPos - currentPosition);
    vec3 reflectDir = reflect(-lightDir, normal);
    float spec = pow(max(dot(viewDir, reflectDir), 0.0), 16.0);

    return (baseColor * diff + spec * 0.05) * intensity * lightColor.rgb;
}

void main()
{
    // ---- Get base color from palette texture ----
    vec3 base = texture(diffuse0, textCoordinates).rgb;

    // ---- Ambient ----
    vec3 ambient = base * 0.20;

    // ---- Lights ----
    vec3 resultLight;

    if (bothLight)
        resultLight = computeDirectionalLight(base, FragPosLightSpace) + computePointLight(base) * 20.0;
    else
        resultLight = computeDirectionalLight(base, FragPosLightSpace);

    FragColor = vec4(ambient + resultLight, 1.0);
}
