#pragma once

#include <Eigen/Eigen>
#include <vector>

class ObjectiveFunction {
public:
						ObjectiveFunction() = default;
	virtual				~ObjectiveFunction() = default;
	virtual float		operator()(Eigen::VectorXf& aX, Eigen::VectorXf& aDx) {
							aDx.setZero();
							constexpr float phi = 0.0f;
							return phi;
						}
protected:
	static float		weightedResidualNormAndDerivative(const Eigen::Ref<Eigen::VectorXf>& aVertexWeights, const Eigen::Ref<Eigen::VectorXf>& aVertexAreas,
							const Eigen::Ref<Eigen::VectorXf>& aTarget, const Eigen::Ref<Eigen::VectorXf>& aX, Eigen::Ref<Eigen::VectorXf> aDx);
};

class SimpleObjectiveFunction final : public ObjectiveFunction {
public:
						// SimpleObjectiveFunction computes (x-target)^T A (x-target) == ||x-target||^2_M
						SimpleObjectiveFunction(const Eigen::Ref<Eigen::VectorXf>& aVertexWeights, const Eigen::Ref<Eigen::VectorXf>& aVertexAreas, 
							const Eigen::Matrix<float, -1, -1, Eigen::RowMajor>& aTarget) : mTarget(aTarget), mVertexWeights(aVertexWeights), mVertexAreas(aVertexAreas) {}

	float				operator()(Eigen::VectorXf& aX, Eigen::VectorXf& aDx) override;

	Eigen::VectorXf		mTarget;
	Eigen::VectorXf		mVertexWeights;
	Eigen::VectorXf		mVertexAreas;
};

// MultiChannelObjectiveFunction extends SimpleObjectiveFunction such that each channel (column) of x and target are compared using the same metric M and then we compute a weighted sum over all channels
// i.e. we compute sum_j ( w_j (x_j - target_j)^T A (x_j - target_j) ), where j indicates a column
class MultiChannelObjectiveFunction final : public ObjectiveFunction {
public:
						//MultiChannelObjectiveFunction(const Eigen::Ref<Eigen::VectorXf>& aVertexWeights, const Eigen::Ref<Eigen::VectorXf>& aVertexAreas, 
						//	const Eigen::Ref<Eigen::VectorXf>& aVertexColor, const Eigen::Matrix<float, -1, -1, Eigen::RowMajor>& aTarget) :
						//		mTarget(aTarget), mVertexWeights(aVertexWeights), mVertexAreas(aVertexAreas), mVertexColor(aVertexColor) {
						//		mChannelWeights.resize(mTarget.cols());
						//		mChannelWeights.setOnes();
						//}
						MultiChannelObjectiveFunction(const Eigen::Ref<Eigen::VectorXf>& aVertexWeights, const Eigen::Ref<Eigen::VectorXf>& aVertexAreas, 
							const Eigen::Ref<Eigen::VectorXf>& aChannelWeights, const Eigen::Ref<Eigen::VectorXf>& aVertexColor, const Eigen::Matrix<float, -1, -1, Eigen::RowMajor>& aXTarget) :
								mTarget(aXTarget), mVertexWeights(aVertexWeights), mVertexAreas(aVertexAreas), mChannelWeights(aChannelWeights), mVertexColor(aVertexColor) {}

	float				operator()(Eigen::VectorXf& aX, Eigen::VectorXf& aDx) override;

	Eigen::MatrixXf		mTarget;
	Eigen::VectorXf		mVertexWeights;
	Eigen::VectorXf		mVertexAreas;
	Eigen::VectorXf		mVertexColor;
	Eigen::VectorXf		mChannelWeights; // channel weights
};


class ConsistentMassMultiChannelObjectiveFunction final : public ObjectiveFunction {
public:
						//ConsistentMassMultiChannelObjectiveFunction(const Eigen::Ref<Eigen::VectorXf>& aVertexWeights, Eigen::MatrixXi& elems, Eigen::MatrixXf& coords, 
						//	const Eigen::Ref<Eigen::VectorXf>& aVertexColor, const Eigen::Matrix<float, -1, -1, Eigen::RowMajor>& aTarget) :
						//		mTarget(aTarget), mVertexColor(aVertexColor) {
						//		mChannelWeights.resize(mTarget.cols());
						//		mChannelWeights.setOnes();
						//		buildConsistentMassMatrix(elems, coords, aVertexWeights);
						//}
						ConsistentMassMultiChannelObjectiveFunction(const Eigen::Ref<Eigen::VectorXf>& aVertexWeights, Eigen::MatrixXi& elems, Eigen::MatrixXf& coords,
							const Eigen::Ref<Eigen::VectorXf>& aChannelWeights, const Eigen::Ref<Eigen::VectorXf>& aVertexColor, const Eigen::Matrix<float, -1, -1, Eigen::RowMajor>& aXTarget) :
								mTarget(aXTarget), mChannelWeights(aChannelWeights), mVertexColor(aVertexColor) {
							buildConsistentMassMatrix(elems, coords, aVertexWeights);
						}

	float				operator()(Eigen::VectorXf& aX, Eigen::VectorXf& aDx) override;
	void buildConsistentMassMatrix(Eigen::MatrixXi& elems, Eigen::MatrixXf& coords, const Eigen::Ref<Eigen::VectorXf>& aVertexWeights);

	Eigen::MatrixXf		mTarget;
	Eigen::SparseMatrix<float> mM;
	Eigen::VectorXf		mVertexColor;
	Eigen::VectorXf		mChannelWeights; // channel weights
};
