#version 330 core
in vec4 vertex;

// uniform shader-parameters
uniform mat4 osg_ViewMatrix;
uniform mat4 ProjectionMatrix;
uniform mat4 osg_ViewMatrixInverse;
uniform vec3 v3LightPos;		// The direction vector to the light source
uniform vec3 v3InvWavelength;	// 1 / pow(wavelength, 4) for the red, green, and blue channels
uniform float fOuterRadius;		// The outer (atmosphere) radius
uniform float fOuterRadius2;	// fOuterRadius^2
uniform float fInnerRadius;		// The inner (planetary) radius
uniform float fInnerRadius2;	// fInnerRadius^2
uniform float fKrESun;			// Kr * ESun
uniform float fKmESun;			// Km * ESun
uniform float fKr4PI;			// Kr * 4 * PI
uniform float fKm4PI;			// Km * 4 * PI
uniform float fScale;			// 1 / (fOuterRadius - fInnerRadius)
uniform float fScaleDepth;		// The scale depth (i.e. the altitude at which the atmosphere's average density is found)
uniform float fScaleOverScaleDepth;	// fScale / fScaleDepth


// vertex-shader output variables (passed to fragment-shader)
out vec4 ex_secondary_color;
out vec4 ex_front_color;
out vec3 ex_v3Direction;
out float ex_worldheight;
//out mat4 ModelViewProjectionMatrix;

const int nSamples = 6;
const float fSamples = 6.0;

float scale(float fCos)
{
	float x = 1.0 - fCos;
	return fScaleDepth * exp(-0.00287 + x*(0.459 + x*(3.83 + x*(-6.80 + x*5.25))));
}

float getNearIntersection(vec3 v3Pos, vec3 v3Ray, float fDistance2, float fRadius2)
{
	float B = 2.0 * dot(v3Pos, v3Ray);
	float C = fDistance2 - fRadius2;
	float fDet = max(0.0, B*B - 4.0 * C);
	return 0.5 * (-B - sqrt(fDet));
}

void main()
{
	vec3 lightPos = normalize(v3LightPos);

	vec3 v3CameraPos = (osg_ViewMatrixInverse * vec4(0.0, 0.0, 0.0, 1.0)).xyz;	
	vec3 vecCamera = v3CameraPos;
	vecCamera /= 10000.0;	
	//if (vecCamera.z <= fInnerRadius) 
		//vecCamera.z = fInnerRadius + 1.0e-6f;	
	
	float fCameraHeight = vecCamera.z;
	float fCameraHeight2 = pow(fCameraHeight, 2.f);	
		
	vec3 posWS = vertex.xyz/10000.0;
	posWS.z += fInnerRadius;
  
	vec3 ray = posWS - vecCamera;
	float far = length(ray);
	ray /= far;
		
	float near = getNearIntersection(vecCamera, ray, fCameraHeight2, fOuterRadius2);
	vec3 rayStart = vecCamera + (ray * near);
	far -= near;
  
	float startAngle = dot(ray, rayStart) / fOuterRadius;
	float startDepth = exp(-1.0/fScaleDepth);
	float startOffset = startDepth * scale(startAngle);
  
	float sampleLength = far / fSamples;
	float scaledLength = sampleLength * fScale;
	vec3 sampleRay = ray * sampleLength;
	vec3 samplePoint = rayStart + (sampleRay * .5f);
  
	vec3 frontColor = vec3(0.0f, 0.0f, 0.0f);
  
	for(int i = 0; i < nSamples; i++) 
	{
		float height = length(samplePoint);
		float depth = exp(fScaleOverScaleDepth * (fInnerRadius - height));
		float lightAngle = dot(-lightPos, samplePoint) / height;
		float cameraAngle = dot(ray, samplePoint) / height;
		float scatter = startOffset + (depth * (scale(lightAngle) - scale(cameraAngle)));
		vec3 attenuate = exp(-scatter * ((v3InvWavelength * fKr4PI) + fKm4PI));    
		frontColor += attenuate * (depth * scaledLength);    
		samplePoint += sampleRay;
	}
		
	ex_front_color = vec4(frontColor * (v3InvWavelength * fKrESun), 1.0f);
	ex_secondary_color = vec4(frontColor * fKmESun, 1.0f);
	ex_v3Direction = vecCamera - posWS;
	ex_worldheight = v3CameraPos.z;
	mat4 viewmatrix = osg_ViewMatrix;
	gl_Position = ProjectionMatrix * viewmatrix * vertex;
}