#pragma once

#include <tamashii/core/scene/ref_entities.hpp>
#include "parameter.hpp"

#include <Eigen/Eigen>
#include <set>

class LightConstraint{
public:
	LightConstraint() : mPenaltyFactor(1.0), mIsActive(true) {}
	virtual ~LightConstraint() = default;
	void setPenaltyFactor(const double aFactor = 1.0) { mPenaltyFactor = aFactor; }
	void setActive(const bool aActive = true) { mIsActive = aActive; }
	void addLight(tamashii::RefLight* aRefLight) { mLights.insert(aRefLight); }
	void removeLight(tamashii::RefLight* aRefLight) { mLights.erase(aRefLight); }
	/*
	 * Evaluate the constraint function.
	 * The result should be 0 iff the constraint is satisfied, and >0 otherwise.
	 * The constraint value should be a distance to the closest constraint satisfying configuration and provide a gradient accordingly.
	 * The result (and the gradient) should be multiplied by the penaltyFactor, for use in penalty-based soft constraint formulations.
	 * If the constraint is inactive (isActive==false) the result should always be 0 (with zero gradient).
	 * The size of the gradient should be lights.size() * LightOptParams::MAX_PARAMS, and the order of entries per light should follow LightOptParams::PARAMS
	 */
	virtual double evalAndAddToGradient(Eigen::VectorXd& aGradient) = 0;
protected:
	std::set<tamashii::RefLight*> mLights;
	double mPenaltyFactor;
	bool mIsActive;
};

class LightsInAABBConstraint final : public LightConstraint{
public:
	LightsInAABBConstraint(const double aXmin, const double aXmax, const double aYmin, const double aYmax, const double aZmin, const double aZmax) : LightConstraint(),
		mXmin(aXmin), mXmax(aXmax), mYmin(aYmin), mYmax(aYmax), mZmin(aZmin), mZmax(aZmax) {}
	
	double evalAndAddToGradient(Eigen::VectorXd& aGradient) override
	{
		double f = 0.0;
		if( mIsActive ){
			for(const tamashii::RefLight* refLight : mLights)
			{
				const auto idx = static_cast<Eigen::Index>(refLight->ref_light_index);
				const glm::vec3& p = refLight->position;
				Eigen::Vector3d d; d.setZero();

				if (static_cast<double>(p.x) < mXmin) d[0] = static_cast<double>(p.x) - mXmin;
				else if (static_cast<double>(p.x) > mXmax) d[0] = static_cast<double>(p.x) - mXmax;

				if (static_cast<double>(p.y) < mYmin) d[1] = static_cast<double>(p.y) - mYmin;
				else if (static_cast<double>(p.y) > mYmax) d[1] = static_cast<double>(p.y) - mYmax;

				if (static_cast<double>(p.z) < mZmin) d[2] = static_cast<double>(p.z) - mZmin;
				else if (static_cast<double>(p.z) > mZmax) d[2] = static_cast<double>(p.z) - mZmax;

				f += 0.5 * mPenaltyFactor * d.squaredNorm();
				aGradient.segment(idx * LightOptParams::MAX_PARAMS + LightOptParams::POS_X, 3) += mPenaltyFactor * d;
			}
		}
		return f;
	}
private:
	double mXmin, mXmax, mYmin, mYmax, mZmin, mZmax;
};

class LightsIntensityPenalty final : public LightConstraint{
public:
	LightsIntensityPenalty(const double penaltyFactor_){ mPenaltyFactor = penaltyFactor_; }
	
	double evalAndAddToGradient(Eigen::VectorXd& aGradient) override
	{
		double f = 0.0;
		if( mIsActive ){
			for(const tamashii::RefLight* refLight : mLights)
			{
				const auto idx = static_cast<Eigen::Index>(refLight->ref_light_index);
				double intensity = refLight->light->getIntensity();
				f += mPenaltyFactor * 0.5*intensity*intensity; // note: quadratic might exessively penalize bright lights, energy consumption is linear with intensity if efficiency is assumed constant - real lights may have reduced efficiency at high power, so higher than linear order might be reasonable in practice
				
				// Note: here we write directly to the intensity gradient - if mixed intensity and colour optimization is active, this result will not match with FD-approximation (but should in principle work for optimization)
#ifdef IALT_USE_QUADRATIC_INTENSITY_OPT
				// quadratic intensity -- for intensity = 0.5*param(i)*param(i)  --> I = p^2 / 2 --> 2I = p^2 --> p = sqrt(2I)
				//dfdI = mPenaltyFactor*intensity;
				//param = sqrt(2.0*intensity);
				//dIdp = param;
				aGradient(idx * LightOptParams::MAX_PARAMS + LightOptParams::INTENSITY) += mPenaltyFactor*intensity*sqrt(2.0*intensity);
#else
				aGradient(idx * LightOptParams::MAX_PARAMS + LightOptParams::INTENSITY) += mPenaltyFactor*intensity;
#endif
			}
		}
		return f;
	}
};

class WindowInModelConstraint : public LightConstraint {
	public:
	WindowInModelConstraint(const double penaltyFactor_) {mPenaltyFactor = penaltyFactor_;}

	double evalAndAddToGradient(Eigen::VectorXd& aGradient) override
	{
		double f = 0.0;
		if( mIsActive ){
			for(const tamashii::RefLight* refLight : mLights) {
				const auto &windowLight = dynamic_cast<tamashii::WindowLight &>(*refLight->light);
				const auto &refModel = dynamic_cast<tamashii::RefModel&>(*refLight->connectedRefModel);

				const auto idx = static_cast<Eigen::Index>(refLight->ref_light_index);
				const glm::vec3& p = refLight->position;
				Eigen::Vector3d d; d.setZero();

				// Get Model matrices
				const auto modelMatrix = refModel.model_matrix;
				glm::vec3 modelScale, modelTranslation;
				glm::quat modelRotation;
				tamashii::math::decomposeTransform(modelMatrix, modelTranslation, modelRotation, modelScale);
				const auto inverseModelMatrix = glm::inverse(glm::scale(modelMatrix, glm::vec3(1.0f) / modelScale));

				glm::dvec3 gradients = {aGradient(idx * LightOptParams::MAX_PARAMS + LightOptParams::POS_X)
										,aGradient(idx * LightOptParams::MAX_PARAMS + LightOptParams::POS_Y)
										,aGradient(idx * LightOptParams::MAX_PARAMS + LightOptParams::POS_Z)};

				gradients = glm::dmat3(inverseModelMatrix) * gradients;

				// Get bounding box of model
				const tamashii::aabb_s AABB = refModel.model->getAABB();
				// Get dimensions of light
				glm::vec3 lightDimensions = windowLight.getDimensions();

				// Divide dimension by 2 for the calculation of the "bounding box" later
				for (int i = 0; i < 3; i++) {
					lightDimensions[i] == 0 ? lightDimensions[i] = 0 : lightDimensions[i] /= 2;
				}

				// Change axes so it aligns with the models axes
				lightDimensions[2] = lightDimensions[1];
				lightDimensions[1] = 0;

				// calculate the "bounding box" of the model
				const float factor = 0.75;
				const glm::vec3 minModelBounds = {modelScale.x * AABB.mMin.x * factor, modelScale.y * AABB.mMin.y * factor, modelScale.z * AABB.mMin.z * factor};
				const glm::vec3 maxModelBounds = {modelScale.x * AABB.mMax.x * factor, modelScale.y * AABB.mMax.y * factor, modelScale.z * AABB.mMax.z * factor};

				// Transform the translation in world Space into the models Object space
				glm::vec3 positionObjectSpace = inverseModelMatrix * glm::vec4{p, 1.0f};

				// calculate the "bounding box" of the light
				const glm::vec3 minLightBounds = {positionObjectSpace.x - lightDimensions.x, positionObjectSpace.y - lightDimensions.y, positionObjectSpace.z - lightDimensions.z};
				const glm::vec3 maxLightBounds = {positionObjectSpace.x + lightDimensions.x, positionObjectSpace.y + lightDimensions.y, positionObjectSpace.z + lightDimensions.z};

				// Determine whether the light is still inside or outside the model
				if (minLightBounds.x <= minModelBounds.x) {
					d[0] = static_cast<double>(minLightBounds.x) - static_cast<double>(minModelBounds.x);
				} else if (maxLightBounds.x >= maxModelBounds.x) {
					d[0] = static_cast<double>(maxLightBounds.x) - static_cast<double>(maxModelBounds.x);
				}
				if (minLightBounds.y <= minModelBounds.y) {
					d[1] = static_cast<double>(minLightBounds.y) - static_cast<double>(minModelBounds.y);
				} else if (maxLightBounds.y >= maxModelBounds.y) {
					d[1] = static_cast<double>(maxLightBounds.y) - static_cast<double>(maxModelBounds.y);
				}
				if (minLightBounds.z <= minModelBounds.z) {
					d[2] = static_cast<double>(minLightBounds.z) - static_cast<double>(minModelBounds.z);
				} else if (maxLightBounds.z >= maxModelBounds.z) {
					d[2] = static_cast<double>(maxLightBounds.z) - static_cast<double>(maxModelBounds.z);
				}

				gradients[0] += mPenaltyFactor * d[0];
				gradients[1] = 0;
				gradients[2] += mPenaltyFactor * d[2];

				gradients = glm::dmat3(glm::scale(modelMatrix, glm::vec3(1.0f) / glm::vec3(modelScale))) * gradients;

				f += 0.5 * mPenaltyFactor * d.squaredNorm();

				aGradient(idx * LightOptParams::MAX_PARAMS + LightOptParams::POS_X) = gradients[0];
				aGradient(idx * LightOptParams::MAX_PARAMS + LightOptParams::POS_Y) = gradients[1];
				aGradient(idx * LightOptParams::MAX_PARAMS + LightOptParams::POS_Z) = gradients[2];
			}
		}
		return f;
	}
};
