#pragma once
#include "Scene.h"
#include "RainState.h"
#include <osg/Geode>
#include <osg/Shader>
#include <osg/Texture2D>

namespace osgCloudyDay
{
	/**
	 * Class to create a rain effect
	 */
	class Rain
	{
	public:
		/**
		 * Constructor
		 */
		Rain(void);
		/**
		 * Deconstructor
		 */
		~Rain(void);

		/**
		 * Initialize the rain effect
		 * @param rain object that saves the states of the rain
		 */
		void Initialize(RainState* rain);
	
		/**
		 * Creates the shader
		 */
		static void CreateShader();
		/**
		 * Creates the texture
		 */
		static void CreateTexture();

		/**
		 * Returns the geode
		 * @return geode 
		 */		
		osg::ref_ptr<osg::Geode> GetGeode();	

	protected:
		osg::ref_ptr<osg::Geode> m_geode;	
		float frand();
		static osg::ref_ptr<osg::Program> m_rain_shader;
		static osg::ref_ptr<osg::Texture3D> m_rain_texture;
	
		std::vector<unsigned int> m_index;
		osg::ref_ptr<osg::Vec3Array> m_vertices;
		osg::ref_ptr<osg::Vec4Array> m_velocity;
		osg::ref_ptr<osg::Vec4Array> m_information;
	};

	/**
	 * Class to update the rain effect at runtime
	 */
	class RainCallback : public osg::NodeCallback
	{
	public:
		/**
		 * Constructor
		 */
		RainCallback(void)
		{

		}

		/**
		 * Class update the particle system of the rain effect
		 */
		virtual void operator()(osg::Node* node, osg::NodeVisitor* nv)
		{
			osgUtil::CullVisitor* cv = dynamic_cast<osgUtil::CullVisitor*>(nv);	   
			osg::ref_ptr<osg::Geode> geometry = dynamic_cast<osg::Geode*> (node);

			float fspeed = 2500.f;

			if(geometry && cv)
			{
				osg::ref_ptr<osg::Geometry> g = static_cast<osg::Geometry*> (geometry->getDrawable(0));
				osg::ref_ptr<osg::Vec3Array> pos = static_cast<osg::Vec3Array*>(g->getVertexAttribArray(0));
				osg::ref_ptr<osg::Vec4Array> vel = static_cast<osg::Vec4Array*>(g->getVertexAttribArray(1));
				for(unsigned int i = 0; i < pos->size(); i++)
				{
					pos->at(i).x() = pos->at(i).x() + 0.05f * vel->at(i).x() * fspeed;
					pos->at(i).y() = pos->at(i).y() + 0.05f * vel->at(i).y() * fspeed;
					pos->at(i).z() = pos->at(i).z() + 0.05f * vel->at(i).z() * fspeed;
					vel->at(i).w() -= 0.05f;

					if(vel->at(i).w() < 0.f)
					{					
						pos->at(i).z() = fspeed;
						vel->at(i).w() = 1.f;
					}
				}
				g->getVertexAttribArray(0)->dirty();
				g->getVertexAttribArray(1)->dirty();

				osg::Matrixd view(cv->getCurrentCamera()->getViewMatrix());
				osg::Matrixd invViewMatrix = osg::Matrixd::inverse(view);
				osg::Matrixd modelview(*cv->getModelViewMatrix());
				osg::Matrixd model = modelview*invViewMatrix;
				osg::Matrixd proj(cv->getCurrentCamera()->getProjectionMatrix());
		
				geometry->getOrCreateStateSet()->getUniform("ModelMatrix")->set(model);
				geometry->getOrCreateStateSet()->getUniform("ViewMatrix")->set(view);
				geometry->getOrCreateStateSet()->getUniform("ProjectionMatrix")->set(proj);		
			
	#ifdef SHADOW_MAPPING
				geometry->getOrCreateStateSet()->getUniform("light_proj_matrix")->set(Scene::GetProjectionMatrix_Light());
				geometry->getOrCreateStateSet()->getUniform("light_mv_matrix")->set(Scene::GetLightCamera()->getViewMatrix());
	#endif
			}
			traverse(node, nv);
		}
	};
}