#version 330 compatibility
in vec4 vertex;

// uniform shader-parameters
uniform mat4 osg_ViewMatrix;
uniform mat4 ProjectionMatrix;
uniform mat4 osg_ViewMatrixInverse;
uniform vec3 v3LightPos;		// The direction vector to the light source
uniform vec3 v3InvWavelength;	// 1 / pow(wavelength, 4) for the red, green, and blue channels
uniform float fOuterRadius;		// The outer (atmosphere) radius
uniform float fOuterRadius2;	// fOuterRadius^2
uniform float fInnerRadius;		// The inner (planetary) radius
uniform float fInnerRadius2;	// fInnerRadius^2
uniform float fKrESun;			// Kr * ESun
uniform float fKmESun;			// Km * ESun
uniform float fKr4PI;			// Kr * 4 * PI
uniform float fKm4PI;			// Km * 4 * PI
uniform float fScale;			// 1 / (fOuterRadius - fInnerRadius)
uniform float fScaleDepth;		// The scale depth (i.e. the altitude at which the atmosphere's average density is found)
uniform float fScaleOverScaleDepth;	// fScale / fScaleDepth

uniform mat4 light_mv_matrix;

// vertex-shader output variables (passed to fragment-shader)
out vec4 ex_secondary_color;
out vec4 ex_front_color;
out vec3 ex_v3Direction;
out float ex_worldheight;
out vec3 ex_lightPos;

const int nSamples = 2000;
const float fSamples = 2000;

float scale(float fCos)
{
	float x = 1.0 - fCos;
	return fScaleDepth * exp(-0.00287 + x*(0.459 + x*(3.83 + x*(-6.80 + x*5.25))));
}

float getNearIntersection(vec3 v3Pos, vec3 v3Ray, float fDistance2, float fRadius2)
{
   float B = 2.0 * dot(v3Pos, v3Ray);
   float C = fDistance2 - fRadius2;
   float fDet = max(0.0, B*B - 4.0 * C);
   return 0.5 * (-B - sqrt(fDet));
}

void main()
{

	vec3 lightPos = normalize((inverse(light_mv_matrix) * vec4(0.0,0.0,0.0,1.0)).xyz);
	ex_lightPos = -lightPos;


	vec3 v3CameraPos = (osg_ViewMatrixInverse * vec4(0.0, 0.0, 0.0, 1.0)).xyz;		
	vec3 vecCamera = v3CameraPos;
	vecCamera /= 1000000.0*fInnerRadius;
	vecCamera.z += fInnerRadius;
	
	float fCameraHeight = length(vecCamera);
	float fCameraHeight2 = pow(fCameraHeight, 2.f);	
		
	//vec3 lightPos = normalize(v3LightPos);

	mat4 viewmatrix = osg_ViewMatrix;
	viewmatrix[3][0] = 0.0;
	//viewmatrix[3][1] -= 0.0;
	viewmatrix[3][2] = 0.0;
	
	
	mat4 ModelMatrix = (osg_ViewMatrixInverse * gl_ModelViewMatrix);
	
	// Get the ray from the camera to the vertex, and its length (which is the far point of the ray passing through the atmosphere)
	vec4 vertex2 = (osg_ViewMatrixInverse * viewmatrix) * ModelMatrix * vertex;
	vec3 v3Pos = vertex2.xyz/1000000.0*fInnerRadius;
	v3Pos.z += fInnerRadius;
	
	vec3 v3Ray = v3Pos - vecCamera;
	float fFar = length(v3Ray);
	v3Ray /= fFar;
	
	ex_lightPos = normalize(vertex2.xyz-(inverse(light_mv_matrix) * vec4(0.0,0.0,0.0,1.0)).xyz*1000000.0);

	// Calculate the ray's starting position, then calculate its scattering offset
	vec3 v3Start = vecCamera;
	float fHeight = length(v3Start);
	float fDepth = exp(fScaleOverScaleDepth * (fInnerRadius - fCameraHeight));
	float fStartAngle = dot(v3Ray, v3Start) / fHeight;
	float fStartOffset = fDepth*scale(fStartAngle);

	// Initialize the scattering loop variables
	float fSampleLength = fFar / fSamples;
	float fScaledLength = fSampleLength * fScale;
	vec3 v3SampleRay = v3Ray * fSampleLength;
	//float near = getNearIntersection(vecCamera, v3Ray, fCameraHeight2, fOuterRadius2);
	vec3 v3SamplePoint = v3Start + (v3SampleRay * 0.5);


	// Now loop through the sample rays
	vec3 v3FrontColor = vec3(0.0);
	for(int i=0; i<nSamples; i++)
	{
		float fHeight = length(v3SamplePoint);													//OK
		float fDepth = exp(fScaleOverScaleDepth * (fInnerRadius - fHeight));					//OK
		float fLightAngle = dot(lightPos, v3SamplePoint) / fHeight;								//OK
		float fCameraAngle = dot(v3Ray, v3SamplePoint) / fHeight;								//OK
		float fScatter = (fStartOffset + fDepth*(scale(fLightAngle) - scale(fCameraAngle)));	//OK
		vec3 v3Attenuate = exp(-fScatter * (v3InvWavelength * fKr4PI + fKm4PI));				//OK		
		v3FrontColor += v3Attenuate * vec3(fDepth * fScaledLength);
		v3SamplePoint += v3SampleRay;
	}
/*/
	
	vec3 v3FrontColor = vec3(0.0);
	vec3 v3Attenuate = vec3(1.0);
	
	if(v3Pos.z < vecCamera.z)
	{		
		float fDepth = exp(1.0/fScaleDepth *(fInnerRadius - fCameraHeight));		// exp(fScaleOverScaleDepth * (InnerRadius - fCameraHeight));//
		float fLightAngle = dot(lightPos.xyz, v3Pos) / length(v3Pos);
		float fCameraAngle = dot(-v3Ray, v3Pos) / length(v3Pos);
		float fCameraScale = scale(fCameraAngle);
		float fLightScale = scale(fLightAngle);
		float fCameraOffset = fDepth*fCameraScale;		

		// Initialize the scattering loop variables
		float fSampleLength = fFar / fSamples;
		float fScaledLength = fSampleLength * fScale;
		vec3 v3SampleRay = v3Ray * fSampleLength;	
		vec3 v3SamplePoint = v3Start + v3SampleRay * 0.5;

		float fHeight = length(v3SamplePoint);												//OK
		float fStartAngle = dot(v3Ray, v3Start) / fHeight;
		float fStartOffset = fDepth*scale(fStartAngle);
			
		// Now loop through the sample rays	
		for(int i=0; i<nSamples; i++)
		{
			fHeight = length(v3SamplePoint);												//OK
			fDepth = exp(1.0/fScaleDepth * (fInnerRadius - fHeight));					//OK		
			
			float fScatter = fDepth*(scale(fLightAngle) + scale(fCameraAngle)) - fCameraOffset;
			
			v3Attenuate = exp(-fScatter * (v3InvWavelength * fKr4PI + fKm4PI));				//OK				
			v3FrontColor += v3Attenuate* (fDepth * fScaledLength);		
			v3SamplePoint += v3SampleRay;
		}
	}
	else
	{
		// Calculate the ray's starting position, then calculate its scattering offset
		vec3 v3Start = vecCamera;
		float fHeight = length(v3Start);
		float fDepth = exp(1.0/fScaleDepth * (fInnerRadius - fCameraHeight));
		float fStartAngle = dot(v3Ray, v3Start) / fHeight;
		float fStartOffset = fDepth*scale(fStartAngle);

		// Initialize the scattering loop variables
		float fSampleLength = fFar / fSamples;
		float fScaledLength = fSampleLength * fScale;
		vec3 v3SampleRay = v3Ray * fSampleLength;		
		vec3 v3SamplePoint = v3Start + (v3SampleRay * 0.5);

		// Now loop through the sample rays	
		for(int i=0; i<nSamples; i++)
		{
			float fHeight = length(v3SamplePoint);													//OK
			float fDepth = exp(1.0/fScaleDepth * (fInnerRadius - fHeight));					//OK
			
			float fLightAngle = dot(lightPos.xyz, v3SamplePoint.xyz) / fHeight;								//OK
			float fCameraAngle = dot(v3Ray, v3SamplePoint) / fHeight;								//OK
			float fScatter = (fDepth*(scale(fLightAngle) - scale(fCameraAngle))) + fStartOffset;	//OK
			
			v3Attenuate = exp(-fScatter * (v3InvWavelength * fKr4PI + fKm4PI));				//OK				
			v3FrontColor += v3Attenuate* (fDepth * fScaledLength);		
			v3SamplePoint += v3SampleRay;
		}
	}
*/

	// Finally, scale the Mie and Rayleigh colors and set up the varying variables for the pixel shader
	ex_secondary_color.rgb = v3FrontColor * fKmESun;
	ex_secondary_color.w = 1.0;
	ex_front_color.rgb = v3FrontColor * (v3InvWavelength * fKrESun);
	//ex_front_color.rgb = v3FrontColor * (v3InvWavelength  * fKrESun+ fKmESun);
	ex_front_color.w = 1.0;
	ex_v3Direction = -normalize(vecCamera - v3Pos);
	ex_worldheight = v3CameraPos.z;
	
	
	gl_Position = ProjectionMatrix * viewmatrix * ModelMatrix * vertex;
	//ModelViewProjectionMatrix = gl_ModelViewProjectionMatrix;
}