#version 420

#define DEBUG
#define DOUBLE
#define DO_FRAME
#define NOBIPATCH

#ifdef DOUBLE
#define Float double
#define Vec2 dvec2
#define Vec3 dvec3
#define Vec4 dvec4
#else
#define Float float
#define Vec2 vec2
#define Vec3 vec3
#define Vec4 vec4
#endif

uniform mat4 MVPMatrix; //model-view-projection matrix
uniform mat4 ModelMatrix; //model matrix
uniform mat4 ModelMatrixIT; //model matrix Inverse Transpose
uniform vec3 eyePosMs; //model-space eye position
uniform sampler2D heightfield;
uniform sampler2D textureNormals;
uniform sampler2D maxMipMap;
uniform ivec3 mipMapSize; //xy = size of maxMipMap, z= highest mip-map level (of maxMipMap)
uniform int minLevelD; //The minimum level to sample (mainly for debugging)
uniform float zScale; //difference of z-coordinate scale compared to x and y, i.e. approx. how many texel the height-field is high.
uniform samplerCube skybox;
uniform float heightfieldOffset; //offset added to heightfield
uniform float heightfieldScale; //scaling factor of heightfield

uniform bool patchNormals; //if true, computes exact normals of the patches (feceted look), else uses the normal-map (smooth shading).
uniform bool visBorders; //if true, displays the crossed borders as red-green walls

uniform float reflectionBlurryness;
uniform vec3 reflectionCol;
uniform vec3 ambientCol;
uniform float shininess;
uniform vec3 colorBase;
uniform float colorFromHeight;

const int maxLights = 2;
uniform int lightCount;
uniform vec3 lightPos[maxLights];
uniform vec3 lightDiffuseCol[maxLights];
uniform vec3 lightSpecularCol[maxLights];


in vec4 fragmentVertex; //just the interpolated vertex coords of the rendered fragment of the height field [0,1]^3.

out vec4 fragColor;

const Float ftol = 0.00;

#ifdef DOUBLE
const Float tol = 0.000000001;
#else
const Float tol = 0.00001;
#endif

const int maxI = 200; //maximum number of iterations

float AdjustHeightfield(float v)
{
    return v * heightfieldScale + heightfieldOffset;
}

float MaxMipMap(ivec2 cell, int level)
{
    ivec2 c = clamp(cell, ivec2(0,0), textureSize(maxMipMap, level)-1);
    return AdjustHeightfield(texelFetch(maxMipMap, c, level).x);
}

float HeightField(ivec2 cell)
{
    ivec2 c = clamp(cell, ivec2(0,0), mipMapSize.xy);
    return AdjustHeightfield(texelFetch(heightfield, c, 0).x);
}

//shading of intersection
void shade(vec3 pos, vec3 normal, float blur)
{
    vec3 posn = pos;
    pos.xy /= mipMapSize.xy;

    vec3 col = minLevelD < 0 ?
                texture(heightfield, pos.xy - .01).zyx : 
                texelFetch(maxMipMap, ivec2(posn.xy / (1 << minLevelD) - .01), minLevelD).xxx;


    vec4 tmp = ModelMatrix * vec4(pos, 1);
    pos = tmp.xyz / tmp.w;

    normal = (ModelMatrixIT * vec4(normal, 0)).xyz;
    normal = normalize(normal);

    vec3 eye = eyePosMs;
    eye.xy /= mipMapSize.xy;
    tmp = ModelMatrix * vec4(eye, 1);
    eye = tmp.xyz / tmp.w;

    vec3 view = normalize(eye - pos);
    col = colorBase + colorFromHeight * col;

    vec3 reflectionSB = reflect(-view, normal);
    vec3 reflectedColor = textureLod(skybox, reflectionSB, reflectionBlurryness + blur).xyz; //set here for more blurry reflection

    vec3 colTotal = ambientCol * col + reflectionCol * reflectedColor;

    for (int i = 0; i < lightCount; i++)
    {
        const vec3 light = normalize(pos - lightPos[i]);
        float diffuseIntensity = max(dot(normal, light), 0.0);
        vec3 reflection = normalize(reflect(-light, normal));
        float specularItensity = max(0,pow(dot(reflection, view), (shininess + blur)));
        colTotal += lightDiffuseCol[i] * col * diffuseIntensity + lightSpecularCol[i] * specularItensity;
    }

    fragColor = vec4(colTotal, 1);
}

//shading when started inside of relief
void inside(vec3 pos)
{
    fragColor = vec4(1.0, 0, 1.0, 1.0);
}

/// Solves the quadratic equation a x^2 + b x + c == 0, and returns the smallest solution that is between xMin and xMax in x. Returns if it has a solution.
bool SolveQuadratic(Float a, Float b, Float c, Float xMin, Float xMax, out Float x)
{
    if (-tol <= a && a <= tol)
    { //linear equation
        if (-tol <= b && b <= tol)
            return false; //no or full solution (in bi-patch: patch is planar and parallel to ray)
        Float x0 = -c / b;
        if (xMin <= x0 && x0 <= xMax)
        {
            x = x0;
            return true;
        }
        return false;
    }
    //quadratic equation
    Float q = b * b - 4 * a * c;
    if (q < -tol)
        return false; //imaginary solution
    q = sqrt(max(0,q));
    //both solutions
    Float x0 = (-b - q) / (2 * a);
    Float x1 = (-b + q) / (2 * a);
    //sort ascending
    if (x1 < x0)
    {
        Float tmp = x0;
        x0 = x1;
        x1 = tmp;
    }
    //test first solution
    if (xMin <= x0 && x0 <= xMax)
    {
        x = x0;
        return true;
    }
    //test second solution
    if (xMin <= x1 && x1 <= xMax)
    {
        x = x1;
        return true;
    }
    return false; //solutions exist, but are not inside the requested interval
}


// void renderSkybox(vec3 dir){
    
    // vec4 environmentColor = texture(skybox, dir);
    // fragColor = environmentColor;
    
// }


void main(void)
{
    Float tolZ = tol / zScale;
    int minLevel = max(minLevelD, 0);

    //compute fragmentPos in Height-Field Coordinates (hf-coords: xy=pixel in highest mip-map-level, z = height [0,1])
    Vec3 fragmentPos = fragmentVertex.xyz / fragmentVertex.w;
    fragmentPos.xy *= mipMapSize.xy;

    Vec3 start = eyePosMs; //gl_FrontFacing ? fragmentPos : eyePosMs; //start of ray in hf-coords (3D-position of rendered Bounding Box if outside, or eye position if inside)
    Vec3 dir = normalize(fragmentPos - eyePosMs); //direction of ray in hf-coords
    Vec3 dirI = 1 / dir; //speedup for intersection computation (if division is slower than multiplication).
    ivec2 borderDir = ivec2(dir.x > tol ? 1 : 0, dir.y > tol ? 1 : 0); //add this to cellL to go to the next border
    Float t = 0; //current t-value along ray, monotonically increasing throughout program.

#ifdef DEBUG
    if((-tol <= dir.x && dir.x <= tol) || (-tol <= dir.y && dir.y <= tol) || (-tolZ <= dir.z && dir.z <= tolZ))
    {
        fragColor = vec4(0,1,0,1); //error GREEN: dir is singular
        return;
    }
#endif
    //fix instabilities due to div by zero
    if (-tol <= dir.x && dir.x <= tol)
    {
        if(dir.x < 0)
            dirI.x = -1 / tol;
        else
            dirI.x =  1 / tol;
    }
    if (-tol <= dir.y && dir.y <= tol)
    {
        if(dir.y < 0)
            dirI.y = -1 / tol;
        else
            dirI.y =  1 / tol;
    }
    if (-tolZ <= dir.z && dir.z <= tolZ)
    {
        if(dir.z < 0)
            dirI.z = -1 / tolZ;
        else
            dirI.z =  1 / tolZ;
    }

    //ensure that start is inside the volume: intersect with all 6 faces and select smallest t
    if (dir.x > 0 && start.x < 0)
        t = max(t, dirI.x * (0 - start.x));
    if (dir.x < 0 && start.x > mipMapSize.x)
        t = max(t, dirI.x * (mipMapSize.x - start.x));
    if (dir.y > 0 && start.y < 0)
        t = max(t, dirI.y * (0 - start.y));
    if (dir.y < 0 && start.y > mipMapSize.y)
        t = max(t, dirI.y * (mipMapSize.y - start.y));
    if (dir.z > 0 && start.z < 0)
        t = max(t, dirI.z * (0 - start.z));
    if (dir.z < 0 && start.z > 1)
        t = max(t, dirI.z * (1 - start.z));
    bool outside = t > 0;
    start += t * dir;
    t = 0;

#ifdef DO_FRAME
    //check if we hit one of the border walls from the outside
    float z = AdjustHeightfield(texture(heightfield, vec2(start.xy + 0.5) / (mipMapSize.xy + 1)));

    if(start.z <= z - tol)
    {
        if(outside)
        {
            if(dir.x > 0 && start.x <= tol) //hit left relief wall
            {
                shade(vec3(start), vec3(-1,0,0), 0);
                return;
            }
            if(dir.x < 0 && start.x >= mipMapSize.x- tol) //hit right relief wall
            {
                shade(vec3(start), vec3(1,0,0), 0);
                return;
            }
            if(dir.y > 0 && start.y <= tol) //hit top relief wall
            {
                shade(vec3(start), vec3(0,-1,0), 0);
                return;
            }
            if(dir.y < 0 && start.y >= mipMapSize.y- tol) //hit bottom relief wall
            {
                shade(vec3(start), vec3(0,1,0), 0);
                return;
            }
            if(dir.z > 0 && start.z <= tol) //hit floor relief wall
            {
                shade(vec3(start), vec3(0,0,-1), 0);
                return;
            }
        }
        inside(vec3(start));//we obviously start inside the relief
        return;
    }
#endif

    ivec2 cell0 = ivec2(floor(start.xy + tol * dir.xy)); //cell in lowest level; for robustness go a little "inside".
    int level = mipMapSize.z; //current mipmapLevel. Starts at second-highest level, since highest level is already checked by bounding volume

    Float tEnd = 0; //the exit-t value regarding the z-plane
    if (dir.x < 0)
        tEnd = max(tEnd, dirI.x * (0 - start.x));
    else
        tEnd = max(tEnd, dirI.x * (mipMapSize.x - start.x));
    if (dir.y < 0)
        tEnd = max(tEnd, dirI.y * (0 - start.y));
    else
        tEnd = max(tEnd, dirI.y * (mipMapSize.y - start.y));
    if (dir.z < 0)
        tEnd = max(tEnd, dirI.z * (0 - start.z));
    else
        tEnd = max(tEnd, dirI.z * (1 - start.z));

    int i = 0; //prevent infinite loops
    while (t < tEnd && i < maxI)
    {
        i++;
        int level2pow = 1 << level; //2^level
        ivec2 cellL = ivec2(floor(cell0 / level2pow)); //current cell in current level

        //t1 = find outgoing Border intersection of cellL (and the border-Id)
        ivec2 border0u = (cellL + borderDir) * level2pow; //unclamped border of cell according to walk-direction (in lowest level).
        ivec2 border0 = min(border0u, mipMapSize.xy); //required for non-power-of-2
        Vec2 txy = dirI.xy * (border0 - start.xy); //intersection ts of ray with border in x and y.
        Float t1; //end-t of current cell
        int borderId; //the "number" of the border at the end of the cell in walk-direction.
        if (txy.x < txy.y)
        { //next cell is in x-direction
            t1 = txy.x;
            borderId = border0u.x;
        }
        else
        { //next cell is in y-direction
            t1 = txy.y;
            borderId = border0u.y;
        }

        Float z = MaxMipMap(cellL, level); //height of the current cell
        z = clamp(z, ftol, 1-ftol);
        Float tLow = dir.z < 0 ? t1 : t; //where at the ray inside the cell we expect the lowest z-value
        if (start.z + dir.z * tLow <= z + tolZ)
        { //the ray crosses the cell (at least partially) under it's z-plane
            if (level == minLevel)
            { //time to intersect the bi-linear patch
#ifdef NOBIPATCH
                if (minLevelD > -1)
                {
                    Vec3 p = start + dir * t;
                    Vec3 cp = (p - round(p)) * 10;
                    vec3 n;
                    if (-tol < cp.x && cp.x < tol)
                        n = vec3(borderDir.x, 0, 0);
                    else if (-tol < cp.y && cp.y < tol)
                        n = vec3(0, borderDir.y, 0);
                    else
                        n = vec3(0, 0, 1);
                    shade(vec3(p), n, 0);
                    //shade(vec3(start + tLow * dir), vec3(0,0,1), 0);
                    //shade(vec3(cell0 + 1, 0), vec3(0,0,1), 0);
                    //fragColor = vec4(float(i)/maxI,0,0,1);
                    //fragColor = vec4(t,t1,0,10)/10;
                    //fragColor = vec4((start+dir*t1)/Vec3(mipMapSize.xy,1), 1);
                    return;
                }
#endif
                //sample bi-linear patch.
                //vec4 p = textureGather(heightfield, cell0);
                //note: pyx
                Float p00 = HeightField(cell0 + ivec2(0,0)); //p.w;
                Float p01 = HeightField(cell0 + ivec2(1,0)); //p.x;
                Float p10 = HeightField(cell0 + ivec2(0,1)); //p.z;
                Float p11 = HeightField(cell0 + ivec2(1,1)); //p.y;
                p00 = clamp(p00, ftol, 1-ftol);
                p01 = clamp(p01, ftol, 1-ftol);
                p10 = clamp(p10, ftol, 1-ftol);
                p11 = clamp(p11, ftol, 1-ftol);

                //test bi-linear patch.
                Float sx = start.x - cell0.x;
                Float sy = start.y - cell0.y;
                Float v0 = p00 - p01;
                Float v1 = p00 - p10;
                Float v2 = v0 - p10 + p11;
                Float a = v2 * dir.x * dir.y;
                Float b = dir.x * (v2 * sy - v0) + dir.y * (v2 * sx - v1) - dir.z;
                Float c = p00 + sx * (sy * v2 - v0) - sy * v1 - start.z;
                //solve quadratic equation a t^2 + b t + c == 0 -> t = (-b +/- Sqrt(b^2 - 4 ac)) / (2 a)
                if (SolveQuadratic(a, b, c, t-tol, t1+tol, t))
                { //hit found at t
                    vec3 pos = vec3(start + t * dir); //update pos to the intersection
                    vec3 normal;
                    float blur;
                    if (patchNormals)
                    {
                        //normal computation of bi-patch
                        float x = pos.x - cell0.x;
                        float y = pos.y - cell0.y;
                        normal = vec3(v0 - v2 * y, v1 - v2 * x, 1); //compute normal
                        blur = 0;
                    }
                    else
                    { //fetch from normal map (better smoothing possible)
                        vec4 n = texture(textureNormals, pos.xy / mipMapSize.xy);
                        normal = n.xyz;
                        normal.z /= heightfieldScale;
                        blur = n.w;
                    }

                    shade(pos, normal, blur);
                    //gl_fragDepth = ; //todo: output correct depth for compositing
                    //fragColor = 10*vec4(float(i)/maxI, float(i)/maxI, float(i)/maxI, 1); //debug: visualize iteration count
                    return;
                }
                //else: advance to next cell (shares code at end of enclosing if).
            }
            else //level > 0
            {
                level--;
                Float tz = dir.z >= -tolZ ? +1/0 : dirI.z * (z - start.z);; //+Inf OR find intersection ray with z-plane
                if (t < tz && tz < t1) //tz inside ]t,t1[
                {
                    t = tz;
                    cell0 = ivec2(floor(start.xy + dir.xy * t));
                }
                //no z-intersection but below z -> keep t and cell0
                continue;
            }
        }
        //above z or no patch-intersection -> advance to border (on missed patch-intersection this is just the next cell)

#ifdef NOBIPATCH
        if (visBorders)
        {
            fragColor = vec4(float(level) / mipMapSize.z,1 - (float(level) / mipMapSize.z), 0, 1); return;
        }
#endif

        t = t1;
        cell0 = ivec2(floor(start.xy + dir.xy * (t + tol))); //advance cell0 PAST the border(s)
        level = findLSB(borderId);
        if (mipMapSize.z < level || level < 0) //hit outside border
            break; //while
    }
#ifdef DEBUG
    if (i == maxI)
        fragColor = vec4(1,0,0,1); //error RED: too much iterations
    else
#endif

    // renderSkybox(dir);
    discard; //fragment
}
