#version 450 core

--Vertex

#extension GL_ARB_explicit_attrib_location : enable

layout(location = 0) in vec3 vertex;
//layout(location = 1) in int planeId;
uniform mat4 viewProjMatrix;
uniform mat4 projMatrix;
uniform mat4 viewMatrix;

//out int planeIdVert;
void main()
{
	gl_Position = vec4(vertex, 1.0);
}

--Geometry // create light planes here

layout(points) in;									// reference point from where a (z value) light plane is generated
layout(triangle_strip, max_vertices = 4) out;


uniform mat4 viewProjMatrix;
uniform mat4 viewMatrix;
uniform mat4 inverseViewMatrix;
uniform mat4 projMatrix;
uniform vec3 cameraPos;
uniform int planeID;
uniform float planeZ;

out vec3 TexCoord;

// for each vertex in the vertex shader, a triangle strip (illum plane) is generated with different depth values
void main()
{
	gl_Layer = planeID % 2;
	vec3 up = vec3(0.0, 1.0, 0.0);
	vec3 right = vec3(1.0, 0.0, 0.0);
	vec3 Pos = gl_in[0].gl_Position.xyz;
	Pos.z = planeZ;

	Pos -= (right * 1.0);
	Pos -= (up * 1.0);
	gl_Position = vec4(Pos.xy, 1.0, 1.0);
	TexCoord = vec3(0.0, 0.0, Pos.z);

	EmitVertex();

	Pos += up * 2.0;
	gl_Position = vec4(Pos.xy, 1.0, 1.0);
	TexCoord = vec3(0.0, 1.0, Pos.z);

	EmitVertex();

	Pos -= up * 2.0;
	Pos += right * 2.0;
	gl_Position = vec4(Pos.xy, 1.0, 1.0);
	TexCoord = vec3(1.0, 0.0, Pos.z);

	EmitVertex();

	Pos += up * 2.0;
	gl_Position = vec4(Pos.xy, 1.0, 1.0);
	TexCoord = vec3(1.0, 1.0, Pos.z);

	EmitVertex();

	EndPrimitive();
}

--Fragment
in vec3 TexCoord;
//in int gl_Layer;

uniform float sample_distance;
uniform sampler3D volume;
uniform sampler2D preintTableMedium;
uniform sampler2D preintTableColor;

uniform sampler2DArray color_buffer;
uniform sampler2DArray medium_color_buffer;
uniform sampler2DArray viewray_pos_buffer;
uniform sampler2DArray viewray_dir_buffer;
uniform sampler2DArray light_color_buffer;
uniform sampler2DArray light_dir_buffer;
uniform sampler2DArray debug_buffer;

//
uniform bool scatteringEnabled;
uniform bool filteringEnabled;
uniform bool intensityCorrectionEnabled;
uniform bool refractionEnabled;

layout(location = 0) out vec4 color_texture_write;
layout(location = 1) out vec4 light_color_texture_write;
layout(location = 2) out vec4 light_dir_texture_write;
layout(location = 3) out vec4 medium_color_texture_write;
layout(location = 4) out vec4 viewray_dir_texture_write;
layout(location = 5) out vec4 viewray_pos_texture_write;
layout(location = 6) out vec4 debug_texture_write;

uniform mat4 inverseViewMatrix;
uniform mat4 viewMatrix;
uniform mat4 inverseViewProjMatrix;
uniform mat4 projMatrix;

uniform float camera_zoom;

// global variables
vec3 Ci_out;
float Ai_out;
vec4 vp_out;
vec4 vd_out;
vec3 Mi_out;
vec4 Li_out;
vec4 ldi_out;
vec4 debug_out;

//Function declarations
vec3 intersectRayPlane();
void lightPropagationNew(vec3 pos, int readingLayer);
void viewingRayPropagationNew(vec3 pos, int readingLayer);
vec3 elliptical_filter(vec3 lp_iprev);
float sampleVolume(vec3 pos, float offset);
vec3 specularBRDF(vec3 ldi, vec3 vd_i, vec3 i_d, vec3 gradient, vec3 color, float IOR_t, float IOR_i);
vec3 volGradient(vec3 current_pos);
float getRefractionValue(float value);
vec3 refractionGradient(vec3 current_pos);
vec2 textureToViewSpace(vec2 pos);
vec2 viewToTextureSpace(vec2 pos);
float GschlickGGX(float NdotV, float k);

#define M_PI 3.1415926535897932384626433832795

struct Plane {
	vec3 position;
	vec3 normal;
};

struct Ray {
	vec3 position;
	vec3 direction;
};

vec3 intersectRayPlane(Ray ray, Plane plane) {

	if (dot(plane.normal, ray.direction) == 0.0) {
		return vec3(0, 0, 0);
		//return if line is parallel to plane
		//TODO return None etc. should not occur if light is behind camera and refraction is plausible
	}

	float t = dot(plane.normal, vec3(plane.position - ray.position)) / dot(plane.normal, ray.direction);

	return ray.position + t * ray.direction;
}

vec3 elliptical_filter(vec3 lp_iprev) {
	//https://cseweb.ucsd.edu//~ravir/274/15/papers/a154-patel.pdf
	return vec3(0);
}

float sampleVolume(vec3 pos, float offset) {
	return texture(volume, pos + offset).r;
}

float GschlickGGX(float dotproduct, float k) {
	/* Geometry schlick ggx factor */
	float nominator = dotproduct;
	float denominator = dotproduct * (1.0 - k) + k;

	return nominator / denominator;
	/* Geometry schlick ggx factor END */
}

// Cook torrance BRDF with GGX
vec3 specularBRDF(vec3 ld_i, vec3 vd_i, vec3 i_d, vec3 gradient, vec3 color, float IOR_t, float IOR_i) {
	//https://learnopengl.com/PBR/Theory
	vec3 halfWay = -normalize(vd_i + ld_i);
	vec3 normal = -gradient;
	
	/* Normal distribution factor */
	float alpha_roughness = 0.1;
	float alpha_roughness2 = alpha_roughness * alpha_roughness;
	float NdotH = max(dot(normal, halfWay), 0.0);
	float NdotH2 = NdotH * NdotH;

	float denom = (NdotH2 * (alpha_roughness2 - 1.0) + 1.0);
	float NDF = alpha_roughness2 / (M_PI * denom * denom);
	/* Normal distribution factor END */

	/* geometry smith factor */
	float NdotV = max(dot(normal, vd_i), 0.0);
	float NdotL = max(dot(normal, ld_i), 0.0);
	float k_direct = pow(alpha_roughness + 1.0, 2.0) / 8.0;
	float ggx1 = GschlickGGX(NdotV, k_direct);
	float ggx2 = GschlickGGX(NdotL, k_direct);

	float GGX = ggx1 * ggx2;
	/* geometry smith factor END */
	
	float c = max(dot(vd_i, halfWay), 0.0);
	float sqrtTerm = pow(IOR_t, 2.0) / pow(IOR_i, 2.0) - 1 + c*c;
	float g = 0.0;
	float Fresnel = 0.0;
	if (sqrtTerm > 0.0) {
		g = sqrt(sqrtTerm);
		Fresnel = 0.5 * (pow(g-c, 2.0) / pow(g+2.0, 2.0)) * ((1.0 + (c*(g+c) - 1.0) / (c*(g-c) + 1.0))); 
	} else {
		Fresnel = 1.0;
	}
	
	// Fresnel schlick approximation
	//float F0 = pow(IOR_t - 1.0, 2.0) / pow(IOR_i + 1.0, 2.0);
	//F0 = 0.05;
	
	//float F = F0 + (1.0 - F0) * pow(1.0 - dot(vd_i, normal), 5.0);

	//float specular = (NDF * F0 * GGX) / (4.0 * NdotL * NdotV);
	float specular = 0.0;
	if(NdotL == 0 || NdotV == 0)
		specular = 0.0;
	else
		specular = NDF * GGX / (4.0 * NdotL * NdotV);

	//specular = NDF / (4.0);

	//return vec3(specular, specular, specular);

	// DEBUG 

	vec3 viewDir = normalize(vd_i.xyz);
	vec3 reflectDir = reflect(-ld_i, normal);

	float factor = pow(max(dot(viewDir, reflectDir), 0.0), 64.0);
	vec3 speculardebug = factor * i_d;

	//return speculardebug;
	return speculardebug;
	//return vec3(0);
}

float getRefractionValue(float value){
	return texture(preintTableMedium, vec2(value, value)).a;
}

vec2 textureToViewSpace(vec2 pos){
	//return (pos - 0.5) * 2.0;
	return (pos - 0.5) * 2.0 * camera_zoom;
}

vec2 viewToTextureSpace(vec2 pos){
	//return (pos * 0.5) + 0.5;
	return ((pos / camera_zoom) * 0.5) + 0.5;
}

vec3 filterTextureVec3(vec2 pos, sampler2DArray unfiltered_texture, int read_layer){
	bool use3x3 = false;

	//approximated gaussian filter https://en.wikipedia.org/wiki/Kernel_(image_processing)
	vec2 pos_x = vec2(pos.x+1,pos.y);
	vec2 pos_nx = vec2(pos.x-1,pos.y);
	vec2 pos_nx_y = vec2(pos.x-1,pos.y+1);
	vec2 pos_nx_ny = vec2(pos.x-1,pos.y-1);
	vec2 pos_y = vec2(pos.x,pos.y+1);
	vec2 pos_ny = vec2(pos.x,pos.y-1);
	vec2 pos_y_x = vec2(pos.x+1,pos.y+1);
	vec2 pos_ny_x = vec2(pos.x+1,pos.y-1);

	vec3 filtered_output = vec3(0);
	if (use3x3) {
		vec2 positions[9] = {pos_x, pos_nx, pos_y, pos_ny, pos, pos_nx_y,pos_nx_ny,pos_y_x,pos_ny_x};
		float filter_kernel[9]= {2,2,2,2,10,1,1,1,1};
		float norm = 0;
		for (int i=0; i<filter_kernel.length(); i++){
			float weight = filter_kernel[i];
			filtered_output += texture(unfiltered_texture, vec3 (positions[i],read_layer)).rgb * weight;
			norm += weight;
		}
		filtered_output = filtered_output / norm;
	} else {
		vec2 positions[5] = {pos_x, pos_nx, pos_y, pos_ny, pos};
		float filter_kernel[5]= {2,2,2,2,10};
		float norm = 0;
		for (int i=0; i<filter_kernel.length(); i++){
			float weight = filter_kernel[i];
			filtered_output += texture(unfiltered_texture, vec3 (positions[i],read_layer)).rgb * weight;
			norm += weight;
		}
		filtered_output = filtered_output / norm;
	}
	
	
	return filtered_output;
}
vec3 refractionGradient(vec3 current_pos) {
	// https://www.cg.tuwien.ac.at/courses/RTVis/material/02%20-%20RTVis%20-%20Real-Time%20Volume%20Graphics%201.pdf folie 43
	// use central differences to compute gradient
	float x, y, z;
	float offset = 0.01f; // h in den folien
	

	x = getRefractionValue(texture(volume, current_pos + vec3(offset, 0, 0) + 0.5).r) - getRefractionValue(texture(volume, current_pos - vec3(offset, 0, 0) + 0.5).r);
	y = getRefractionValue(texture(volume, current_pos + vec3(0, offset, 0) + 0.5).r) - getRefractionValue(texture(volume, current_pos - vec3(0, offset, 0) + 0.5).r);
	z = getRefractionValue(texture(volume, current_pos + vec3(0, 0, offset) + 0.5).r) - getRefractionValue(texture(volume, current_pos - vec3(0, 0, offset) + 0.5).r);

	vec3 gradient = 1/(2.0 * offset) * vec3(x, y, z);

	gradient = (viewMatrix * vec4(gradient,0)).rgb;

	if (length(gradient) > 0){
		gradient = normalize(gradient);
	}
	else{
		gradient = vec3(0);
	}
	return gradient;
}

vec3 volGradient(vec3 current_pos) {
	// https://www.cg.tuwien.ac.at/courses/RTVis/material/02%20-%20RTVis%20-%20Real-Time%20Volume%20Graphics%201.pdf folie 43
	// use central differences to compute gradient
	float x, y, z;
	float offset = 0.01f; // h in den folien
	x = texture(volume, current_pos + vec3(offset, 0, 0) + 0.5).r - texture(volume, current_pos - vec3(offset, 0, 0) + 0.5).r;
	y = texture(volume, current_pos + vec3(0, offset, 0) + 0.5).r - texture(volume, current_pos - vec3(0, offset, 0) + 0.5).r;
	z = texture(volume, current_pos + vec3(0, 0, offset) + 0.5).r - texture(volume, current_pos - vec3(0, 0, offset) + 0.5).r;

	vec3 gradient = 1/(2.0 * offset) * vec3(x, y, z);
	
	gradient = (viewMatrix * vec4(gradient,0)).rgb;

	if (length(gradient) > 0){
		gradient = normalize(gradient);
		}
	else{
		gradient = vec3(0);
	}
	return gradient;
}

// http://graphics.stanford.edu/data/3Dscanrep/ VOLUMES REPOSITORY
/// http://openqvis.sourceforge.net/  another ct scans repository
// central differences and more http://graphicsrunner.blogspot.com/2009/01/volume-rendering-102-transfer-functions.html
void lightPropagationNew(vec3 pos, int readingLayer)
{
	vec3 pos_viewspace = vec3(textureToViewSpace(pos.xy),pos.z);
	vec3 ld_i = texture(light_dir_buffer, vec3(pos.xy, readingLayer)).xyz;													// line 7
	// todo fix positions
	Plane plane_prev = Plane(vec3(0, 0, pos_viewspace.z + sample_distance), vec3(0, 0, 1));
	
	//TODO scale light direction?

	Ray ray_prev = Ray(pos_viewspace, -ld_i);
	vec3 lp_iprev_viewspace = intersectRayPlane(ray_prev, plane_prev);
	
	//float fov = 45.0;
	//float Px = (2.0 * lp_iprev_viewspace.x) * tan(fov / 2.0 * M_PI / 180.0);
    //float Py = (2.0 * lp_iprev_viewspace.y) * tan(fov / 2.0 * M_PI / 180.0);
	//float Px2 = (2.0 * pos_viewspace.x) * tan(fov / 2.0 * M_PI / 180.0);
    //float Py2 = (2.0 * pos_viewspace.y) * tan(fov / 2.0 * M_PI / 180.0);
	//vec3 viewDirection = texture(viewray_dir_buffer, vec3(pos.xy, readingLayer)).xyz;
	vec3 lp_iprev = vec3(viewToTextureSpace(lp_iprev_viewspace.xy),lp_iprev_viewspace.z);

	//TODO filter
	vec3 L_iprev = texture(light_color_buffer, vec3(lp_iprev.xy, readingLayer)).xyz; // line 9
	//L_iprev = filterTextureVec3(L_iprev.xy, light_color_buffer, readingLayer);

	vec3 ld_iprev = texture(light_dir_buffer, vec3(lp_iprev.xy, readingLayer)).xyz;											// line 10 TODO: filter
	
	if (filteringEnabled) {
		ld_iprev = filterTextureVec3(lp_iprev.xy, light_dir_buffer, readingLayer);
	}

	float S_i = abs(dFdxFine(pos_viewspace.x)) * abs(dFdyFine(pos_viewspace.y));																		// line 11 (light intensity correction with derivatives from current point)
	float S_iprev = abs(dFdxFine(lp_iprev_viewspace.x)) * abs(dFdyFine(lp_iprev_viewspace.y));																// line 12 (light intensity correction with derivatives from previous point)
	//float S_i = abs(dFdxFine(Px2)) * abs(dFdyFine(Py2));																		// line 11 (light intensity correction with derivatives from current point)
	//float S_iprev = abs(dFdxFine(Px)) * abs(dFdyFine(Py));			

	float l_i;

	if (intensityCorrectionEnabled){
		l_i = S_iprev / S_i;																								// line 13
	} else {
		l_i = 1.0;
	}

	vec4 sample_current_world = inverseViewMatrix * vec4(pos_viewspace, 1.0);
	vec4 sample_prev_world = inverseViewMatrix * vec4(lp_iprev_viewspace, 1.0);

	//TODO use offset or not?
	float volsample_current = sampleVolume(sample_current_world.xyz, 0.5);
	float volsample_prev = sampleVolume(sample_prev_world.xyz, 0.5);

	vec3 medium_color = texture(preintTableMedium, vec2(volsample_current, volsample_prev)).rgb;				// fetch preintegrated tf value with respect to current value and previous intersected value
	
	vec3 refraction_gradient;

	if (refractionEnabled){
		refraction_gradient = refractionGradient(sample_current_world.xyz); 
	} else {
		refraction_gradient = vec3(0.0);
	}

	//TODO use previous or current opacity?
	float opacity = texture(color_buffer, vec3(lp_iprev.xy, readingLayer)).a;

	//medium_color = vec3(1, 1, 1);

	//opacity irgendwas

	 // 1.0 * 1.0 * (1 - 0.000) * 1.0 
	vec3 L_i = L_iprev * l_i * (1 - opacity) * medium_color;

	//vec3 vd_i = texture(viewray_dir_buffer, vec3(pos.xy, readingLayer)).xyz;
	ld_i = normalize(ld_iprev + sample_distance * refraction_gradient);

	Li_out = vec4(L_i, 1.0);//vec4(L_i, 1.0); DEBUG
	// Li_out = vec4(1.0);
	ldi_out = vec4(ld_i, 0.0);
	// ldi_out = vec4(0,0,-1, 0.0);


	debug_out = vec4(getRefractionValue(texture(volume, sample_current_world.xyz + 0.5).r));
}

void viewingRayPropagationNew(vec3 pos, int readingLayer)
{
	vec4 vp_i = texture(viewray_pos_buffer, vec3(pos.xy, readingLayer));				// line 21 (retrieve previous viewing ray position)
	vec4 vp_i_world = inverseViewMatrix * vp_i; // convert position to world coordinates to retrieve volume within bounds 
	vec4 vd_i = texture(viewray_dir_buffer, vec3(pos.xy, readingLayer));				// line 22 (retrieve previous viewing ray direction)
	vec4 color_iprev = texture(color_buffer, vec3(pos.xy, readingLayer)).rgba;			// line 23 (retrieve previous particle color and opacity)
	vec3 medium_iprev = texture(medium_color_buffer, vec3(pos.xy, readingLayer)).rgb;		// line 24 (retrieve previous medium color)

	Plane plane_prev = Plane(vec3(0, 0, vp_i.z + sample_distance), vec3(0, 0, 1));		// line 27
	Ray ray_prev = Ray(vp_i.xyz, -vd_i.xyz);											// line 27
	vec4 vp_iprev = vec4(intersectRayPlane(ray_prev, plane_prev), 1.0);					// line 27
	vec4 vp_iprev_world = inverseViewMatrix * vp_iprev;

	vec3 i_d = texture(light_color_buffer, vec3(pos.xy, readingLayer)).rgb;				// line 25 (retrieve incoming light intensity(color))
	vec3 i_s = vec3(0.0, 0.0, 0.0); // TODO specularBRDF										// line 26
	vec3 vol_gradient = volGradient(vp_i_world.xyz);

	//float density = texture(volume, vp_i_world.xyz).r;										// line 27
	float density = sampleVolume(vp_i_world.xyz, 0.5);								// line 27
	//float density_prev = texture(volume, vp_iprev_world.xyz).r;									// line 27
	float density_prev = sampleVolume(vp_iprev_world.xyz, 0.5);						// line 27

	//TODO read integration table
	vec4 color = texture(preintTableColor, vec2(density_prev, density)).rgba;			// line 27
	//color = texture(preintTableColor, vec2(TexCoord.xy)).rgba;
	vec3 medium_color = texture(preintTableMedium, vec2(density_prev, density)).rgb;	// line 27
	float IOR_t = 1.0 + texture(preintTableMedium, vec2(density_prev, density_prev)).a;
	float IOR_i = 1.0 + texture(preintTableMedium, vec2(density, density)).a;

	if (scatteringEnabled){
		i_s = specularBRDF(ldi_out.xyz, vd_i.xyz, i_d, vol_gradient, color.rgb, IOR_t, IOR_i);
	} else {
		i_s = vec3(0);
	}

	//vec3 C_i = color_iprev.rgb + (1.0 - color_iprev.a) * medium_iprev.rgb * (color.a * color.rgb * i_d + i_s);					// line 28
	vec3 C_i = color_iprev.rgb + (1.0 - color_iprev.a) * medium_iprev.rgb * (color.a * color.rgb * i_d + i_s * color.a);					// line 28 multiply specular component with alpha
	//vec3 C_i = color_iprev.rgb + i_s;					// line 28
	float A_i = color_iprev.a + (1.0 - color_iprev.a) * color.a;																	// line 29												
	vec3 M_i = medium_iprev *  medium_color;	// debug																					// line 30

	//TODO proper gradient
	vec3 IOR_gradient = refractionGradient(vp_i_world.xyz);																					// line 31
	vec4 vd_inext = normalize(vd_i + sample_distance * vec4(IOR_gradient, 0.0));											// line 31
	//vec4 vd_inext = vd_i; // debug: always take the same direction

	Plane plane_next = Plane(vec3(0, 0, vp_i.z - sample_distance), vec3(0, 0, -1));												// line 32
	Ray ray_next = Ray(vp_i.xyz, vd_inext.xyz);																					// line 32
	vec4 vp_inext = vec4(intersectRayPlane(ray_next, plane_next), 1.0);															// line 32

	Ci_out = C_i;
	Ai_out = A_i;
	Mi_out = M_i;
	vp_out = vp_inext;
	vd_out = vd_inext;
	
}


void main()
{
	//http://www.real-time-volume-graphics.org/?page_id=28
	// automatically executes fragment code for each light plane generated in the geometry stage
	//int readingLayer = 1 - gl_Layer;
	int writeLayer = gl_Layer;//readingLayer == 1 ? 0 : 1;
	int readingLayer = 1 - writeLayer;
	
	lightPropagationNew(TexCoord.xyz, readingLayer);
	viewingRayPropagationNew(TexCoord.xyz, readingLayer);

	//Read Textures
	/*
	vec4 r_color_texture = texture(color_buffer, vec3(TexCoord.xy, readingLayer));
	vec4 r_light_color_texture = texture(light_color_buffer, vec3(TexCoord.xy, readingLayer));
	vec4 r_light_dir_texture = texture(light_dir_buffer, vec3(TexCoord.xy, readingLayer));
	vec4 r_medium_color_texture = texture(medium_color_buffer, vec3(TexCoord.xy, readingLayer));
	vec4 r_viewray_dir_texture = texture(viewray_dir_buffer, vec3(TexCoord.xy, readingLayer));
	vec4 r_viewray_pos_texture = texture(viewray_pos_buffer, vec3(TexCoord.xy, readingLayer));
	vec4 r_preintTableMedium = texture(preintTableMedium, TexCoord.xy);
	vec4 r_preintTableColor = texture(preintTableColor, TexCoord.xy);
	*/

	//Write Textures
	color_texture_write = vec4(Ci_out, Ai_out);
	viewray_pos_texture_write = vp_out;
	viewray_dir_texture_write = vd_out;
	medium_color_texture_write = vec4(Mi_out.rgb, 1.0);
	light_color_texture_write = vec4(Li_out.xyz,1.0);
	light_dir_texture_write = vec4(ldi_out.xyz, 0.0);
	debug_texture_write = debug_out;
	//medium_color_texture_write = r_medium_color_texture;
	//viewray_dir_texture_write = r_preintTableMedium;
	//viewray_pos_texture_write = r_viewray_pos_texture;
}