#version 330 core

layout(location = 0) in vec3 aPos;
layout(location = 1) in vec3 aColor;
layout(location = 2) in vec3 aNormal;
layout(location = 3) in vec2 aTex;

out vec3 vColor;
out vec3 vNormal;
out vec3 vFragPos;
out vec2 vTex;
out float vDistanceToIsland;
out vec3 vWorldPos;

uniform mat4 model;
uniform mat4 camMatrix;
uniform float uTime;

const float PI = 3.14159265359;

// Simplex noise functions for more natural wave patterns
vec3 mod289(vec3 x) { return x - floor(x * (1.0 / 289.0)) * 289.0; }
vec2 mod289(vec2 x) { return x - floor(x * (1.0 / 289.0)) * 289.0; }
vec3 permute(vec3 x) { return mod289(((x*34.0)+1.0)*x); }

float snoise(vec2 v) {
    const vec4 C = vec4(0.211324865405187,  // (3.0-sqrt(3.0))/6.0
                        0.366025403784439,  // 0.5*(sqrt(3.0)-1.0)
                        -0.577350269189626, // -1.0 + 2.0 * C.x
                        0.024390243902439); // 1.0 / 41.0
    vec2 i  = floor(v + dot(v, C.yy));
    vec2 x0 = v -   i + dot(i, C.xx);
    vec2 i1;
    i1 = (x0.x > x0.y) ? vec2(1.0, 0.0) : vec2(0.0, 1.0);
    vec4 x12 = x0.xyxy + C.xxzz;
    x12.xy -= i1;
    i = mod289(i);
    vec3 p = permute(permute(i.y + vec3(0.0, i1.y, 1.0))
                   + i.x + vec3(0.0, i1.x, 1.0));
    vec3 m = max(0.5 - vec3(dot(x0,x0), dot(x12.xy,x12.xy), dot(x12.zw,x12.zw)), 0.0);
    m = m*m;
    m = m*m;
    vec3 x = 2.0 * fract(p * C.www) - 1.0;
    vec3 h = abs(x) - 0.5;
    vec3 ox = floor(x + 0.5);
    vec3 a0 = x - ox;
    m *= 1.79284291400159 - 0.85373472095314 * (a0*a0 + h*h);
    vec3 g;
    g.x  = a0.x  * x0.x  + h.x  * x0.y;
    g.yz = a0.yz * x12.xz + h.yz * x12.yw;
    return 130.0 * dot(m, g);
}

// Calculate distance to nearest island (approximation using known island positions)
float getDistanceToIsland(vec2 pos) {
    // Define approximate island center positions (adjust based on your scene)
    // TODO: Adjust these coordinates to match your actual island positions in the XZ plane
    // You can add more islands by adding more vec2 points and expanding the min() chain
    vec2 island1 = vec2(0.0, -15.0);
    vec2 island2 = vec2(-35.0, -30.0);
    vec2 island3 = vec2(25.0, -25.0);
    
    float dist1 = length(pos - island1);
    float dist2 = length(pos - island2);
    float dist3 = length(pos - island3);
    
    return min(min(dist1, dist2), dist3);
}

// Improved wave function with noise modulation
float wave(vec2 p, vec2 direction, float wavelength, float speed, float amplitude, float timeShift, float noiseInfluence)
{
    direction = normalize(direction);
    
    float k = 2.0 * PI / wavelength;
    float phase = dot(p, direction) * k + speed * uTime + timeShift;
    
    // Add noise to make waves more organic
    float noiseValue = snoise(p * 0.1 + uTime * 0.1) * noiseInfluence;
    
    return sin(phase + noiseValue) * amplitude;
}

// Calculate wave attenuation near islands
float getWaveAttenuation(vec2 pos) {
    float distToIsland = getDistanceToIsland(pos);
    
    // Smooth falloff: waves reduce as they approach islands
    // Full waves beyond 15 units, smooth reduction from 15 to 5 units
    float attenuation = smoothstep(5.0, 15.0, distToIsland);
    
    return attenuation;
}

vec3 displaceVertex(vec3 pos)
{
    vec2 p = pos.xz;
    float height = 0.0;
    
    // Calculate attenuation based on distance to islands
    float attenuation = getWaveAttenuation(p);
    
    // Base noise for organic variation
    float baseNoise = snoise(p * 0.05 + uTime * 0.02) * 0.02;
    height += baseNoise;
    
    // Gentle long wave with noise modulation
    height += wave(p, vec2(1.0, 0.2), 8.0, 0.6, 0.10 * attenuation, 0.0, 0.5);
    
    // Cross-wave from another direction
    height += wave(p, vec2(-0.3, 1.0), 5.0, 0.9, 0.06 * attenuation, 1.7, 0.3);
    
    // Smaller, faster ripples
    height += wave(p, vec2(0.7, -1.0), 2.5, 1.8, 0.03 * attenuation, 3.1, 0.4);
    
    // Additional turbulent detail
    height += wave(p, vec2(0.5, 0.8), 3.5, 1.2, 0.04 * attenuation, 2.5, 0.6);
    
    // Subtle radial waves from island edges (creating shore effect)
    float distToIsland = getDistanceToIsland(p);
    if (distToIsland < 20.0) {
        float shoreWave = sin(distToIsland * 1.5 - uTime * 2.0) * 0.015;
        shoreWave *= (1.0 - smoothstep(5.0, 20.0, distToIsland));
        height += shoreWave;
    }
    
    pos.y += height;
    return pos;
}


// Estimated normals using derivatives
vec3 computeNormal(vec3 pos)
{
    float eps = 0.1;
    vec3 p  = pos;
    vec3 px = pos + vec3(eps, 0, 0);
    vec3 pz = pos + vec3(0, 0, eps);

    px = displaceVertex(px);
    pz = displaceVertex(pz);
    p  = displaceVertex(p);

    vec3 dx = px - p;
    vec3 dz = pz - p;

    return normalize(cross(dz, dx));
}

void main()
{
    vec3 displaced = displaceVertex(aPos);
    vec3 normal    = computeNormal(aPos);

    vColor   = aColor;
    vNormal  = mat3(transpose(inverse(model))) * normal;
    vFragPos = vec3(model * vec4(displaced, 1.0));
    vTex     = aTex;
    vWorldPos = aPos;
    
    // Pass distance to island for depth/color effects in fragment shader
    vDistanceToIsland = getDistanceToIsland(aPos.xz);

    gl_Position = camMatrix * model * vec4(displaced, 1.0);
}
