#pragma once
#include <tamashii/core/render/render_backend_implementation.hpp>
#include <tamashii/renderer_vk/convenience/texture_to_gpu.hpp>
#include <tamashii/renderer_vk/convenience/geometry_to_gpu.hpp>
#include <tamashii/renderer_vk/convenience/material_to_gpu.hpp>
#include <tamashii/renderer_vk/convenience/light_to_gpu.hpp>
#include <tamashii/renderer_vk/render_backend.hpp>
#include <tamashii/core/scene/render_scene.hpp>
#include <rvk/rvk.hpp>

#include "light_trace_opti.hpp"

#include <thread>
#include <chrono>

class InteractiveAdjointLightTracing final : public tamashii::RenderBackendImplementation {
public:
	struct OptimizerResult {
		std::deque<Eigen::VectorXd> history;
		double lastPhi;
	};

					InteractiveAdjointLightTracing(const tamashii::VulkanRenderRoot& aRoot) : mRoot{ aRoot }, mLights{nullptr},
						mUpdates{mRoot.frameCount()}, mLto{aRoot}, mSceneLoaded{false},
						mShowTarget{ false }, mShowAlpha{ false }, mShowAdjointDeriv{ false },
						mAdjointVisRange{ 1 }, mAdjointVisLog{ false }, mShowWireframeOverlay{ false },
						mWireframeColor{ 0.5f }, mShowSH{ false }, mShowInterpolatedSH{ false }, mShowSHSize{ 0.01f },
						mCullMode{ { "None", "Front", "Back" } }, mActiveCullMode{ 2 }, mShowGrad{ false },
						mGradRange{ 1.0f }, mGradVisLog{ false }, mGradVisAcc{ false }, mGradVisWeights{ false }, mShowFDGrad{ false },
						mFdGradH{ 0.5f }, mDrawSetSH{ false }, mOptimizerChoices{ LightTraceOptimizer::optimizerNames }, mOptimizerChoice{ 0 },
						mOptimizerStepSize{ 1.0f }, mPenalty{ 0.0f }, mPTRenderer{ true }, mRecalculate{ false }, mUsePathGuiding{ false },
						mEnvLightIntensity{ 1.f }, mUpdatePathGuidingBuffers{ false }, mTakeScreenshot{ false }, mScreenshotSpp{ 1 }, mScreenshotTimeLimit{ 3600000 }
					{
						mPTGlobalUBO.dither_strength = 1.0f;
						// FIXME: the default path tracer impl uses ccli vars for this value. Should
						// look into doing the same probably
						mPTGlobalUBO.pixelSamplesPerFrame = 1;
						mPTGlobalUBO.accumulatedFrames = 0;

						mPTGlobalUBO.fr_tmin = 0.2f;
						mPTGlobalUBO.fr_tmax = 10000.0f;
						mPTGlobalUBO.br_tmin = 0.001f;
						mPTGlobalUBO.br_tmax = 10000.0f;
						mPTGlobalUBO.sr_tmin = 0.001f;
						mPTGlobalUBO.sr_tmax_offset = -0.01f; // offset for shadow ray tmax

						mPTGlobalUBO.pixel_filter_type = 0;
						mPTGlobalUBO.pixel_filter_width = 1.0f;
						mPTGlobalUBO.pixel_filter_extra = glm::vec2{ 2.0f };

						mPTGlobalUBO.exposure_film = 1;

						mPTGlobalUBO.exposure_tm = 0;
						mPTGlobalUBO.gamma_tm = 1;
						// FIXME: see `pixelSamplesPerFrame`
						mPTGlobalUBO.rrpt = 5;
						mPTGlobalUBO.clamp_direct = 0;
						mPTGlobalUBO.clamp_indirect = 0;
						mPTGlobalUBO.filter_glossy = 0;
						mPTGlobalUBO.light_geometry = true;
						// TODO: use correct bg cli value
						//mEnvLight = glm::vec4{ var::bg.value()[0], var::bg.value()[1], var::bg.value()[2], 255.f } / 255.0f;
						mEnvLight = glm::vec4{ 0.0f, 0.0f, 0.0f, 255.f } / 255.0f;
						// FIXME: see `pixelSamplesPerFrame` (both `max_depth` and `env_shade`)
						mPTGlobalUBO.max_depth = -1;
						mPTGlobalUBO.env_shade = true;

						mPTGlobalUBO.sampling_strategy = 1;
						mPTGlobalUBO.use_cached_radiance = false;
						mPTGlobalUBO.accumulate = true;
						mPTGlobalUBO.shade = false;
						mPTGlobalUBO.tone_mapping_type = 0;
						mPTGlobalUBO.pg_hsh_subdivision_depth = 5;
					}

					~InteractiveAdjointLightTracing() override
					{
						if(mOptimizerThread.joinable()) mOptimizerThread.join();
					}
					InteractiveAdjointLightTracing(const InteractiveAdjointLightTracing&) = delete;
					InteractiveAdjointLightTracing& operator=(const InteractiveAdjointLightTracing&) = delete;
					InteractiveAdjointLightTracing(InteractiveAdjointLightTracing&&) = delete;
					InteractiveAdjointLightTracing& operator=(InteractiveAdjointLightTracing&&) = delete;

	const char*		getName() override { return "ialt"; }

	void			windowSizeChanged(uint32_t aWidth, uint32_t aHeight) override;
	bool			drawOnMesh(const tamashii::DrawInfo* aDrawInfo) override;

					// implementation preparation
	void			prepare(tamashii::RenderInfo_s& aRenderInfo) override;
	void			destroy() override;

					// scene
	void			sceneLoad(tamashii::SceneBackendData aScene) override;
	void			sceneUnload(tamashii::SceneBackendData aScene) override;

					// frame
	void			drawView(tamashii::ViewDef_s* aViewDef) override;
	void			drawUI(tamashii::UiConf_s* aUiConf) override;

	void			waitForNextFrame();
	void			startOptimizerThread(std::function<void(InteractiveAdjointLightTracing*, LightTraceOptimizer*, rvk::Buffer*, unsigned int, float, int)>);


	// python helper functions
	LightTraceOptimizer& getOptimizer() { return mLto; }
	void			runForward(const Eigen::Map<Eigen::VectorXd>&);
	double			runBackward(Eigen::VectorXd&);
	void			useCurrentRadianceAsTarget(bool clearWeights);
	OptimizerResult runOptimizer(LightTraceOptimizer::Optimizers, float, int);

	void			showTarget(const bool b) { mShowTarget = b; }
	void			clearGradVis();
	void			showGradVis(const std::optional<std::pair<tamashii::RefLight*, LightOptParams::PARAMS>>& p)
	{
		mGradImageSelection = p;
		mShowGrad = mGradVisAcc = mGradImageSelection.has_value();
	} 
	void			showFDGradVis(const std::optional<std::pair<tamashii::RefLight*, LightOptParams::PARAMS>>& p, const float h)
	{
		mFDGradImageSelection = p;
		mFdGradH = h;
		mShowFDGrad = mGradVisAcc = mFDGradImageSelection.has_value();
	}

	std::unique_ptr<tamashii::Image>	getFrameImage();
	std::unique_ptr<tamashii::Image>	getGradImage();
private:
	tamashii::VulkanRenderRoot mRoot;

	// data that is frame independent
	struct VkData {
		explicit VkData(rvk::LogicalDevice& aDevice) :
			mGpuTd(&aDevice), mGpuMd(&aDevice), mGpuLd(&aDevice), mGpuBlas(&aDevice),
			mGlobalBuffer(&aDevice), mDescriptor(&aDevice), mShader(&aDevice), mPipelineCullNone(&aDevice),
			mPipelineCullFront(&aDevice), mPipelineCullBack(&aDevice), mShaderSHVis(&aDevice),
			mPipelineSHVisNone(&aDevice), mShaderSingleSHVis(&aDevice),
			mPipelineSingleSHVisNone(&aDevice), mDescriptorDrawOnMesh(&aDevice), mShaderDrawOnMesh(&aDevice),
			mPipelineDrawOnMesh(&aDevice), mCurrentPipeline(nullptr), mRadianceBufferCopy(&aDevice), 
			mFdRadianceBufferCopy(&aDevice), mFd2RadianceBufferCopy(&aDevice),
			mDerivVisPipeline(&aDevice), mDerivVisShader(&aDevice),
			mDerivVisImageAccumulate(&aDevice), mDerivVisImageAccumulateCount(&aDevice), mPTPipeline(&aDevice),
			mPTShader(&aDevice), mPathGuidingPipeline(&aDevice), mPathGuidingShader(&aDevice), mPTImageAccumulate(&aDevice), 
			mPTImageAccumulateCount(&aDevice), mPTCacheImage(&aDevice), mIncidentRadianceBufferCopy(&aDevice), mPGPhiMapBuffer(&aDevice),
			mPGLegendreMapBuffer(&aDevice), mPGNormalizationConstantsBuffer(&aDevice) {}
		// convenience classes for textures, geometry and as
		tamashii::TextureDataVulkan							mGpuTd;
		tamashii::MaterialDataVulkan						mGpuMd;
		tamashii::LightDataVulkan							mGpuLd;
		tamashii::GeometryDataBlasVulkan					mGpuBlas;

		rvk::Buffer											mGlobalBuffer;
		rvk::Descriptor										mDescriptor;
		rvk::RShader										mShader;
		rvk::RPipeline										mPipelineCullNone;
		rvk::RPipeline										mPipelineCullFront;
		rvk::RPipeline										mPipelineCullBack;
		rvk::RShader										mShaderSHVis;
		rvk::RPipeline										mPipelineSHVisNone;
		rvk::RShader										mShaderSingleSHVis;
		rvk::RPipeline										mPipelineSingleSHVisNone;
		rvk::Descriptor										mDescriptorDrawOnMesh;
		rvk::CShader										mShaderDrawOnMesh;
		rvk::CPipeline										mPipelineDrawOnMesh;
		rvk::RPipeline*										mCurrentPipeline;
		rvk::Buffer											mRadianceBufferCopy;
		rvk::Buffer											mFdRadianceBufferCopy;
		rvk::Buffer											mFd2RadianceBufferCopy;

		rvk::RTPipeline										mDerivVisPipeline;
		rvk::RTShader										mDerivVisShader;
		rvk::Image											mDerivVisImageAccumulate;
		rvk::Image											mDerivVisImageAccumulateCount;

		// path tracing data
		rvk::RTPipeline										mPTPipeline;
		rvk::RTShader										mPTShader;
		rvk::RTPipeline										mPathGuidingPipeline;
		rvk::RTShader										mPathGuidingShader;
		rvk::Image											mPTImageAccumulate;
		rvk::Image											mPTImageAccumulateCount;
		rvk::Image											mPTCacheImage;

		// path guiding data
		rvk::Buffer											mIncidentRadianceBufferCopy;
		rvk::Buffer											mPGPhiMapBuffer;
		rvk::Buffer											mPGLegendreMapBuffer;
		rvk::Buffer											mPGNormalizationConstantsBuffer;
	};
	struct VkFrameData {
		explicit VkFrameData(rvk::LogicalDevice& aDevice) : mGpuTlas(&aDevice), mColor(&aDevice),
		mDepth(&aDevice), mDerivVisDescriptor(&aDevice), mDerivVisImage(&aDevice),
		mPTGlobalDescriptor(&aDevice), mPTGlobalUniformBuffer(&aDevice), mPTImage(&aDevice),
		mPTDebugImage(&aDevice) {}
		tamashii::GeometryDataTlasVulkan					mGpuTlas;
		rvk::Image											mColor;
		rvk::Image											mDepth;
		rvk::Descriptor										mDerivVisDescriptor;
		rvk::Image											mDerivVisImage;

		// path tracing data
		rvk::Descriptor										mPTGlobalDescriptor;
		rvk::Buffer											mPTGlobalUniformBuffer;
		rvk::Image											mPTImage;
		rvk::Image											mPTDebugImage;
	};

	std::optional<VkData>									mData;
	std::vector<VkFrameData>								mFrameData;
	std::vector<tamashii::SceneUpdateInfo>					mUpdates;
	std::deque<std::shared_ptr<tamashii::RefLight>>		*mLights;
	LightTraceOptimizer										mLto;

	#include "../../../assets/shader/ialt/defines.h"
	PTGlobalBuffer											mPTGlobalUBO;

	bool													mSceneLoaded;
	bool													mShowTarget;
	bool													mShowAlpha;
	bool													mShowAdjointDeriv;
	float													mAdjointVisRange;
	bool													mAdjointVisLog;
	bool													mShowWireframeOverlay;
	glm::vec3												mWireframeColor;
	bool													mShowSH;
	bool													mShowInterpolatedSH;
	float													mShowSHSize;
	std::vector<std::string>								mCullMode;
	uint32_t												mActiveCullMode;
	bool													mShowGrad;
	float													mGradRange;
	bool													mGradVisLog;
	bool													mGradVisAcc;
	bool													mGradVisWeights;
	bool													mShowFDGrad;
	float													mFdGradH;
	bool													mDrawSetSH;
	float													mPenalty;
	bool													mPTRenderer;
	bool													mRecalculate;
	bool													mUsePathGuiding;
	glm::vec3												mEnvLight;
	float													mEnvLightIntensity;
	bool													mUpdatePathGuidingBuffers;
	bool													mTakeScreenshot;
	uint32_t												mScreenshotSpp;
	uint32_t												mScreenshotTimeLimit;
	std::chrono::time_point<std::chrono::steady_clock>      mScreenshotStartTime;
	//int ;
	static ccli::Var<int> mOptMaxIters;


	std::thread												mOptimizerThread;
	std::vector<std::string>								mOptimizerChoices;
	uint8_t													mOptimizerChoice;
	float													mOptimizerStepSize;

	std::optional<std::pair<tamashii::RefLight*, LightOptParams::PARAMS>> mGradImageSelection;
	std::optional<std::pair<tamashii::RefLight*, LightOptParams::PARAMS>> mFDGradImageSelection;

	std::condition_variable									mNextFrameCV;

	void computeHSHIntegrals();
	void computeNormalizationConstants();
	void recomputePathGuidingBuffers();
};
