#include "ialt.hpp"
#include <imgui.h>
#include <algorithm>
#include <glm/gtc/color_space.hpp>
#include <tamashii/core/common/common.hpp>
#include <tamashii/core/common/vars.hpp>
#include <tamashii/core/common/input.hpp>
#include <tamashii/core/scene/model.hpp>
#include <tamashii/core/platform/filewatcher.hpp>
#include <tamashii/core/io/io.hpp>

T_USE_NAMESPACE
RVK_USE_NAMESPACE

constexpr uint32_t MAX_INDEX_DATA_SIZE = (10 * 1024 * 1024);
constexpr uint32_t MAX_VERTEX_DATA_SIZE = (8 * 512 * 1024);
constexpr uint32_t MAX_GLOBAL_IMAGE_SIZE = (1024);
constexpr uint32_t MAX_MATERIAL_SIZE = (2 * 1024);
constexpr uint32_t MAX_LIGHT_SIZE = (2 * 128);
constexpr uint32_t MAX_INSTANCE_SIZE = (4 * 1024);
constexpr uint32_t MAX_GEOMETRY_DATA_SIZE = (2 * 1024);

constexpr VkFormat COLOR_FORMAT = VK_FORMAT_R32G32B32A32_SFLOAT;
constexpr VkFormat DEPTH_FORMAT = VK_FORMAT_D32_SFLOAT;

constexpr VkFormat PT_ACCUMULATION_FORMAT = VK_FORMAT_R32G32B32A32_SFLOAT;
constexpr VkFormat PT_COUNT_FORMAT = VK_FORMAT_R32_SFLOAT;

constexpr float M_PI = 3.1415926535897932384626433832795028841;

using LTOVars = LightTraceOptimizer::vars;

ccli::Var<int>		InteractiveAdjointLightTracing::mOptMaxIters("", "maxIters", 200, ccli::Flag::ConfigRead, "Limit optimizer iterations.");

void InteractiveAdjointLightTracing::windowSizeChanged(const uint32_t aWidth, const uint32_t aHeight)
{
	SingleTimeCommand stc = mRoot.singleTimeCommand();
	stc.begin();
	// accumulate Image
	mData->mDerivVisImageAccumulate.destroy();
	mData->mDerivVisImageAccumulateCount.destroy();
	// rtImage
	mData->mDerivVisImageAccumulate.createImage2D(aWidth, aHeight, VK_FORMAT_R32G32B32A32_SFLOAT, VK_IMAGE_USAGE_TRANSFER_SRC_BIT | VK_IMAGE_USAGE_STORAGE_BIT | VK_IMAGE_USAGE_TRANSFER_DST_BIT);
	mData->mDerivVisImageAccumulate.CMD_TransitionImage(stc.buffer(), VK_IMAGE_LAYOUT_GENERAL);
	mData->mDerivVisImageAccumulateCount.createImage2D(aWidth, aHeight, VK_FORMAT_R32_UINT, VK_IMAGE_USAGE_TRANSFER_SRC_BIT | VK_IMAGE_USAGE_STORAGE_BIT | VK_IMAGE_USAGE_TRANSFER_DST_BIT);
	mData->mDerivVisImageAccumulateCount.CMD_TransitionImage(stc.buffer(), VK_IMAGE_LAYOUT_GENERAL);

	mData->mPTImageAccumulate.destroy();
	mData->mPTImageAccumulateCount.destroy();
	
	mData->mPTImageAccumulate.createImage2D(aWidth, aHeight, PT_ACCUMULATION_FORMAT, VK_IMAGE_USAGE_TRANSFER_SRC_BIT | VK_IMAGE_USAGE_STORAGE_BIT | VK_IMAGE_USAGE_TRANSFER_DST_BIT);
	mData->mPTImageAccumulate.CMD_TransitionImage(stc.buffer(), VK_IMAGE_LAYOUT_GENERAL);
	mData->mPTImageAccumulateCount.createImage2D(aWidth, aHeight, PT_COUNT_FORMAT, VK_IMAGE_USAGE_TRANSFER_SRC_BIT | VK_IMAGE_USAGE_STORAGE_BIT | VK_IMAGE_USAGE_TRANSFER_DST_BIT);
	mData->mPTImageAccumulateCount.CMD_TransitionImage(stc.buffer(), VK_IMAGE_LAYOUT_GENERAL);

	for (uint32_t si_idx = 0; si_idx < mRoot.frameCount(); si_idx++) {
		VkFrameData& frameData = mFrameData[si_idx];
		frameData.mColor.destroy();
		frameData.mDepth.destroy();

		frameData.mColor.createImage2D(aWidth, aHeight, COLOR_FORMAT, rvk::Image::Use::DOWNLOAD | rvk::Image::Use::COLOR_ATTACHMENT | rvk::Image::Use::UPLOAD);
		frameData.mDepth.createImage2D(aWidth, aHeight, DEPTH_FORMAT, rvk::Image::Use::DOWNLOAD | rvk::Image::Use::DEPTH_STENCIL_ATTACHMENT);
		frameData.mColor.CMD_TransitionImage(stc.buffer(), VK_IMAGE_LAYOUT_GENERAL);
		frameData.mDepth.CMD_TransitionImage(stc.buffer(), VK_IMAGE_LAYOUT_DEPTH_STENCIL_ATTACHMENT_OPTIMAL);

		frameData.mDerivVisImage.destroy();
		// rtImage
		frameData.mDerivVisImage.createImage2D(aWidth, aHeight, VK_FORMAT_B8G8R8A8_UNORM, VK_IMAGE_USAGE_TRANSFER_SRC_BIT | VK_IMAGE_USAGE_STORAGE_BIT | VK_IMAGE_USAGE_TRANSFER_DST_BIT);
		frameData.mDerivVisImage.CMD_TransitionImage(stc.buffer(), VK_IMAGE_LAYOUT_GENERAL);
		frameData.mDerivVisDescriptor.setImage(DERIV_VIS_DESC_DERIV_IMAGE_BINDING, &frameData.mDerivVisImage);
		frameData.mDerivVisDescriptor.setImage(DERIV_VIS_DESC_DERIV_ACC_IMAGE_BINDING, &mData->mDerivVisImageAccumulate);
		frameData.mDerivVisDescriptor.setImage(DERIV_VIS_DESC_DERIV_ACC_COUNT_IMAGE_BINDING, &mData->mDerivVisImageAccumulateCount);
		if(mSceneLoaded) frameData.mDerivVisDescriptor.update();

		frameData.mPTImage.destroy();
		frameData.mPTDebugImage.destroy();
		frameData.mPTImage.createImage2D(aWidth, aHeight, VK_FORMAT_R8G8B8A8_UNORM, VK_IMAGE_USAGE_TRANSFER_SRC_BIT | VK_IMAGE_USAGE_STORAGE_BIT | VK_IMAGE_USAGE_TRANSFER_DST_BIT);
		frameData.mPTImage.CMD_TransitionImage(stc.buffer(), VK_IMAGE_LAYOUT_GENERAL);
		frameData.mPTDebugImage.createImage2D(aWidth, aHeight, VK_FORMAT_R8G8B8A8_UNORM, VK_IMAGE_USAGE_TRANSFER_SRC_BIT | VK_IMAGE_USAGE_STORAGE_BIT | VK_IMAGE_USAGE_TRANSFER_DST_BIT);
		frameData.mPTDebugImage.CMD_TransitionImage(stc.buffer(), VK_IMAGE_LAYOUT_GENERAL);
		frameData.mPTGlobalDescriptor.setImage(GLSL_GLOBAL_RT_OUT_IMAGE_BINDING, &frameData.mPTImage);
		frameData.mPTGlobalDescriptor.setImage(GLSL_GLOBAL_RT_ACC_IMAGE_BINDING, &mData->mPTImageAccumulate);
		frameData.mPTGlobalDescriptor.setImage(GLSL_GLOBAL_RT_ACC_C_IMAGE_BINDING, &mData->mPTImageAccumulateCount);
		frameData.mPTGlobalDescriptor.setImage(GLSL_GLOBAL_DEBUG_IMAGE_BINDING, &frameData.mPTDebugImage);
		if (mSceneLoaded) frameData.mPTGlobalDescriptor.update();
	}
	stc.end();
}

bool InteractiveAdjointLightTracing::drawOnMesh(const tamashii::DrawInfo* aDrawInfo)
{
	auto color = glm::vec4(0.0f);
	bool draw = false;
	bool left = false;
//#ifdef IALT_USE_SPHERICAL_HARMONICS
//	if (InputSystem::getInstance().wasPressed(Input::MOUSE_LEFT)) {
//#else
	if (InputSystem::getInstance().isDown(Input::MOUSE_LEFT)) {
//#endif
		color = glm::vec4(glm::convertSRGBToLinear(glm::vec3(aDrawInfo->mColor0)), aDrawInfo->mColor0.w);
		draw = true;
		left = true;
	}
//#ifdef IALT_USE_SPHERICAL_HARMONICS
//	else if (InputSystem::getInstance().wasPressed(Input::MOUSE_RIGHT)) {
//#else
	else if (InputSystem::getInstance().isDown(Input::MOUSE_RIGHT)) {
//#endif
		color = glm::vec4(glm::convertSRGBToLinear(glm::vec3(aDrawInfo->mColor1)), aDrawInfo->mColor1.w);
		draw = true;
	}
	if (draw) {
		const triangle_s tri = aDrawInfo->mHitInfo.mRefMeshHit->mesh->getTriangle(aDrawInfo->mHitInfo.mPrimitiveIndex);

		const auto refModel = std::static_pointer_cast<RefModel>(aDrawInfo->mHitInfo.mHit);
		const RefMesh* refMesh = aDrawInfo->mHitInfo.mRefMeshHit;
		refMesh->mesh->hasColors0(true);
		const GeometryDataVulkan::primitveBufferOffset_s offset = mData->mGpuBlas.getOffset(refMesh->mesh.get());
		const glm::vec4 hitPosWS = glm::vec4(aDrawInfo->mPositionWs, 1.0f);

		DrawBuffer db = {};
		db.modelMatrix = refModel->model_matrix;
		db.positionWS = aDrawInfo->mPositionWs;
		db.normal = glm::normalize(glm::mat3(glm::transpose(glm::inverse(refModel->model_matrix))) * tri.mGeoN);
		db.radius = aDrawInfo->mRadius;
		db.color = color;
		db.originWS = aDrawInfo->mHitInfo.mOriginPos;
		db.vertexOffset = offset.mVertexOffset;
		db.softBrush = aDrawInfo->mSoftBrush;
		db.drawAll = aDrawInfo->mDrawAll;
		db.drawRGB = aDrawInfo->mDrawRgb;
		db.drawALPHA = aDrawInfo->mDrawAlpha;
		db.leftMouse = left;
		db.setSH = mDrawSetSH;
		
		const CommandBuffer cb = mRoot.currentCmdBuffer();
		mData->mPipelineDrawOnMesh.CMD_SetPushConstant(&cb, VK_SHADER_STAGE_COMPUTE_BIT, 0, sizeof(db), &db);

		mData->mPipelineDrawOnMesh.CMD_BindDescriptorSets(&cb, { &mData->mDescriptorDrawOnMesh });
		mData->mPipelineDrawOnMesh.CMD_BindPipeline(&cb);

		mData->mPipelineDrawOnMesh.CMD_Dispatch(&cb, static_cast<uint32_t>(refMesh->mesh->getVertexCount()));
	}
	return false;
}

// implementation preparation
void InteractiveAdjointLightTracing::prepare(tamashii::RenderInfo_s& aRenderInfo) {
#ifdef IALT_USE_SPHERICAL_HARMONICS
	LightTraceOptimizer::sphericalHarmonicOrder = LightTraceOptimizer::vars::shOrder.value();
	LightTraceOptimizer::entries_per_vertex = 3 * (LightTraceOptimizer::sphericalHarmonicOrder + 1) * (LightTraceOptimizer::sphericalHarmonicOrder + 1); // RGB for each SH coeff
	spdlog::info("SH-order is {}", LightTraceOptimizer::sphericalHarmonicOrder);
	// ToDo: where can we allow changes at runtime? Probably need to re-run ::prepare(...) and ::init(...)
#endif
	mData.emplace(mRoot.device);
	mFrameData.resize(mRoot.frameCount(), VkFrameData(mRoot.device));

	RPipeline::global_render_state.renderpass = nullptr;
	RPipeline::global_render_state.colorFormat = { COLOR_FORMAT };
	RPipeline::global_render_state.depthFormat = DEPTH_FORMAT;

	mData->mGpuTd.prepare(rvk::Shader::Stage::RAYGEN | rvk::Shader::Stage::ANY_HIT | rvk::Shader::Stage::COMPUTE | rvk::Shader::Stage::FRAGMENT, MAX_GLOBAL_IMAGE_SIZE);
	mData->mGpuMd.prepare(rvk::Buffer::Use::STORAGE, MAX_MATERIAL_SIZE);
	mData->mGpuLd.prepare(rvk::Buffer::Use::STORAGE, MAX_LIGHT_SIZE);
	constexpr uint32_t flags = rvk::Buffer::Use::STORAGE | rvk::Buffer::Use::AS_INPUT;
	mData->mGpuBlas.prepare(MAX_INDEX_DATA_SIZE, MAX_VERTEX_DATA_SIZE, flags | rvk::Buffer::Use::INDEX, flags | rvk::Buffer::Use::VERTEX);
	for (auto& fd : mFrameData) fd.mGpuTlas.prepare(MAX_INSTANCE_SIZE, rvk::Buffer::Use::STORAGE, MAX_GEOMETRY_DATA_SIZE);

	mLto.init(&mData->mGpuTd, &mData->mGpuMd, &mData->mGpuLd, &mData->mGpuBlas, &mFrameData.front().mGpuTlas);

	mData->mGlobalBuffer.create(rvk::Buffer::Use::UNIFORM, sizeof(GlobalBufferR), rvk::Buffer::Location::HOST_COHERENT);
	mData->mGlobalBuffer.mapBuffer();

	mData->mDescriptor.reserve(4);
	mData->mDescriptor.addUniformBuffer(RASTERIZER_DESC_GLOBAL_BUFFER_BINDING, rvk::Shader::Stage::VERTEX | rvk::Shader::Stage::GEOMETRY | rvk::Shader::Stage::FRAGMENT);
	mData->mDescriptor.addStorageBuffer(RASTERIZER_DESC_RADIANCE_BUFFER_BINDING, rvk::Shader::Stage::VERTEX | rvk::Shader::Stage::FRAGMENT);
	mData->mDescriptor.addStorageBuffer(RASTERIZER_DESC_TARGET_RADIANCE_BUFFER_BINDING, rvk::Shader::Stage::FRAGMENT);
	mData->mDescriptor.addStorageBuffer(RASTERIZER_DESC_TARGET_RADIANCE_WEIGHTS_BUFFER_BINDING, rvk::Shader::Stage::FRAGMENT);
	mData->mDescriptor.addStorageBuffer(RASTERIZER_DESC_CHANNEL_WEIGHTS_BUFFER_BINDING, rvk::Shader::Stage::FRAGMENT);
	mData->mDescriptor.addStorageBuffer(RASTERIZER_DESC_VTX_COLOR_BUFFER_BINDING, rvk::Shader::Stage::VERTEX);
	mData->mDescriptor.addStorageBuffer(RASTERIZER_DESC_AREA_BUFFER_BINDING, rvk::Shader::Stage::FRAGMENT);
	mData->mDescriptor.addStorageBuffer(RASTERIZER_DESC_MATERIAL_BUFFER_BINDING, rvk::Shader::Stage::FRAGMENT);
	mData->mDescriptor.finish(false);

	uint32_t constData[3] = { LightTraceOptimizer::sphericalHarmonicOrder, LightTraceOptimizer::entries_per_vertex, (uint32_t)(LightTraceOptimizer::vars::unphysicalNicePreview.asBool().value()) };

	mData->mShader.addStage(rvk::Shader::Source::GLSL, rvk::Shader::Stage::VERTEX, IALT_SHADER_DIR "rasterizer_vertex.glsl", LightTraceOptimizer::shaderDefines);
	mData->mShader.addStage(rvk::Shader::Source::GLSL, rvk::Shader::Stage::GEOMETRY, IALT_SHADER_DIR "rasterizer_geometry.glsl", LightTraceOptimizer::shaderDefines);
	mData->mShader.addStage(rvk::Shader::Source::GLSL, rvk::Shader::Stage::FRAGMENT, IALT_SHADER_DIR "rasterizer_fragment.glsl", LightTraceOptimizer::shaderDefines);
	mData->mShader.addConstant(0, 0, 4u, 0u);
	mData->mShader.addConstant(0, 1, 4u, 4u);
	mData->mShader.addConstant(0, 2, 4u, 8u);
	mData->mShader.setConstantData(0, constData, 12u);
	mData->mShader.addConstant(2, 0, 4u, 0u);
	mData->mShader.addConstant(2, 1, 4u, 4u);
	mData->mShader.addConstant(2, 2, 4u, 8u);
	mData->mShader.setConstantData(2, constData, 12u);
	mData->mShader.finish();

	mData->mPipelineCullNone.setShader(&mData->mShader);
	mData->mPipelineCullNone.addPushConstant(rvk::Shader::Stage::VERTEX | rvk::Shader::Stage::FRAGMENT, 0, 17 * sizeof(float));
	mData->mPipelineCullNone.addDescriptorSet({ mData->mGpuTd.getDescriptor(), &mData->mDescriptor });

	mData->mPipelineCullNone.addBindingDescription(0, sizeof(vertex_s), VK_VERTEX_INPUT_RATE_VERTEX);
	mData->mPipelineCullNone.addAttributeDescription(0, 0, VK_FORMAT_R32G32B32_SFLOAT, offsetof(vertex_s, position));
	mData->mPipelineCullNone.addAttributeDescription(0, 1, VK_FORMAT_R32G32B32_SFLOAT, offsetof(vertex_s, normal));
	mData->mPipelineCullNone.addAttributeDescription(0, 2, VK_FORMAT_R32G32B32A32_SFLOAT, offsetof(vertex_s, tangent));
	mData->mPipelineCullNone.addAttributeDescription(0, 3, VK_FORMAT_R32G32_SFLOAT, offsetof(vertex_s, texture_coordinates_0));
	mData->mPipelineCullNone.addAttributeDescription(0, 4, VK_FORMAT_R32G32_SFLOAT, offsetof(vertex_s, texture_coordinates_1));
	mData->mPipelineCullNone.addAttributeDescription(0, 5, VK_FORMAT_R32G32B32A32_SFLOAT, offsetof(vertex_s, color_0));

	mData->mPipelineCullNone.finish();
	rvk::RPipeline::pushRenderState();
	rvk::RPipeline::global_render_state.cullMode = VK_CULL_MODE_FRONT_BIT;
	mData->mPipelineCullFront = mData->mPipelineCullNone;
	mData->mPipelineCullFront.finish();
	rvk::RPipeline::global_render_state.cullMode = VK_CULL_MODE_BACK_BIT;
	mData->mPipelineCullBack = mData->mPipelineCullNone;
	mData->mPipelineCullBack.finish();
	rvk::RPipeline::popRenderState();

#ifdef IALT_USE_SPHERICAL_HARMONICS
	mData->mShaderSHVis.addStage(rvk::Shader::Source::GLSL, rvk::Shader::Stage::VERTEX, IALT_SHADER_DIR "sh_vis.vert", LightTraceOptimizer::shaderDefines);
	mData->mShaderSHVis.addStage(rvk::Shader::Source::GLSL, rvk::Shader::Stage::GEOMETRY, IALT_SHADER_DIR "sh_vis.geom", LightTraceOptimizer::shaderDefines);
	mData->mShaderSHVis.addStage(rvk::Shader::Source::GLSL, rvk::Shader::Stage::FRAGMENT, IALT_SHADER_DIR "sh_vis.frag", LightTraceOptimizer::shaderDefines);
	mData->mShaderSHVis.addConstant(0, 0, 4u, 0u);
	mData->mShaderSHVis.addConstant(2, 0, 4u, 0u);
	mData->mShaderSHVis.addConstant(2, 1, 4u, 4u);
	mData->mShaderSHVis.setConstantData(0, &LightTraceOptimizer::entries_per_vertex, 4u);
	mData->mShaderSHVis.setConstantData(2, &constData[0], 8u);
	mData->mShaderSHVis.finish();

	mData->mPipelineSHVisNone.setShader(&mData->mShaderSHVis);
	mData->mPipelineSHVisNone.addPushConstant(rvk::Shader::Stage::VERTEX, 0, sizeof(glm::mat4));
	mData->mPipelineSHVisNone.addPushConstant(rvk::Shader::Stage::GEOMETRY | rvk::Shader::Stage::FRAGMENT, sizeof(glm::mat4), sizeof(float));

	mData->mPipelineSHVisNone.addBindingDescription(0, sizeof(vertex_s), VK_VERTEX_INPUT_RATE_VERTEX);
	mData->mPipelineSHVisNone.addAttributeDescription(0, 0, VK_FORMAT_R32G32B32_SFLOAT, offsetof(vertex_s, position));
	mData->mPipelineSHVisNone.addAttributeDescription(0, 1, VK_FORMAT_R32G32B32_SFLOAT, offsetof(vertex_s, normal));
	mData->mPipelineSHVisNone.addAttributeDescription(0, 2, VK_FORMAT_R32G32B32_SFLOAT, offsetof(vertex_s, tangent));
	mData->mPipelineSHVisNone.addDescriptorSet({ &mData->mDescriptor });
	rvk::RPipeline::pushRenderState();
	rvk::RPipeline::global_render_state.primitiveTopology = VK_PRIMITIVE_TOPOLOGY_POINT_LIST;
	rvk::RPipeline::global_render_state.polygonMode = VK_POLYGON_MODE_FILL;
	rvk::RPipeline::global_render_state.cullMode = VK_CULL_MODE_NONE;
	mData->mPipelineSHVisNone.finish();
	rvk::RPipeline::popRenderState();

	// single sh vis
	mData->mShaderSingleSHVis.addStage(rvk::Shader::Source::GLSL, rvk::Shader::Stage::VERTEX, IALT_SHADER_DIR "single_sh_vis.vert", LightTraceOptimizer::shaderDefines);
	mData->mShaderSingleSHVis.addStage(rvk::Shader::Source::GLSL, rvk::Shader::Stage::GEOMETRY, IALT_SHADER_DIR "single_sh_vis.geom", LightTraceOptimizer::shaderDefines);
	mData->mShaderSingleSHVis.addStage(rvk::Shader::Source::GLSL, rvk::Shader::Stage::FRAGMENT, IALT_SHADER_DIR "single_sh_vis.frag", LightTraceOptimizer::shaderDefines);
	mData->mShaderSingleSHVis.addConstant(2, 0, 4u, 0u);
	mData->mShaderSingleSHVis.addConstant(2, 1, 4u, 4u);
	mData->mShaderSingleSHVis.setConstantData(2, &constData[0], 8u);
	mData->mShaderSingleSHVis.finish();

	mData->mPipelineSingleSHVisNone.setShader(&mData->mShaderSingleSHVis);
	mData->mPipelineSingleSHVisNone.addPushConstant(rvk::Shader::Stage::VERTEX, 0u, 16u);
	mData->mPipelineSingleSHVisNone.addPushConstant(rvk::Shader::Stage::FRAGMENT, 16u, 16u + 8u + 4u);
	mData->mPipelineSingleSHVisNone.addPushConstant(rvk::Shader::Stage::GEOMETRY, 16u * 2u + 8u, 4u);

	mData->mPipelineSingleSHVisNone.addDescriptorSet({ &mData->mDescriptor });
	rvk::RPipeline::pushRenderState();
	rvk::RPipeline::global_render_state.primitiveTopology = VK_PRIMITIVE_TOPOLOGY_POINT_LIST;
	rvk::RPipeline::global_render_state.polygonMode = VK_POLYGON_MODE_FILL;
	rvk::RPipeline::global_render_state.cullMode = VK_CULL_MODE_NONE;
	mData->mPipelineSingleSHVisNone.finish();
	rvk::RPipeline::popRenderState();
#endif

	mData->mDescriptorDrawOnMesh.reserve(2);
	mData->mDescriptorDrawOnMesh.addStorageBuffer(DRAW_DESC_VERTEX_BUFFER_BINDING, rvk::Shader::Stage::COMPUTE);
	mData->mDescriptorDrawOnMesh.addStorageBuffer(DRAW_DESC_TARGET_RADIANCE_BUFFER_BINDING, rvk::Shader::Stage::COMPUTE);
	mData->mDescriptorDrawOnMesh.addStorageBuffer(DRAW_DESC_TARGET_RADIANCE_WEIGHTS_BUFFER_BINDING, rvk::Shader::Stage::COMPUTE);
	mData->mDescriptorDrawOnMesh.finish(false);

	mData->mShaderDrawOnMesh.addStage(rvk::Shader::Source::GLSL, rvk::Shader::Stage::COMPUTE, IALT_SHADER_DIR "draw_on_mesh.comp", LightTraceOptimizer::shaderDefines);
	mData->mShaderDrawOnMesh.addConstant(0, 0, 4u, 0u);
	mData->mShaderDrawOnMesh.addConstant(0, 1, 4u, 4u);
	mData->mShaderDrawOnMesh.setConstantData(0, &constData[0], 8u);
	mData->mShaderDrawOnMesh.finish();

	mData->mPipelineDrawOnMesh.setShader(&mData->mShaderDrawOnMesh);
	mData->mPipelineDrawOnMesh.addDescriptorSet({ &mData->mDescriptorDrawOnMesh });
	mData->mPipelineDrawOnMesh.addPushConstant(rvk::Shader::Stage::COMPUTE, 0u, sizeof(DrawBuffer));
	mData->mPipelineDrawOnMesh.finish();

	mData->mDerivVisShader.addStage(rvk::Shader::Source::GLSL, rvk::Shader::Stage::RAYGEN, IALT_SHADER_DIR "deriv_vis_rgen.glsl", LightTraceOptimizer::shaderDefines);		// idx 0
	mData->mDerivVisShader.addStage(rvk::Shader::Source::GLSL, rvk::Shader::Stage::CLOSEST_HIT, IALT_SHADER_DIR "forward_rchit.glsl", LightTraceOptimizer::shaderDefines);	// idx 1
	mData->mDerivVisShader.addStage(rvk::Shader::Source::GLSL, rvk::Shader::Stage::MISS, IALT_SHADER_DIR "forward_rmiss.glsl", LightTraceOptimizer::shaderDefines);			// idx 2
	mData->mDerivVisShader.addGeneralShaderGroup(0);
	mData->mDerivVisShader.addHitShaderGroup(1);
	mData->mDerivVisShader.addGeneralShaderGroup(2);
	mData->mDerivVisShader.addConstant(0, 0, 4u, 0u);
	mData->mDerivVisShader.addConstant(0, 1, 4u, 4u);
	mData->mDerivVisShader.setConstantData(0, &constData[0], 8u);
	mData->mDerivVisShader.finish();

	SingleTimeCommand stc = mRoot.singleTimeCommand();
	stc.begin();
	// data for individual frames
	for (uint32_t frameIndex = 0; frameIndex < mRoot.frameCount(); frameIndex++) {
		VkFrameData& frameData = mFrameData[frameIndex];
		frameData.mColor.createImage2D(aRenderInfo.targetSize.x, aRenderInfo.targetSize.y, COLOR_FORMAT, rvk::Image::Use::DOWNLOAD | rvk::Image::Use::COLOR_ATTACHMENT | rvk::Image::Use::UPLOAD);
		frameData.mDepth.createImage2D(aRenderInfo.targetSize.x, aRenderInfo.targetSize.y, DEPTH_FORMAT, rvk::Image::Use::DOWNLOAD | rvk::Image::Use::DEPTH_STENCIL_ATTACHMENT);
		frameData.mColor.CMD_TransitionImage(stc.buffer(), VK_IMAGE_LAYOUT_GENERAL);
		frameData.mDepth.CMD_TransitionImage(stc.buffer(), VK_IMAGE_LAYOUT_DEPTH_STENCIL_ATTACHMENT_OPTIMAL);
	}

	mData->mDerivVisImageAccumulate.createImage2D(aRenderInfo.targetSize.x, aRenderInfo.targetSize.y, VK_FORMAT_R32G32B32A32_SFLOAT, rvk::Image::Use::DOWNLOAD | rvk::Image::Use::UPLOAD | rvk::Image::Use::STORAGE);
	mData->mDerivVisImageAccumulate.CMD_TransitionImage(stc.buffer(), VK_IMAGE_LAYOUT_GENERAL);
	mData->mDerivVisImageAccumulateCount.createImage2D(aRenderInfo.targetSize.x, aRenderInfo.targetSize.y, VK_FORMAT_R32_UINT, rvk::Image::Use::DOWNLOAD | rvk::Image::Use::UPLOAD | rvk::Image::Use::STORAGE);
	mData->mDerivVisImageAccumulateCount.CMD_TransitionImage(stc.buffer(), VK_IMAGE_LAYOUT_GENERAL);
	for (uint32_t idx = 0; idx < mRoot.frameCount(); idx++) {
		VkFrameData& frameData = mFrameData[idx];
		frameData.mDerivVisImage.createImage2D(aRenderInfo.targetSize.x, aRenderInfo.targetSize.y, VK_FORMAT_B8G8R8A8_UNORM, rvk::Image::Use::DOWNLOAD | rvk::Image::Use::UPLOAD | rvk::Image::Use::STORAGE);
		frameData.mDerivVisImage.CMD_TransitionImage(stc.buffer(), VK_IMAGE_LAYOUT_GENERAL);

		frameData.mDerivVisDescriptor.reserve(3);
		frameData.mDerivVisDescriptor.addStorageImage(DERIV_VIS_DESC_DERIV_IMAGE_BINDING, rvk::Shader::Stage::RAYGEN | rvk::Shader::Stage::COMPUTE);
		frameData.mDerivVisDescriptor.addStorageImage(DERIV_VIS_DESC_DERIV_ACC_IMAGE_BINDING, rvk::Shader::Stage::RAYGEN);
		frameData.mDerivVisDescriptor.addStorageImage(DERIV_VIS_DESC_DERIV_ACC_COUNT_IMAGE_BINDING, rvk::Shader::Stage::RAYGEN);
		frameData.mDerivVisDescriptor.addUniformBuffer(DERIV_VIS_DESC_UBO_BUFFER_BINDING, rvk::Shader::Stage::RAYGEN);
		frameData.mDerivVisDescriptor.addStorageBuffer(DERIV_VIS_DESC_TARGET_BUFFER_BINDING, rvk::Shader::Stage::RAYGEN);
		frameData.mDerivVisDescriptor.addStorageBuffer(DERIV_VIS_DESC_TARGET_RADIANCE_WEIGHTS_BUFFER_BINDING, rvk::Shader::Stage::RAYGEN);
		frameData.mDerivVisDescriptor.addStorageBuffer(DERIV_VIS_DESC_FD_RADIANCE_BUFFER_BINDING, rvk::Shader::Stage::RAYGEN);
		frameData.mDerivVisDescriptor.addStorageBuffer(DERIV_VIS_DESC_FD2_RADIANCE_BUFFER_BINDING, rvk::Shader::Stage::RAYGEN);

		frameData.mDerivVisDescriptor.setImage(DERIV_VIS_DESC_DERIV_IMAGE_BINDING, &frameData.mDerivVisImage);
		frameData.mDerivVisDescriptor.setImage(DERIV_VIS_DESC_DERIV_ACC_IMAGE_BINDING, &mData->mDerivVisImageAccumulate);
		frameData.mDerivVisDescriptor.setImage(DERIV_VIS_DESC_DERIV_ACC_COUNT_IMAGE_BINDING, &mData->mDerivVisImageAccumulateCount);
		frameData.mDerivVisDescriptor.setBuffer(DERIV_VIS_DESC_UBO_BUFFER_BINDING, &mData->mGlobalBuffer);
		frameData.mDerivVisDescriptor.finish(false);
	}

	mData->mDerivVisPipeline.setShader(&mData->mDerivVisShader);
	mData->mDerivVisPipeline.addDescriptorSet({ mData->mGpuTd.getDescriptor(), mLto.getAdjointDescriptor(), &mFrameData[0].mDerivVisDescriptor });
	mData->mDerivVisPipeline.finish();

	// path tracing
	{
		mData->mPTImageAccumulate.createImage2D(aRenderInfo.targetSize.x, aRenderInfo.targetSize.y, PT_ACCUMULATION_FORMAT, rvk::Image::Use::DOWNLOAD | rvk::Image::Use::UPLOAD | rvk::Image::Use::STORAGE);
		mData->mPTImageAccumulate.CMD_TransitionImage(stc.buffer(), VK_IMAGE_LAYOUT_GENERAL);

		mData->mPTImageAccumulateCount.createImage2D(aRenderInfo.targetSize.x, aRenderInfo.targetSize.y, PT_COUNT_FORMAT, rvk::Image::Use::DOWNLOAD | rvk::Image::Use::UPLOAD | rvk::Image::Use::STORAGE);
		mData->mPTImageAccumulateCount.CMD_TransitionImage(stc.buffer(), VK_IMAGE_LAYOUT_GENERAL);

		mData->mPTCacheImage.createImage2D(aRenderInfo.targetSize.x, aRenderInfo.targetSize.y, VK_FORMAT_B8G8R8A8_UNORM, rvk::Image::Use::DOWNLOAD | rvk::Image::Use::UPLOAD | rvk::Image::Use::STORAGE);
		mData->mPTCacheImage.CMD_TransitionImage(stc.buffer(), VK_IMAGE_LAYOUT_GENERAL);

		for (auto& frameData : mFrameData)
		{
			frameData.mPTImage.createImage2D(aRenderInfo.targetSize.x, aRenderInfo.targetSize.y, VK_FORMAT_R8G8B8A8_UNORM, rvk::Image::Use::DOWNLOAD | rvk::Image::Use::UPLOAD | rvk::Image::Use::STORAGE);
			frameData.mPTImage.CMD_TransitionImage(stc.buffer(), VK_IMAGE_LAYOUT_GENERAL);
			
			frameData.mPTDebugImage.createImage2D(aRenderInfo.targetSize.x, aRenderInfo.targetSize.y, VK_FORMAT_R8G8B8A8_UNORM, rvk::Image::Use::DOWNLOAD | rvk::Image::Use::UPLOAD | rvk::Image::Use::STORAGE);
			frameData.mPTDebugImage.CMD_TransitionImage(stc.buffer(), VK_IMAGE_LAYOUT_GENERAL);
			
			frameData.mPTGlobalUniformBuffer.create(rvk::Buffer::Use::UNIFORM, sizeof(PTGlobalBuffer), rvk::Buffer::Location::HOST_COHERENT);
			frameData.mPTGlobalUniformBuffer.mapBuffer();

			auto& desc = frameData.mPTGlobalDescriptor;
			desc.reserve(11);
			desc.addUniformBuffer(GLOBAL_DESC_UBO_BINDING, rvk::Shader::Stage::RAYGEN);
			desc.addStorageBuffer(GLOBAL_DESC_GEOMETRY_DATA_BINDING, rvk::Shader::Stage::RAYGEN | rvk::Shader::Stage::ANY_HIT);
			desc.addStorageBuffer(GLOBAL_DESC_INDEX_BUFFER_BINDING, rvk::Shader::Stage::RAYGEN | rvk::Shader::Stage::ANY_HIT);
			desc.addStorageBuffer(GLOBAL_DESC_VERTEX_BUFFER_BINDING, rvk::Shader::Stage::RAYGEN | rvk::Shader::Stage::ANY_HIT);
			desc.addStorageBuffer(GLOBAL_DESC_MATERIAL_BUFFER_BINDING, rvk::Shader::Stage::RAYGEN | rvk::Shader::Stage::ANY_HIT);
			desc.addStorageBuffer(GLSL_GLOBAL_LIGHT_DATA_BINDING, rvk::Shader::Stage::RAYGEN | rvk::Shader::Stage::INTERSECTION);
			desc.addAccelerationStructureKHR(GLOBAL_DESC_AS_BINDING, rvk::Shader::Stage::RAYGEN);
			desc.addStorageImage(GLSL_GLOBAL_RT_OUT_IMAGE_BINDING, rvk::Shader::Stage::RAYGEN);
			desc.addStorageImage(GLSL_GLOBAL_RT_ACC_IMAGE_BINDING, rvk::Shader::Stage::RAYGEN);
			desc.addStorageImage(GLSL_GLOBAL_RT_ACC_C_IMAGE_BINDING, rvk::Shader::Stage::RAYGEN);
			desc.addStorageImage(GLSL_GLOBAL_DEBUG_IMAGE_BINDING, rvk::Shader::Stage::RAYGEN);
			desc.addStorageBuffer(GLSL_GLOBAL_RADIANCE_BUFFER_BINDING, rvk::Shader::Stage::RAYGEN);
			desc.addStorageBuffer(PATH_GUIDING_INCIDENT_RADIANCE_BUFFER_BINDING, rvk::Shader::Stage::RAYGEN);
			desc.addStorageBuffer(PATH_GUIDING_PHI_MAP_BINDING, rvk::Shader::Stage::RAYGEN);
			desc.addStorageBuffer(PATH_GUIDING_LEGENDRE_MAP_BINDING, rvk::Shader::Stage::RAYGEN);
			desc.addStorageBuffer(PATH_GUIDING_NORMALIZATION_CONSTANTS_BINDING, rvk::Shader::Stage::RAYGEN);

			
			desc.setBuffer(GLOBAL_DESC_UBO_BINDING, &frameData.mPTGlobalUniformBuffer);
			desc.setBuffer(GLOBAL_DESC_MATERIAL_BUFFER_BINDING, mData->mGpuMd.getMaterialBuffer());
			desc.setBuffer(GLSL_GLOBAL_LIGHT_DATA_BINDING, mData->mGpuLd.getLightBuffer());
			desc.setImage(GLSL_GLOBAL_RT_OUT_IMAGE_BINDING, &frameData.mPTImage);
			desc.setImage(GLSL_GLOBAL_RT_ACC_IMAGE_BINDING, &mData->mPTImageAccumulate);
			desc.setImage(GLSL_GLOBAL_RT_ACC_C_IMAGE_BINDING, &mData->mPTImageAccumulateCount);
			desc.setImage(GLSL_GLOBAL_DEBUG_IMAGE_BINDING, &frameData.mPTDebugImage);
			desc.finish(false);
		}

		// shaders
		mData->mPTShader.addStage(rvk::Shader::Source::GLSL, rvk::Shader::Stage::RAYGEN, IALT_SHADER_DIR "pt/path_tracer.rgen", { "GLSL", "SAMPLE_BRDF" });
		mData->mPTShader.addStage(rvk::Shader::Source::GLSL, rvk::Shader::Stage::RAYGEN, IALT_SHADER_DIR "pt/path_tracer.rgen", { "GLSL", "SAMPLE_LIGHTS" });
		mData->mPTShader.addStage(rvk::Shader::Source::GLSL, rvk::Shader::Stage::RAYGEN, IALT_SHADER_DIR "pt/path_tracer.rgen", { "GLSL", "SAMPLE_LIGHTS_MIS" });
		mData->mPTShader.addStage(rvk::Shader::Source::GLSL, rvk::Shader::Stage::MISS, IALT_SHADER_DIR "pt/path_tracer.rmiss", { "GLSL" });
		mData->mPTShader.addStage(rvk::Shader::Source::GLSL, rvk::Shader::Stage::CLOSEST_HIT, IALT_SHADER_DIR "pt/path_tracer.rchit", { "GLSL" });
		mData->mPTShader.addStage(rvk::Shader::Source::GLSL, rvk::Shader::Stage::ANY_HIT, IALT_SHADER_DIR "pt/path_tracer.rahit", { "GLSL" });
		mData->mPTShader.addStage(rvk::Shader::Source::GLSL, rvk::Shader::Stage::INTERSECTION, IALT_SHADER_DIR "pt/sphere.rint", { "GLSL" });
		// shadow
		mData->mPTShader.addStage(rvk::Shader::Source::GLSL, rvk::Shader::Stage::MISS, IALT_SHADER_DIR "pt/shadow_ray.rmiss", { "GLSL" });
		mData->mPTShader.addStage(rvk::Shader::Source::GLSL, rvk::Shader::Stage::CLOSEST_HIT, IALT_SHADER_DIR "pt/shadow_ray.rchit", { "GLSL" });
		mData->mPTShader.addStage(rvk::Shader::Source::GLSL, rvk::Shader::Stage::ANY_HIT, IALT_SHADER_DIR "pt/shadow_ray.rahit", { "GLSL" });
		mData->mPTShader.addGeneralShaderGroup(IALT_SHADER_DIR "pt/path_tracer.rmiss"); // idx 0
		mData->mPTShader.addGeneralShaderGroup(IALT_SHADER_DIR "pt/shadow_ray.rmiss"); // idx 1
		mData->mPTShader.addGeneralShaderGroup(0); // rgen idx 0
		mData->mPTShader.addGeneralShaderGroup(1); // rgen idx 1
		mData->mPTShader.addGeneralShaderGroup(2); // rgen idx 2

		mData->mPTShader.addHitShaderGroup(IALT_SHADER_DIR "pt/path_tracer.rchit", IALT_SHADER_DIR "pt/path_tracer.rahit"); // idx 0
		mData->mPTShader.addProceduralShaderGroup(IALT_SHADER_DIR "pt/path_tracer.rchit", "", IALT_SHADER_DIR "pt/sphere.rint"); // idx 0

		mData->mPTShader.addHitShaderGroup(IALT_SHADER_DIR "pt/shadow_ray.rchit", IALT_SHADER_DIR "pt/shadow_ray.rahit"); // idx 1
		mData->mPTShader.addProceduralShaderGroup(IALT_SHADER_DIR "pt/shadow_ray.rchit", "", IALT_SHADER_DIR "pt/sphere.rint"); // idx 1
		mData->mPTShader.addConstant(0, 0, 4u, 0u);
		mData->mPTShader.addConstant(0, 1, 4u, 4u);
		mData->mPTShader.setConstantData(0, constData, 8);
		mData->mPTShader.addConstant(1, 0, 4u, 0u);
		mData->mPTShader.addConstant(1, 1, 4u, 4u);
		mData->mPTShader.setConstantData(1, constData, 8);
		mData->mPTShader.addConstant(2, 0, 4u, 0u);
		mData->mPTShader.addConstant(2, 1, 4u, 4u);
		mData->mPTShader.setConstantData(2, constData, 8);
		mData->mPTShader.finish();

		mData->mPTPipeline.setShader(&mData->mPTShader);
		mData->mPTPipeline.addDescriptorSet({ mData->mGpuTd.getDescriptor(), &mFrameData[0].mPTGlobalDescriptor});
		mData->mPTPipeline.finish();

		// path guiding shaders
		mData->mPathGuidingShader.addStage(rvk::Shader::Source::GLSL, rvk::Shader::Stage::RAYGEN, IALT_SHADER_DIR "pt/path_tracer.rgen", { "GLSL", "SAMPLE_BRDF", "PATH_GUIDING" });
		mData->mPathGuidingShader.addStage(rvk::Shader::Source::GLSL, rvk::Shader::Stage::RAYGEN, IALT_SHADER_DIR "pt/path_tracer.rgen", { "GLSL", "SAMPLE_LIGHTS", "PATH_GUIDING" });
		mData->mPathGuidingShader.addStage(rvk::Shader::Source::GLSL, rvk::Shader::Stage::RAYGEN, IALT_SHADER_DIR "pt/path_tracer.rgen", { "GLSL", "SAMPLE_LIGHTS_MIS", "PATH_GUIDING" });
		mData->mPathGuidingShader.addStage(rvk::Shader::Source::GLSL, rvk::Shader::Stage::MISS, IALT_SHADER_DIR "pt/path_tracer.rmiss", { "GLSL" });
		mData->mPathGuidingShader.addStage(rvk::Shader::Source::GLSL, rvk::Shader::Stage::CLOSEST_HIT, IALT_SHADER_DIR "pt/path_tracer.rchit", { "GLSL" });
		mData->mPathGuidingShader.addStage(rvk::Shader::Source::GLSL, rvk::Shader::Stage::ANY_HIT, IALT_SHADER_DIR "pt/path_tracer.rahit", { "GLSL" });
		mData->mPathGuidingShader.addStage(rvk::Shader::Source::GLSL, rvk::Shader::Stage::INTERSECTION, IALT_SHADER_DIR "pt/sphere.rint", { "GLSL" });
		// shadow
		mData->mPathGuidingShader.addStage(rvk::Shader::Source::GLSL, rvk::Shader::Stage::MISS, IALT_SHADER_DIR "pt/shadow_ray.rmiss", { "GLSL" });
		mData->mPathGuidingShader.addStage(rvk::Shader::Source::GLSL, rvk::Shader::Stage::CLOSEST_HIT, IALT_SHADER_DIR "pt/shadow_ray.rchit", { "GLSL" });
		mData->mPathGuidingShader.addStage(rvk::Shader::Source::GLSL, rvk::Shader::Stage::ANY_HIT, IALT_SHADER_DIR "pt/shadow_ray.rahit", { "GLSL" });
		mData->mPathGuidingShader.addGeneralShaderGroup(IALT_SHADER_DIR "pt/path_tracer.rmiss"); // idx 0
		mData->mPathGuidingShader.addGeneralShaderGroup(IALT_SHADER_DIR "pt/shadow_ray.rmiss"); // idx 1
		mData->mPathGuidingShader.addGeneralShaderGroup(0); // rgen idx 0
		mData->mPathGuidingShader.addGeneralShaderGroup(1); // rgen idx 1
		mData->mPathGuidingShader.addGeneralShaderGroup(2); // rgen idx 2

		mData->mPathGuidingShader.addHitShaderGroup(IALT_SHADER_DIR "pt/path_tracer.rchit", IALT_SHADER_DIR "pt/path_tracer.rahit"); // idx 0
		mData->mPathGuidingShader.addProceduralShaderGroup(IALT_SHADER_DIR "pt/path_tracer.rchit", "", IALT_SHADER_DIR "pt/sphere.rint"); // idx 0

		mData->mPathGuidingShader.addHitShaderGroup(IALT_SHADER_DIR "pt/shadow_ray.rchit", IALT_SHADER_DIR "pt/shadow_ray.rahit"); // idx 1
		mData->mPathGuidingShader.addProceduralShaderGroup(IALT_SHADER_DIR "pt/shadow_ray.rchit", "", IALT_SHADER_DIR "pt/sphere.rint"); // idx 1
		mData->mPathGuidingShader.addConstant(0, 0, 4u, 0u);
		mData->mPathGuidingShader.addConstant(0, 1, 4u, 4u);
		mData->mPathGuidingShader.setConstantData(0, constData, 8);
		mData->mPathGuidingShader.addConstant(1, 0, 4u, 0u);
		mData->mPathGuidingShader.addConstant(1, 1, 4u, 4u);
		mData->mPathGuidingShader.setConstantData(1, constData, 8);
		mData->mPathGuidingShader.addConstant(2, 0, 4u, 0u);
		mData->mPathGuidingShader.addConstant(2, 1, 4u, 4u);
		mData->mPathGuidingShader.setConstantData(2, constData, 8);
		mData->mPathGuidingShader.finish();

		mData->mPathGuidingPipeline.setShader(&mData->mPathGuidingShader);
		mData->mPathGuidingPipeline.addDescriptorSet({ mData->mGpuTd.getDescriptor(), &mFrameData[0].mPTGlobalDescriptor });
		mData->mPathGuidingPipeline.finish();
	}
	stc.end();

	computeHSHIntegrals();
	computeNormalizationConstants();
}
void InteractiveAdjointLightTracing::destroy() {
	mData.reset();
	mFrameData.clear();

	mLto.destroy();
}

// scene load/unload
void InteractiveAdjointLightTracing::sceneLoad(tamashii::SceneBackendData aScene) {
	SingleTimeCommand stc = mRoot.singleTimeCommand();
	uint64_t vertexCount = 0;

	// FIXME: Always clear the map of light references, even if the scene is empty to prevent
	// dangling pointers. There should be a better way than reaching into the optimizer object 
	// like this.
	mLights = nullptr;
	mLto.getLightOptParams().clear();

	// count all vertices, if zero return else create radiance buffers with vertex_count * stride size
	for (auto model : aScene.models) for (const auto& mesh : *model) vertexCount += mesh->getVertexCount();
	if (!vertexCount) return;

	mLights = &aScene.refLights;

	mData->mGpuTd.loadScene(&stc, aScene);
	mData->mGpuMd.loadScene(&stc, aScene, &mData->mGpuTd);
	mData->mGpuBlas.loadScene(&stc, aScene);
	mData->mGpuLd.loadScene(&stc, aScene, &mData->mGpuTd, &mData->mGpuBlas);
	for (auto& fd : mFrameData) fd.mGpuTlas.loadScene(&stc, aScene, &mData->mGpuBlas, &mData->mGpuMd);

	mData->mRadianceBufferCopy.create(rvk::Buffer::Use::STORAGE | rvk::Buffer::Use::UPLOAD, vertexCount * LightTraceOptimizer::entries_per_vertex * sizeof(float), rvk::Buffer::Location::DEVICE);
	mData->mFdRadianceBufferCopy.create(rvk::Buffer::Use::STORAGE | rvk::Buffer::Use::UPLOAD, vertexCount * LightTraceOptimizer::entries_per_vertex * sizeof(float), rvk::Buffer::Location::DEVICE);
	mData->mFd2RadianceBufferCopy.create(rvk::Buffer::Use::STORAGE | rvk::Buffer::Use::UPLOAD, vertexCount * LightTraceOptimizer::entries_per_vertex * sizeof(float), rvk::Buffer::Location::DEVICE);
	
	mData->mIncidentRadianceBufferCopy.create(rvk::Buffer::Use::STORAGE | rvk::Buffer::Use::UPLOAD, vertexCount * LightTraceOptimizer::entries_per_vertex * sizeof(float), rvk::Buffer::Location::DEVICE);

	stc.begin();
	mData->mRadianceBufferCopy.CMD_FillBuffer(stc.buffer(), 0);
	mData->mIncidentRadianceBufferCopy.CMD_FillBuffer(stc.buffer(), 0);
	mData->mFdRadianceBufferCopy.CMD_FillBuffer(stc.buffer(), 0);
	mData->mFd2RadianceBufferCopy.CMD_FillBuffer(stc.buffer(), 0);
	stc.end();
	mLto.sceneLoad(aScene, vertexCount);
	if (mData->mGpuLd.getLightCount()) {
		mLto.forward(mLto.getCurrentParams(), &mData->mRadianceBufferCopy, &mData->mIncidentRadianceBufferCopy);
	}

	mData->mDescriptor.setBuffer(RASTERIZER_DESC_GLOBAL_BUFFER_BINDING, &mData->mGlobalBuffer);
	mData->mDescriptor.setBuffer(RASTERIZER_DESC_RADIANCE_BUFFER_BINDING, &mData->mRadianceBufferCopy);
	mData->mDescriptor.setBuffer(RASTERIZER_DESC_VTX_COLOR_BUFFER_BINDING, mLto.getVtxTextureColorBuffer());
	mData->mDescriptor.setBuffer(RASTERIZER_DESC_AREA_BUFFER_BINDING, mLto.getVtxAreaBuffer());
	mData->mDescriptor.setBuffer(RASTERIZER_DESC_MATERIAL_BUFFER_BINDING, mData->mGpuMd.getMaterialBuffer());
	mData->mDescriptor.setBuffer(RASTERIZER_DESC_TARGET_RADIANCE_BUFFER_BINDING, mLto.getTargetRadianceBuffer());
	mData->mDescriptor.setBuffer(RASTERIZER_DESC_TARGET_RADIANCE_WEIGHTS_BUFFER_BINDING, mLto.getTargetRadianceWeightsBuffer());
	mData->mDescriptor.setBuffer(RASTERIZER_DESC_CHANNEL_WEIGHTS_BUFFER_BINDING, mLto.getChannelWeightsBuffer());
	mData->mDescriptor.update();

	mData->mDescriptorDrawOnMesh.setBuffer(DRAW_DESC_VERTEX_BUFFER_BINDING, mData->mGpuBlas.getVertexBuffer());
	mData->mDescriptorDrawOnMesh.setBuffer(DRAW_DESC_TARGET_RADIANCE_BUFFER_BINDING, mLto.getTargetRadianceBuffer());
	mData->mDescriptorDrawOnMesh.setBuffer(DRAW_DESC_TARGET_RADIANCE_WEIGHTS_BUFFER_BINDING, mLto.getTargetRadianceWeightsBuffer());
	mData->mDescriptorDrawOnMesh.update();

	for (uint32_t idx = 0; idx < mRoot.frameCount(); idx++) {
		VkFrameData& frameData = mFrameData[idx];
		frameData.mDerivVisDescriptor.setBuffer(DERIV_VIS_DESC_TARGET_BUFFER_BINDING, mLto.getTargetRadianceBuffer());
		frameData.mDerivVisDescriptor.setBuffer(DERIV_VIS_DESC_TARGET_RADIANCE_WEIGHTS_BUFFER_BINDING, mLto.getTargetRadianceWeightsBuffer());
		frameData.mDerivVisDescriptor.setBuffer(DERIV_VIS_DESC_FD_RADIANCE_BUFFER_BINDING, &mData->mFdRadianceBufferCopy);
		frameData.mDerivVisDescriptor.setBuffer(DERIV_VIS_DESC_FD2_RADIANCE_BUFFER_BINDING, &mData->mFd2RadianceBufferCopy);
		frameData.mDerivVisDescriptor.update();

		// update path-tracing data
		frameData.mPTGlobalDescriptor.setBuffer(GLOBAL_DESC_GEOMETRY_DATA_BINDING, frameData.mGpuTlas.getGeometryDataBuffer());
		frameData.mPTGlobalDescriptor.setBuffer(GLOBAL_DESC_INDEX_BUFFER_BINDING, mData->mGpuBlas.getIndexBuffer());
		frameData.mPTGlobalDescriptor.setBuffer(GLOBAL_DESC_VERTEX_BUFFER_BINDING, mData->mGpuBlas.getVertexBuffer());
		frameData.mPTGlobalDescriptor.setBuffer(GLSL_GLOBAL_RADIANCE_BUFFER_BINDING, &mData->mRadianceBufferCopy);
		frameData.mPTGlobalDescriptor.setBuffer(PATH_GUIDING_INCIDENT_RADIANCE_BUFFER_BINDING, &mData->mIncidentRadianceBufferCopy);
		frameData.mPTGlobalDescriptor.setBuffer(PATH_GUIDING_PHI_MAP_BINDING, &mData->mPGPhiMapBuffer);
		frameData.mPTGlobalDescriptor.setBuffer(PATH_GUIDING_LEGENDRE_MAP_BINDING, &mData->mPGLegendreMapBuffer);
		frameData.mPTGlobalDescriptor.setBuffer(PATH_GUIDING_NORMALIZATION_CONSTANTS_BINDING, &mData->mPGNormalizationConstantsBuffer);
		if (mData->mGpuMd.bufferChanged(false)) frameData.mPTGlobalDescriptor.setBuffer(GLOBAL_DESC_MATERIAL_BUFFER_BINDING, mData->mGpuMd.getMaterialBuffer());

		frameData.mPTGlobalDescriptor.setAccelerationStructureKHR(GLOBAL_DESC_AS_BINDING, frameData.mGpuTlas.getTlas());
		frameData.mPTGlobalDescriptor.update();
	}

	if (mFDGradImageSelection.has_value()) mLto.fillFiniteDiffRadianceBuffer(mFDGradImageSelection.value().first, mFDGradImageSelection.value().second, mFdGradH, &mData->mFdRadianceBufferCopy, &mData->mFd2RadianceBufferCopy);
	mSceneLoaded = true;

	if (!LightTraceOptimizer::vars::runPredefinedTest.value().empty()) {
		spdlog::info("running test case {}", LightTraceOptimizer::vars::runPredefinedTest.value());
		auto testname = LightTraceOptimizer::vars::runPredefinedTest.value();

		startOptimizerThread([aScene, testname](InteractiveAdjointLightTracing* ialt, LightTraceOptimizer* aLto, rvk::Buffer* aRadianceBufferOut, unsigned int, float, int) {
			aLto->runPredefinedTestCase(ialt, aScene, testname, aRadianceBufferOut);
		});
	}

	mPTGlobalUBO.shade = !aScene.refLights.empty();
	for (auto& model : aScene.models) {
		bool hasLight = std::any_of(model->begin(), model->end(), [](const auto& mesh) {
			return mesh->getMaterial()->isLight();
		});
		if (hasLight) {
			mPTGlobalUBO.shade = true;
			break;
		}
	}
}

void InteractiveAdjointLightTracing::sceneUnload(tamashii::SceneBackendData aScene) {
	mSceneLoaded = false;
	mData->mGpuTd.unloadScene();
	mData->mGpuMd.unloadScene();
	mData->mGpuLd.unloadScene();
	mData->mGpuBlas.unloadScene();
	for (auto& fd : mFrameData) fd.mGpuTlas.unloadScene();


	mLto.sceneUnload(aScene);
	mData->mRadianceBufferCopy.destroy();
	mData->mIncidentRadianceBufferCopy.destroy();
	mData->mFdRadianceBufferCopy.destroy();
	mData->mFd2RadianceBufferCopy.destroy();
}

// draw frame
void InteractiveAdjointLightTracing::drawView(tamashii::ViewDef_s* aViewDef) {
	mNextFrameCV.notify_all();

	//if (var::headless.getBool()) return;
	CommandBuffer& cb = mRoot.currentCmdBuffer();
	SingleTimeCommand stc = mRoot.singleTimeCommand();
	const uint32_t fi = mRoot.currentIndex();
	VkFrameData& frameData = mFrameData[fi];
	const auto cc = glm::vec3{ var::varToVec(var::bg) } / 255.0f;

	GlobalBufferR buffer = {};
	buffer.viewMat = aViewDef->view_matrix;
	buffer.projMat = aViewDef->projection_matrix;
	buffer.inverseViewMat = aViewDef->inv_view_matrix;
	buffer.inverseProjMat = aViewDef->inv_projection_matrix;
	buffer.viewPos = glm::vec4(aViewDef->view_pos, 1);
	buffer.viewDir = glm::vec4(aViewDef->view_dir, 0);
	buffer.wireframeColor[0] = mWireframeColor.x; buffer.wireframeColor[1] = mWireframeColor.y; buffer.wireframeColor[2] = mWireframeColor.z; buffer.wireframeColor[3] = 1.0f;
	buffer.size[0] = static_cast<float>(aViewDef->target_size.x); buffer.size[1] = static_cast<float>(aViewDef->target_size.y);
	buffer.shade = 0;
	if (!LightTraceOptimizer::vars::constRandSeed) buffer.frameCount = aViewDef->frame_index;
	buffer.grad_vis_accumulate = mGradVisAcc;
	buffer.show_target = mShowTarget;
	buffer.show_alpha = mShowAlpha;
	buffer.show_adjoint_deriv = mShowAdjointDeriv;
	buffer.show_wireframe_overlay = mShowWireframeOverlay;
	buffer.adjoint_range = mAdjointVisRange;
	buffer.log_adjoint_vis = mAdjointVisLog;
	buffer.dither_strength = 0;
	buffer.grad_range = mGradRange;
	buffer.bg_color = glm::vec4(cc, 1.0f);
	buffer.cull_mode = mActiveCullMode;
	buffer.log_grad_vis = mGradVisLog;
	buffer.grad_vis_weights = mGradVisWeights;
	buffer.fd_grad_vis = mShowFDGrad;
	buffer.fd_grad_h = mFdGradH;
	if (mGradImageSelection.has_value())
	{
		buffer.light_deriv_idx = mGradImageSelection.value().first->ref_light_index;
		buffer.param_deriv_idx = mGradImageSelection.value().second;
	}
	else if (mFDGradImageSelection.has_value())
	{
		buffer.light_deriv_idx = mFDGradImageSelection.value().first->ref_light_index;
		buffer.param_deriv_idx = mFDGradImageSelection.value().second;
	}
	mData->mGlobalBuffer.STC_UploadData(&stc, &buffer, sizeof(GlobalBufferR));

	// process updates
	if (aViewDef->updates.any()) for (uint32_t i = 0; i < mRoot.frameCount(); i++) mUpdates[i] = mUpdates[i] | aViewDef->updates;
	if (mUpdates[fi].mModelInstances || mUpdates[fi].mMaterials || mUpdates[fi].mLights) {
		//spdlog::debug("update {}", vd->frame_index);
		mUpdates[fi].mModelInstances = mUpdates[fi].mMaterials = false;
		SingleTimeCommand stc = mRoot.singleTimeCommand();
		mData->mGpuMd.update(&stc, aViewDef->scene, &mData->mGpuTd);
		frameData.mGpuTlas.update(&cb, &stc, aViewDef->scene, &mData->mGpuBlas, &mData->mGpuMd, mPTGlobalUBO.light_geometry ? &mData->mGpuLd : nullptr);

		cb.cmdMemoryBarrier(VK_PIPELINE_STAGE_ACCELERATION_STRUCTURE_BUILD_BIT_KHR, VK_PIPELINE_STAGE_RAY_TRACING_SHADER_BIT_KHR);
		frameData.mPTGlobalDescriptor.setAccelerationStructureKHR(GLOBAL_DESC_AS_BINDING, frameData.mGpuTlas.getTlas());
		frameData.mPTGlobalDescriptor.update();
		// mLto uses the tlas from `mFrameData[0]`
		if (fi == 0) mLto.updateTlas(&frameData.mGpuTlas);
	}
	mUpdates[fi].reset();
	if (aViewDef->updates.mImages || aViewDef->updates.mTextures) {
		mRoot.device.waitIdle();
		SingleTimeCommand stc = mRoot.singleTimeCommand();
		mData->mGpuTd.update(&stc, aViewDef->scene);
	}
	if (aViewDef->updates.mLights) {
		mData->mGpuLd.update(&stc, aViewDef->scene, &mData->mGpuTd, &mData->mGpuBlas);
		mLto.updateParamsFromScene();
		if (mFDGradImageSelection.has_value()) mLto.fillFiniteDiffRadianceBuffer(mFDGradImageSelection.value().first, mFDGradImageSelection.value().second, mFdGradH, &mData->mFdRadianceBufferCopy, &mData->mFd2RadianceBufferCopy);
		else mLto.forward(mLto.getCurrentParams(), &mData->mRadianceBufferCopy, &mData->mIncidentRadianceBufferCopy);
	}
	if (aViewDef->updates.any() || mRecalculate) {
		mRecalculate = false;
		mPTGlobalUBO.accumulatedFrames = 0;
		mData->mPTImageAccumulate.CMD_ClearColor(&cb, 0, 0, 0, 0);
		mData->mPTImageAccumulateCount.CMD_ClearColor(&cb, 0, 0, 0, 0);
	}
	if (aViewDef->updates.mModelGeometries) {
		mData->mGpuBlas.update(&stc, aViewDef->scene);

		// allow painting of targets (--> new objective), but don't update anything else ...
		//mLto.buildObjectiveFunction(aViewDef->scene);
	}
	if (mUpdatePathGuidingBuffers) {
		recomputePathGuidingBuffers();
		mUpdatePathGuidingBuffers = false;
	}

	if (mPTRenderer)
	{
		bool shouldRender = true;
		if (aViewDef->surfaces.empty()) {
			bool has_lights = false;
			for (const auto a : aViewDef->lights) {
				if (a->light->getType() == Light::Type::POINT || a->light->getType() == Light::Type::IES) { has_lights = true; break; }
			}
			if (!has_lights) {
				shouldRender = false;
			}
		};

		if (shouldRender)
		{
			if (mTakeScreenshot)
			{
				const std::chrono::milliseconds elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::high_resolution_clock::now() - mScreenshotStartTime);
				if (mPTGlobalUBO.accumulatedFrames >= mScreenshotSpp || elapsed.count() > mScreenshotTimeLimit)
				{
					const time_t t = std::time(nullptr);
					std::string filename = std::to_string(t) +
						std::string("_spp-") + std::to_string(mPTGlobalUBO.accumulatedFrames) +
						"_time-" + std::to_string(std::min(mScreenshotTimeLimit, uint32_t(elapsed.count()))) +
						"_max_bounces-" + std::to_string(mPTGlobalUBO.max_depth) +
						"_cr-" + (mPTGlobalUBO.use_cached_radiance ? std::string("y") : std::string("n")) +
						"_pg-" + (mUsePathGuiding ? std::string("y") : std::string("n")) +
						"_pgd-" + std::to_string(mPTGlobalUBO.pg_hsh_subdivision_depth) +
						"_shord-" + std::to_string(LTOVars::shOrder) +
						".png";
					std::filesystem::path p(filename);
					const VkExtent3D extent = mFrameData[mRoot.previousIndex()].mPTImage.getExtent();
					std::vector<glm::u8vec4> data_uint8(extent.width * extent.height);
					mFrameData[mRoot.previousIndex()].mPTImage.STC_DownloadData2D(&stc, extent.width, extent.height, 4, data_uint8.data());
					std::vector<glm::u8vec4> data_image;
					data_image.reserve(extent.width * extent.height);
					//for (const auto v : data_uint8) data_image.emplace_back(v.z, v.y, v.x, v.w);
					io::Export::save_image_png_8_bit(p.string(), extent.width, extent.height, 4, reinterpret_cast<uint8_t*>(data_uint8.data()));
					spdlog::info("Impl-Screenshot saved: {}", p.string());
					mTakeScreenshot = false;
				}
			}

			frameData.mPTImage.CMD_ClearColor(&cb, cc.x, cc.y, cc.z, 1.0f);
			frameData.mPTDebugImage.CMD_ClearColor(&cb, 0, 0, 0, 1.0f);

			static glm::vec2 debugClickPos = glm::vec2{ 0, 0 };
			if (tamashii::InputSystem::getInstance().wasPressed(tamashii::Input::MOUSE_LEFT)) {
				debugClickPos = tamashii::InputSystem::getInstance().getMousePosAbsolute();
			}

			SingleTimeCommand stl = mRoot.singleTimeCommand();
			stl.begin();

			// ubo
			mPTGlobalUBO.viewMat = aViewDef->view_matrix;
			mPTGlobalUBO.projMat = aViewDef->projection_matrix;
			mPTGlobalUBO.inverseViewMat = aViewDef->inv_view_matrix;
			mPTGlobalUBO.inverseProjMat = aViewDef->inv_projection_matrix;
			mPTGlobalUBO.viewPos = glm::vec4{ aViewDef->view_pos, 1 };
			mPTGlobalUBO.viewDir = glm::vec4{ aViewDef->view_dir, 0 };
			mPTGlobalUBO.debugPixelPosition = debugClickPos;
			mPTGlobalUBO.cull_mode = mActiveCullMode;
			Common::getInstance().intersectionSettings().mCullMode = static_cast<CullMode>(mActiveCullMode);
			mPTGlobalUBO.bg[0] = mEnvLight[0] * mEnvLightIntensity; mPTGlobalUBO.bg[1] = mEnvLight[1] * mEnvLightIntensity; mPTGlobalUBO.bg[2] = mEnvLight[2] * mEnvLightIntensity; mPTGlobalUBO.bg[3] = 1;
			mPTGlobalUBO.size[0] = static_cast<float>(aViewDef->target_size.x); mPTGlobalUBO.size[1] = static_cast<float>(aViewDef->target_size.y);
			mPTGlobalUBO.frameIndex = static_cast<float>(aViewDef->frame_index);
			mPTGlobalUBO.light_count = mData->mGpuLd.getLightCount();
			// accumulation calc
			if (mPTGlobalUBO.accumulate) mPTGlobalUBO.accumulatedFrames += mPTGlobalUBO.pixelSamplesPerFrame;
			else mPTGlobalUBO.accumulatedFrames = mPTGlobalUBO.pixelSamplesPerFrame;
			frameData.mPTGlobalUniformBuffer.STC_UploadData(&stl, &mPTGlobalUBO, sizeof(PTGlobalBuffer));
			stl.end();

			if (!aViewDef->scene.refModels.empty()) {
				if (mUsePathGuiding) {
					mData->mPathGuidingPipeline.CMD_BindDescriptorSets(&cb, { mData->mGpuTd.getDescriptor(), &frameData.mPTGlobalDescriptor });
					mData->mPathGuidingPipeline.CMD_BindPipeline(&cb);
					mData->mPathGuidingPipeline.CMD_TraceRays(&cb, aViewDef->target_size.x, aViewDef->target_size.y, 1, mPTGlobalUBO.sampling_strategy);
				}
				else {
					mData->mPTPipeline.CMD_BindDescriptorSets(&cb, { mData->mGpuTd.getDescriptor(), &frameData.mPTGlobalDescriptor });
					mData->mPTPipeline.CMD_BindPipeline(&cb);
					mData->mPTPipeline.CMD_TraceRays(&cb, aViewDef->target_size.x, aViewDef->target_size.y, 1, mPTGlobalUBO.sampling_strategy);
				}
			}

			//if (mShowCache) rvk::swapchain::CMD_BlitImageToImage(&cb, &mData->cacheImage, &mRoot.currentImage(), VK_FILTER_LINEAR);
			//else if (mDebugImage) rvk::swapchain::CMD_BlitImageToImage(&cb, &frameData.debugImage, &mRoot.currentImage(), VK_FILTER_LINEAR);
			///*else */ rvk::swapchain::CMD_BlitImageToImage(&cb, &frameData.mPTImage, &mRoot.currentImage(), VK_FILTER_LINEAR);
		}
		rvk::swapchain::CMD_BlitImageToImage(&cb, &frameData.mPTImage, &frameData.mColor, VK_FILTER_LINEAR);
		cb.cmdBeginRendering(
			{ { &frameData.mColor,
			VkClearColorValue{} , RVK_L2, RVK_S2} },
			{ &frameData.mDepth,
			{0.0f, 0}, RVK_LC, RVK_S2 }
		);
		// TODO: figure out why this is needed, and why this works
		mData->mPipelineCullNone.CMD_SetViewport(&cb);
		mData->mPipelineCullNone.CMD_SetScissor(&cb);

		// bind vertex and index buffer for use by sh visualization
		mData->mGpuBlas.getVertexBuffer()->CMD_BindVertexBuffer(&cb, 0, 0);
		mData->mGpuBlas.getIndexBuffer()->CMD_BindIndexBuffer(&cb, VK_INDEX_TYPE_UINT32, 0);
	} else {
		cb.cmdBeginRendering(
			{ { &frameData.mColor,
			{{cc.x, cc.y, cc.z, 1.0f}} , RVK_LC, RVK_S2 } },
			{ &frameData.mDepth,
			{0.0f, 0}, RVK_LC, RVK_S2 }
		);
		Common::getInstance().intersectionSettings().mCullMode = static_cast<tamashii::CullMode>(mActiveCullMode);
		if (mActiveCullMode == 0) mData->mCurrentPipeline = &mData->mPipelineCullNone;
		else if (mActiveCullMode == 1) mData->mCurrentPipeline = &mData->mPipelineCullFront;
		else if (mActiveCullMode == 2) mData->mCurrentPipeline = &mData->mPipelineCullBack;

		if (!aViewDef->surfaces.empty()) {
			mData->mCurrentPipeline->CMD_BindDescriptorSets(&cb, { mData->mGpuTd.getDescriptor(), &mData->mDescriptor });
			mData->mCurrentPipeline->CMD_BindPipeline(&cb);
			mData->mGpuBlas.getVertexBuffer()->CMD_BindVertexBuffer(&cb, 0, 0);
			mData->mGpuBlas.getIndexBuffer()->CMD_BindIndexBuffer(&cb, VK_INDEX_TYPE_UINT32, 0);
		}

		mData->mCurrentPipeline->CMD_SetViewport(&cb);
		mData->mCurrentPipeline->CMD_SetScissor(&cb);
		for (const DrawSurf_s& ds : aViewDef->surfaces) {
			// if cull is enabled, don't cull no backface cull meshes
			if (!ds.refMesh->mesh->getMaterial()->getCullBackface()) mData->mPipelineCullNone.CMD_BindPipeline(&cb);
			else mData->mCurrentPipeline->CMD_BindPipeline(&cb);

			int materialIndex = mData->mGpuMd.getIndex(ds.refMesh->mesh->getMaterial());
			mData->mCurrentPipeline->CMD_SetPushConstant(&cb, VK_SHADER_STAGE_VERTEX_BIT | VK_SHADER_STAGE_FRAGMENT_BIT, 0, 16 * sizeof(float), &ds.model_matrix[0][0]);
			mData->mCurrentPipeline->CMD_SetPushConstant(&cb, VK_SHADER_STAGE_VERTEX_BIT | VK_SHADER_STAGE_FRAGMENT_BIT, 16 * sizeof(float), 1 * sizeof(int), &materialIndex);

			if (ds.refMesh->mesh->hasIndices()) mData->mCurrentPipeline->CMD_DrawIndexed(&cb, static_cast<uint32_t>(ds.refMesh->mesh->getIndexCount()),
				mData->mGpuBlas.getOffset(ds.refMesh->mesh.get()).mIndexOffset, mData->mGpuBlas.getOffset(ds.refMesh->mesh.get()).mVertexOffset);
			else mData->mCurrentPipeline->CMD_Draw(&cb, static_cast<uint32_t>(ds.refMesh->mesh->getVertexCount()), mData->mGpuBlas.getOffset(ds.refMesh->mesh.get()).mVertexOffset);
		}
	}

#ifdef IALT_USE_SPHERICAL_HARMONICS
	if (mShowSH) {
		/*mSwapchain->getCurrentDepthImage()->CMD_ImageBarrier(cb,
			VK_PIPELINE_STAGE_EARLY_FRAGMENT_TESTS_BIT, VK_PIPELINE_STAGE_EARLY_FRAGMENT_TESTS_BIT,
			VK_ACCESS_DEPTH_STENCIL_ATTACHMENT_WRITE_BIT, VK_ACCESS_DEPTH_STENCIL_ATTACHMENT_READ_BIT, VK_DEPENDENCY_BY_REGION_BIT);*/
		mData->mPipelineSHVisNone.CMD_BindPipeline(&cb);
		mData->mPipelineSHVisNone.CMD_BindDescriptorSets(&cb, { &mData->mDescriptor });
		for (const DrawSurf_s& ds : aViewDef->surfaces) {
			mData->mPipelineSHVisNone.CMD_SetPushConstant(&cb, VK_SHADER_STAGE_VERTEX_BIT, 0, 16 * sizeof(float), &ds.model_matrix[0][0]);
			mData->mPipelineSHVisNone.CMD_SetPushConstant(&cb, VK_SHADER_STAGE_GEOMETRY_BIT | VK_SHADER_STAGE_FRAGMENT_BIT, 16 * sizeof(float), 1 * sizeof(float), &mShowSHSize);
			mData->mPipelineSHVisNone.CMD_Draw(&cb, static_cast<uint32_t>(ds.refMesh->mesh->getVertexCount()), mData->mGpuBlas.getOffset(ds.refMesh->mesh.get()).mVertexOffset);
		}
	}

	if (mShowInterpolatedSH) {
		const IntersectionSettings settings(CullMode::None, HitMask::Geometry);
		Intersection info = {};
		Common::getInstance().intersectScene(settings, &info);
		if (info.mHit && info.mHit->type == Ref::Type::Model) {
			glm::uvec4 offsets = { 0,0,0,0 };
			bool stop = false;
			for (auto model : aViewDef->scene.refModels) {
				for (const auto& mesh : model->refMeshes) {
					if (mesh.get() != info.mRefMeshHit) offsets += mesh->mesh->getVertexCount();
					else {
						offsets.x += mesh->mesh->getIndicesArray()[3u * info.mPrimitiveIndex + 0];
						offsets.y += mesh->mesh->getIndicesArray()[3u * info.mPrimitiveIndex + 1];
						offsets.z += mesh->mesh->getIndicesArray()[3u * info.mPrimitiveIndex + 2];
						stop = true;
					}
					if (stop) break;
				}
				if (stop) break;
			}

			/*mSwapchain->getCurrentDepthImage()->CMD_ImageBarrier(cb,
				VK_PIPELINE_STAGE_EARLY_FRAGMENT_TESTS_BIT, VK_PIPELINE_STAGE_EARLY_FRAGMENT_TESTS_BIT,
				VK_ACCESS_DEPTH_STENCIL_ATTACHMENT_WRITE_BIT, VK_ACCESS_DEPTH_STENCIL_ATTACHMENT_READ_BIT, VK_DEPENDENCY_BY_REGION_BIT);*/
			mData->mPipelineSingleSHVisNone.CMD_BindPipeline(&cb);
			mData->mPipelineSingleSHVisNone.CMD_BindDescriptorSets(&cb, { &mData->mDescriptor });

			mData->mPipelineSingleSHVisNone.CMD_SetPushConstant(&cb, VK_SHADER_STAGE_VERTEX_BIT, 0, 16u, &info.mHitPos);
			mData->mPipelineSingleSHVisNone.CMD_SetPushConstant(&cb, VK_SHADER_STAGE_GEOMETRY_BIT | VK_SHADER_STAGE_FRAGMENT_BIT, 16u * 2u + 8u, 4u, &mShowSHSize);
			mData->mPipelineSingleSHVisNone.CMD_SetPushConstant(&cb, VK_SHADER_STAGE_FRAGMENT_BIT, 16u, 16u, &offsets);
			mData->mPipelineSingleSHVisNone.CMD_SetPushConstant(&cb, VK_SHADER_STAGE_FRAGMENT_BIT, 16u * 2u, 8u, &info.mBarycentric);
			mData->mPipelineSingleSHVisNone.CMD_Draw(&cb, 1);
		}
	}
#endif
	cb.cmdEndRendering();
	rvk::swapchain::CMD_BlitImageToImage(&cb, &frameData.mColor, &mRoot.currentImage(), VK_FILTER_LINEAR);


	if (!mShowGrad) mGradImageSelection.reset();
	if (!mShowFDGrad) mFDGradImageSelection.reset();
	if (mGradImageSelection.has_value() || mFDGradImageSelection.has_value()) {
		if (!aViewDef->scene.refModels.empty()) {
			mData->mDerivVisPipeline.CMD_BindDescriptorSets(&cb, { mData->mGpuTd.getDescriptor(), mLto.getAdjointDescriptor(), &frameData.mDerivVisDescriptor });
			mData->mDerivVisPipeline.CMD_BindPipeline(&cb);
			mData->mDerivVisPipeline.CMD_TraceRays(&cb, aViewDef->target_size.x, aViewDef->target_size.y, 1);
		}
		rvk::swapchain::CMD_BlitImageToImage(&cb, &frameData.mDerivVisImage, &mRoot.currentImage(), VK_FILTER_LINEAR);
	}
}

void InteractiveAdjointLightTracing::drawUI(tamashii::UiConf_s* aUiConf) {
	if (ImGui::Begin("Settings", nullptr, 0))
	{
		bool fwdPT = mLto.useForwardPT();
		bool bwdPT = mLto.useBackwardPT();
		ImGui::Separator();
		ImGui::PushItemWidth(110);
		if (!fwdPT || !bwdPT) {
			int xRays = LightTraceOptimizer::vars::numRaysXperLight.value();
			int yRays = LightTraceOptimizer::vars::numRaysYperLight.value();
			if (ImGui::SliderInt("##xRay", &xRays, 1, 10000, "xRays: %d")) LightTraceOptimizer::vars::numRaysXperLight.value(xRays);
			ImGui::SameLine();
			if (ImGui::SliderInt("##yRay", &yRays, 0, 10000, "yRays: %d")) LightTraceOptimizer::vars::numRaysYperLight.value(yRays);
		}
		if (fwdPT || bwdPT) {
			int triRays = LightTraceOptimizer::vars::numRaysPerTriangle.value();
			//int samRays = LightTraceOptimizer::vars::numSamples.getInt();
			if (ImGui::SliderInt("##triRays", &triRays, 1, 10000, "triRays: %d")) LightTraceOptimizer::vars::numRaysPerTriangle.value(triRays);
			//ImGui::SameLine();
			//if (ImGui::SliderInt("##samRays", &samRays, 1, 100, "samRays: %d")) LightTraceOptimizer::vars::numSamples.setInt(samRays);
		}
		ImGui::PopItemWidth();
		ImGui::PushItemWidth(140);
		if (ImGui::BeginCombo("##opti_combo", mOptimizerChoices[mOptimizerChoice].c_str())) {
			for (size_t i = 0; i < mOptimizerChoices.size(); i++) {
				const bool isSelected = (i == mOptimizerChoice);
				if (ImGui::Selectable(mOptimizerChoices[i].c_str(), isSelected)) mOptimizerChoice = i;
				if (isSelected) ImGui::SetItemDefaultFocus();
			}
			ImGui::EndCombo();
		}
		ImGui::PopItemWidth();
		ImGui::SameLine();

		if (mLto.optimizationRunning()) {
			ImGui::PushStyleColor(ImGuiCol_ButtonHovered, ImVec4{ 0.8f, 0.3f, 0.4f, 1.0f });
			ImGui::PushStyleColor(ImGuiCol_Button, ImVec4{ 0.6f, 0.3f, 0.4f, 1.0f });
			ImGui::PushStyleColor(ImGuiCol_ButtonActive, ImVec4{ 0.7f, 0.4f, 0.5f, 1.0f });

			if (ImGui::Button("Stop", ImVec2(ImGui::GetContentRegionAvail().x, 0.0f)))
			{
				mLto.optimizationRunning(false);
			}
		}
		else {
			ImGui::PushStyleColor(ImGuiCol_ButtonHovered, ImVec4{ 0.3f, 0.8f, 0.4f, 1.0f });
			ImGui::PushStyleColor(ImGuiCol_Button, ImVec4{ 0.3f, 0.6f, 0.4f, 1.0f });
			ImGui::PushStyleColor(ImGuiCol_ButtonActive, ImVec4{ 0.4f, 0.7f, 0.5f, 1.0f });
            if (ImGui::Button("Optimize", ImVec2(ImGui::GetContentRegionAvail().x, 0.0f)))
            {
				auto scene = aUiConf->scene;
				startOptimizerThread([scene](InteractiveAdjointLightTracing* ialt, LightTraceOptimizer* aLto, rvk::Buffer* aRadianceBufferOut, unsigned int aOpti, float aOptimizerStepSize, int aOptMaxIters) {
					aLto->optimize(aOpti, aRadianceBufferOut, aOptimizerStepSize, aOptMaxIters);
					scene->requestLightUpdate();
				});
            }
		}
		ImGui::PopStyleColor(3);
		int tmpInt = mOptMaxIters.value();
		if (ImGui::SliderInt("##maxIters", &tmpInt, 1, 5000, "max iters: %d")) mOptMaxIters.value(tmpInt);

		if (mLto.getCurrentHistoryIndex() != -1) {
			if (ImGui::SliderInt("##H", &mLto.getCurrentHistoryIndex(), 0, static_cast<int>(mLto.getHistorySize()) - 1, "History: %d"))
			{
				mLto.selectParamsFromHistory(mLto.getCurrentHistoryIndex());
				aUiConf->scene->requestLightUpdate();
			}
		}
		if (ImGui::Button("Copy Radiance to Target", ImVec2(ImGui::GetContentRegionAvail().x, 0.0f)))
		{
			mLto.copyRadianceToTarget();
			mLto.buildObjectiveFunction(Common::getInstance().getRenderSystem()->getMainScene().get()->getSceneData());
		}
		if (ImGui::Button("Copy Target to Mesh", ImVec2(ImGui::GetContentRegionAvail().x, 0.0f)))
		{
			mLto.copyTargetToMesh(Common::getInstance().getRenderSystem()->getMainScene().get()->getSceneData());
		}
		static float weight = 1.0f;
		ImGui::PushItemWidth(110);
		ImGui::DragFloat("##weight", &weight, 0.01f, 0.0, 1.0, "Weight: %.3g");
		ImGui::PopItemWidth();
		ImGui::SameLine();
		if (ImGui::Button("Set Weights", ImVec2(ImGui::GetContentRegionAvail().x, 0.0f)))
		{
			mLto.setTargetWeights(weight);
		}
		ImGui::DragFloat("##IS", &mOptimizerStepSize, 0.01f, 0.01f, 5.0f, "Optim step size: %.3g");
		if(ImGui::DragInt("##IB", &mLto.bounces(), 1, 0, 10, "Bounces: %d")) aUiConf->scene->requestLightUpdate();

		ImGui::Separator();
		ImGui::Checkbox("Show Wireframe Overlay", &mShowWireframeOverlay);
		if(mShowWireframeOverlay) ImGui::ColorEdit3("Wireframe Color", &mWireframeColor[0], ImGuiColorEditFlags_NoInputs);
		if (!mShowAdjointDeriv && !mShowGrad) ImGui::Checkbox("Show Target", &mShowTarget);
		if (mShowTarget) ImGui::Checkbox("Show Weights", &mShowAlpha);
		if (!mShowTarget && !mShowGrad) ImGui::Checkbox("Show Adjoint", &mShowAdjointDeriv);
		if (mShowAdjointDeriv)
		{
			ImGui::DragFloat("##adjrange", &mAdjointVisRange, 0.01f, 0.001f, 100.0f, "Adjoint: %f");
			ImGui::Checkbox("Log", &mAdjointVisLog);
		}
		if (!mShowTarget && !mShowAdjointDeriv && mShowGrad) {
			ImGui::Checkbox("Show Gradient", &mShowGrad);
			ImGui::DragFloat("##gradrange", &mGradRange, 0.01f, 0.001f, 100.0f, "Grad: %f");
			ImGui::Checkbox("Log", &mGradVisLog);
			ImGui::Checkbox("Accumulate", &mGradVisAcc);
			ImGui::Checkbox("Weights", &mGradVisWeights);
		}
		if (mShowFDGrad) {
			ImGui::Checkbox("Show Fd Gradient", &mShowFDGrad);
			ImGui::DragFloat("##gradrange", &mGradRange, 0.01f, 0.001f, 100.0f, "Grad: %f");
			ImGui::Checkbox("Log", &mGradVisLog);
			ImGui::Checkbox("Accumulate", &mGradVisAcc);
			ImGui::Checkbox("Weights", &mGradVisWeights);
		}
#ifdef IALT_USE_SPHERICAL_HARMONICS
		ImGui::Checkbox("Show SH", &mShowSH);
		ImGui::Checkbox("Show Interpolated SH", &mShowInterpolatedSH);
		if(mShowSH || mShowInterpolatedSH) ImGui::DragFloat("##SHS", &mShowSHSize, 0.001f, 0.001f, 5.0, "SH Radius: %.3f");

#endif
		if(ImGui::Checkbox("Use PT for Fwd", &fwdPT))
		{
			// TODO: fill irradiance buffer
			mLto.useForwardPT(fwdPT);
			mLto.forward(mLto.getCurrentParams(), &mData->mRadianceBufferCopy);
		}
		if (ImGui::Checkbox("Use PT for Bwd", &bwdPT))
		{
			mLto.useBackwardPT(bwdPT);
		}
		ImGui::Checkbox("Use PT renderer", &mPTRenderer);
		if (mPTRenderer) 
		{
			ImGui::Separator();
			constexpr std::array samplingStrategies = { "BRDF", "BRDF + Light (0.5)", "BRDF + Light (MIS)" };
			if (ImGui::BeginCombo("##combo", samplingStrategies[mPTGlobalUBO.sampling_strategy]))
			{
				for (uint32_t i = 0; i < samplingStrategies.size(); i++) {
					const bool isSelected = (i == mPTGlobalUBO.sampling_strategy);
					if (ImGui::Selectable(samplingStrategies[i], isSelected)) {
						mPTGlobalUBO.sampling_strategy = i;
						mRecalculate = true;
					}
					if (isSelected) ImGui::SetItemDefaultFocus();
				}
				ImGui::EndCombo();
			}

			if (ImGui::SliderInt("##max_depth", &mPTGlobalUBO.max_depth, -1, 20, "maximum bounces: %d")) 
				mRecalculate = true;

			if (ImGui::Checkbox("Accumulate", reinterpret_cast<bool*>(&mPTGlobalUBO.accumulate))) 
				mRecalculate = true;

			if (ImGui::DragInt("##spf", reinterpret_cast<int*>(&mPTGlobalUBO.pixelSamplesPerFrame), 1, 1, 1000, "Pixel Samples per Frame: %d"))
				mRecalculate = true;

			if (ImGui::Checkbox("Use cached radiance", reinterpret_cast<bool*>(&mPTGlobalUBO.use_cached_radiance))) 
				mRecalculate = true;

			if (ImGui::Checkbox("Use path guiding", &mUsePathGuiding)) 
				mRecalculate = true;

			if (mUsePathGuiding)
			{
				if (ImGui::SliderInt("##pg_depth", (int*)&mPTGlobalUBO.pg_hsh_subdivision_depth, 1, 9, "HSH subdivision depth: %d"))
				{
					mRecalculate = true;
					mUpdatePathGuidingBuffers = true;
				}
			}

			ImGui::SeparatorText("Screenshot");
			ImGui::BeginDisabled(mTakeScreenshot);
			if (ImGui::Button("Start"))
			{
				mTakeScreenshot = true;
				mRecalculate = true;
				mScreenshotStartTime = std::chrono::steady_clock::now();
			}
			ImGui::InputInt("spp", reinterpret_cast<int*>(&mScreenshotSpp), 1, 100);
			ImGui::DragScalar("time limit", ImGuiDataType_U32, &mScreenshotTimeLimit, 100, nullptr, nullptr, "%u ms");
			ImGui::EndDisabled();

			ImGui::Separator();
		} else {
			mRecalculate = true;
			mTakeScreenshot = false;
		}
		ImGui::SetNextItemWidth(160.0f);
		ImGui::DragFloat("##penaltyFactor", &mPenalty, 0.01f, -1.0f, 100.0f, "Penaltyfactor: %1.0f");
		ImGui::SameLine();
		if (ImGui::Button("Set", ImVec2(ImGui::GetContentRegionAvail().x, 0.0f)))
		{
			LightTraceOptimizer::vars::useWindowConstraint.tryStore(mPenalty);
			mLto.buildObjectiveFunction(Common::getInstance().getRenderSystem()->getMainScene().get()->getSceneData());
		}
		ImGui::Text("Cull Mode:"); ImGui::SameLine();
		ImGui::PushItemWidth(ImGui::GetContentRegionAvail().x);
		if (ImGui::BeginCombo("##cmcombo", mCullMode[mActiveCullMode].c_str(), ImGuiComboFlags_NoArrowButton)) {
			for (uint32_t i = 0; i < mCullMode.size(); i++) {
				const bool isSelected = (i == mActiveCullMode);
				if (ImGui::Selectable(mCullMode[i].c_str(), isSelected)) mActiveCullMode = i;
				if (isSelected) ImGui::SetItemDefaultFocus();
			}
			ImGui::EndCombo();
		}
		ImGui::PopItemWidth();

		ImGui::End();
	}

	if (aUiConf->scene->getSelection().reference && aUiConf->scene->getSelection().reference->type == Ref::Type::Light) {

		if( mLto.getLightOptParams().count(static_cast<tamashii::RefLight*>( aUiConf->scene->getSelection().reference.get() ))==0){ // add default selection state if not stored yet (could be new light source)
			mLto.getLightOptParams()[static_cast<tamashii::RefLight*>( aUiConf->scene->getSelection().reference.get())] = LightOptParams();
		}
		if (ImGui::Begin("Edit", nullptr, 0) && mLto.getLightOptParams().count(static_cast<tamashii::RefLight*>( aUiConf->scene->getSelection().reference.get()))>0 )
		{
			bool lightSettingsUpdate = false;
			//ToDo: while an optimization is running, do not allow the user to change the parameter selection here
			LightOptParams& lp = mLto.getLightOptParams()[static_cast<tamashii::RefLight*>(aUiConf->scene->getSelection().reference.get())];
			ImGui::Separator();
			ImGui::Text("Optimization Parameters");
			ImGui::Text("Position: ");
			ImGui::SameLine();
			bool xyz = lp[LightOptParams::POS_X] && lp[LightOptParams::POS_Y] && lp[LightOptParams::POS_Z];
			if (ImGui::Checkbox("all", &xyz)) {
				lp[LightOptParams::POS_X] = lp[LightOptParams::POS_Y] = lp[LightOptParams::POS_Z] = xyz;
				lightSettingsUpdate |= true;
			}
			ImGui::SameLine();
			lightSettingsUpdate |= ImGui::Checkbox("x", &lp[LightOptParams::POS_X]);
			ImGui::SameLine();
			lightSettingsUpdate |= ImGui::Checkbox("y", &lp[LightOptParams::POS_Y]);
			ImGui::SameLine();
			lightSettingsUpdate |= ImGui::Checkbox("z", &lp[LightOptParams::POS_Z]);

			ImGui::Text("Intensity:");
			ImGui::SameLine();
			lightSettingsUpdate |= ImGui::Checkbox("##intensity_opti", &lp[LightOptParams::INTENSITY]);

			if (dynamic_cast<tamashii::SpotLight*>(static_cast<tamashii::RefLight*>(aUiConf->scene->getSelection().reference.get())->light.get()) ||
				dynamic_cast<tamashii::SurfaceLight*>(static_cast<tamashii::RefLight*>(aUiConf->scene->getSelection().reference.get())->light.get()) ||
				dynamic_cast<tamashii::IESLight*>(static_cast<tamashii::RefLight*>(aUiConf->scene->getSelection().reference.get())->light.get())
				) {
				ImGui::Text("Rotation: ");
				ImGui::SameLine();
				bool rxyz = lp[LightOptParams::ROT_X] && lp[LightOptParams::ROT_Y] && lp[LightOptParams::ROT_Z];
				if (ImGui::Checkbox("rall", &rxyz)) {
					lp[LightOptParams::ROT_X] = lp[LightOptParams::ROT_Y] = lp[LightOptParams::ROT_Z] = rxyz;
					lightSettingsUpdate |= true;
				}
				ImGui::SameLine();
				lightSettingsUpdate |= ImGui::Checkbox("rx", &lp[LightOptParams::ROT_X]);
				ImGui::SameLine();
				lightSettingsUpdate |= ImGui::Checkbox("ry", &lp[LightOptParams::ROT_Y]);
				ImGui::SameLine();
				lightSettingsUpdate |= ImGui::Checkbox("rz", &lp[LightOptParams::ROT_Z]);
			}

			if (dynamic_cast<tamashii::SpotLight*>(static_cast<tamashii::RefLight*>(aUiConf->scene->getSelection().reference.get())->light.get())) {
				ImGui::Text("Cone angles: ");
				ImGui::SameLine();
				lightSettingsUpdate |= ImGui::Checkbox("inner", &lp[LightOptParams::CONE_INNER]);
				ImGui::SameLine();
				lightSettingsUpdate |= ImGui::Checkbox("edge", &lp[LightOptParams::CONE_EDGE]);
			}
			ImGui::Text("Color:    ");
			ImGui::SameLine();
			lightSettingsUpdate |= ImGui::Checkbox("rgb", &lp[LightOptParams::COLOR_R]);
			lp[LightOptParams::COLOR_G] = lp[LightOptParams::COLOR_B] = lp[LightOptParams::COLOR_R];

			if (lightSettingsUpdate) mLto.exportLightSettings();

			ImGui::Separator();
			ImGui::Text("Gradient Visualization");
			ImGui::Text("Parameter:");
			ImGui::SameLine();
			static const char* gradImageOptions[13] = { "X", "Y", "Z", "Intensity", "Rot X", "Rot Y", "Rot Z", "Cone Inner", "Cone Edge", "Color R", "Color G", "Color B", "Off"};
			//static const uint32_t choices = 8;
			{
				uint32_t currentGradImage = (sizeof(gradImageOptions) / sizeof(char*)) - 1;
				if (mGradImageSelection.has_value() && mGradImageSelection.value().first == aUiConf->scene->getSelection().reference.get()) currentGradImage = mGradImageSelection.value().second;
				if (ImGui::BeginCombo("##gradImgCombo", gradImageOptions[currentGradImage], ImGuiComboFlags_NoArrowButton)) {
					for (uint32_t i = 0; i < sizeof(gradImageOptions) / sizeof(char*); i++) {
						const bool isSelected = (i == currentGradImage);
						if (ImGui::Selectable(gradImageOptions[i], isSelected)) {
							currentGradImage = i;
							if (currentGradImage != (sizeof(gradImageOptions) / sizeof(char*)) - 1) {
								mGradImageSelection.emplace(static_cast<RefLight*>(aUiConf->scene->getSelection().reference.get()), static_cast<LightOptParams::PARAMS>(i));
								mShowGrad = true;
								mFDGradImageSelection.reset();
								mShowFDGrad = false;
								mGradVisAcc = false;
							}
							else {
								mGradImageSelection.reset();
								mShowGrad = false;
							}
						}
						if (isSelected) ImGui::SetItemDefaultFocus();
					}
					ImGui::EndCombo();
				}
			}

			ImGui::Separator();
			ImGui::Text("Finite Difference Visualization");
			ImGui::PushItemWidth(150);
			ImGui::Text("Parameter:");
			ImGui::SameLine();
			uint32_t currentGradImageFD = (sizeof(gradImageOptions) / sizeof(char*)) - 1;
			if (mFDGradImageSelection.has_value() && mFDGradImageSelection.value().first == aUiConf->scene->getSelection().reference.get()) currentGradImageFD = mFDGradImageSelection.value().second;
			if (ImGui::BeginCombo("##fdgradImgCombo", gradImageOptions[currentGradImageFD], ImGuiComboFlags_NoArrowButton)) {
				for (uint32_t i = 0; i < sizeof(gradImageOptions) / sizeof(char*); i++) {
					const bool isSelected = (i == currentGradImageFD);
					if (ImGui::Selectable(gradImageOptions[i], isSelected)) {
						currentGradImageFD = i;
						if (currentGradImageFD != (sizeof(gradImageOptions) / sizeof(char*)) - 1) {
							mFDGradImageSelection.emplace(static_cast<RefLight*>(aUiConf->scene->getSelection().reference.get()), static_cast<LightOptParams::PARAMS>(i));
							mLto.fillFiniteDiffRadianceBuffer(mFDGradImageSelection.value().first, mFDGradImageSelection.value().second, mFdGradH, &mData->mFdRadianceBufferCopy, &mData->mFd2RadianceBufferCopy);
							mShowFDGrad = true;
							mGradImageSelection.reset();
							mShowGrad = false;
							mGradVisAcc = false;
						}
						else {
							mFDGradImageSelection.reset();
							mShowFDGrad = false;
						}
					}
					if (isSelected) ImGui::SetItemDefaultFocus();
				}
				ImGui::EndCombo();
			}
			ImGui::PopItemWidth();

			ImGui::SameLine();
			ImGui::PushItemWidth(ImGui::GetContentRegionAvail().x);
			if(ImGui::DragFloat("##stepsize", &mFdGradH, 0.01f, 0.01f, 1.0f, "h: %.3g"))
			{
				if (mFDGradImageSelection.has_value())
				{
					mLto.fillFiniteDiffRadianceBuffer(mFDGradImageSelection.value().first, mFDGradImageSelection.value().second, mFdGradH, &mData->mFdRadianceBufferCopy, &mData->mFd2RadianceBufferCopy);
				}
			}
			ImGui::PopItemWidth();


			/*ImGui::Separator();
			static bool shoft = true;
			static float h = 0.1f;
			if (ImGui::Button("Capture", ImVec2(70, 0.0f)))
			{
			}
			ImGui::SameLine();
			ImGui::PushItemWidth(60);
			ImGui::DragFloat("##steps", &h, 0.01f, 0.0, 1.0, "h: %.3g");
			ImGui::PopItemWidth();
			ImGui::SameLine();
			ImGui::Checkbox("##showfd", &shoft);*/

			ImGui::End();
		}
	}
	if (aUiConf->scene->getSelection().reference && aUiConf->scene->getSelection().reference->type == Ref::Type::Model) {

		auto refModel = static_cast<RefModel*>(aUiConf->scene->getSelection().reference.get());
		auto front = refModel->refMeshes.begin();
		std::advance(front, aUiConf->scene->getSelection().meshOffset);

		if (mLto.getLightOptParams().count(front->get()) == 0) { // add default selection state if not stored yet (could be new light source)
			mLto.getLightOptParams()[front->get()] = LightOptParams();
		}
		if (ImGui::Begin("Edit", nullptr, 0) && mLto.getLightOptParams().count(front->get()) > 0)
		{
			bool lightSettingsUpdate = false;
			//ToDo: while an optimization is running, do not allow the user to change the parameter selection here
			LightOptParams& lp = mLto.getLightOptParams()[front->get()];
			ImGui::Separator();
			ImGui::Text("Optimization Parameters");

			ImGui::Text("Emission Strength:");
			ImGui::SameLine();
			lightSettingsUpdate |= ImGui::Checkbox("##emission_opti", &lp[LightOptParams::INTENSITY]);

			ImGui::Text("Emission Color:   ");
			ImGui::SameLine();
			lightSettingsUpdate |= ImGui::Checkbox("##emission_rgb_opti", &lp[LightOptParams::COLOR_R]);
			lp[LightOptParams::COLOR_G] = lp[LightOptParams::COLOR_B] = lp[LightOptParams::COLOR_R];

			ImGui::Text("Emission Texture:");
			ImGui::SameLine();
			lightSettingsUpdate |= ImGui::Checkbox("##emission_tex_opti", &lp[LightOptParams::EMISSIVE_TEXTURE]);

			if (lightSettingsUpdate) mLto.exportLightSettings();

			//ImGui::Separator();
			//ImGui::Text("Gradient Visualization");
			//ImGui::Text("Parameter:");
			//ImGui::SameLine();
			//static const char* gradImageOptions[13] = { "X", "Y", "Z", "Intensity", "Rot X", "Rot Y", "Rot Z", "Cone Inner", "Cone Edge", "Color R", "Color G", "Color B", "Off" };
			////static const uint32_t choices = 8;
			//{
			//	uint32_t currentGradImage = (sizeof(gradImageOptions) / sizeof(char*)) - 1;
			//	if (mGradImageSelection.has_value() && mGradImageSelection.value().first == aUiConf->scene->getSelection().reference.get()) currentGradImage = mGradImageSelection.value().second;
			//	if (ImGui::BeginCombo("##gradImgCombo", gradImageOptions[currentGradImage], ImGuiComboFlags_NoArrowButton)) {
			//		for (uint32_t i = 0; i < sizeof(gradImageOptions) / sizeof(char*); i++) {
			//			const bool isSelected = (i == currentGradImage);
			//			if (ImGui::Selectable(gradImageOptions[i], isSelected)) {
			//				currentGradImage = i;
			//				if (currentGradImage != (sizeof(gradImageOptions) / sizeof(char*)) - 1) {
			//					mGradImageSelection.emplace(static_cast<RefLight*>(aUiConf->scene->getSelection().reference.get()), static_cast<LightOptParams::PARAMS>(i));
			//					mShowGrad = true;
			//					mFDGradImageSelection.reset();
			//					mShowFDGrad = false;
			//					mGradVisAcc = false;
			//				}
			//				else {
			//					mGradImageSelection.reset();
			//					mShowGrad = false;
			//				}
			//			}
			//			if (isSelected) ImGui::SetItemDefaultFocus();
			//		}
			//		ImGui::EndCombo();
			//	}
			//}

			//ImGui::Separator();
			//ImGui::Text("Finite Difference Visualization");
			//ImGui::PushItemWidth(150);
			//ImGui::Text("Parameter:");
			//ImGui::SameLine();
			//uint32_t currentGradImageFD = (sizeof(gradImageOptions) / sizeof(char*)) - 1;
			//if (mFDGradImageSelection.has_value() && mFDGradImageSelection.value().first == aUiConf->scene->getSelection().reference.get()) currentGradImageFD = mFDGradImageSelection.value().second;
			//if (ImGui::BeginCombo("##fdgradImgCombo", gradImageOptions[currentGradImageFD], ImGuiComboFlags_NoArrowButton)) {
			//	for (uint32_t i = 0; i < sizeof(gradImageOptions) / sizeof(char*); i++) {
			//		const bool isSelected = (i == currentGradImageFD);
			//		if (ImGui::Selectable(gradImageOptions[i], isSelected)) {
			//			currentGradImageFD = i;
			//			if (currentGradImageFD != (sizeof(gradImageOptions) / sizeof(char*)) - 1) {
			//				mFDGradImageSelection.emplace(static_cast<RefLight*>(aUiConf->scene->getSelection().reference.get()), static_cast<LightOptParams::PARAMS>(i));
			//				mLto.fillFiniteDiffRadianceBuffer(mFDGradImageSelection.value().first, mFDGradImageSelection.value().second, mFdGradH, &mData->mFdRadianceBufferCopy, &mData->mFd2RadianceBufferCopy);
			//				mShowFDGrad = true;
			//				mGradImageSelection.reset();
			//				mShowGrad = false;
			//				mGradVisAcc = false;
			//			}
			//			else {
			//				mFDGradImageSelection.reset();
			//				mShowFDGrad = false;
			//			}
			//		}
			//		if (isSelected) ImGui::SetItemDefaultFocus();
			//	}
			//	ImGui::EndCombo();
			//}
			//ImGui::PopItemWidth();

			//ImGui::SameLine();
			//ImGui::PushItemWidth(ImGui::GetContentRegionAvail().x);
			//if (ImGui::DragFloat("##stepsize", &mFdGradH, 0.01f, 0.01f, 1.0f, "h: %.3g"))
			//{
			//	if (mFDGradImageSelection.has_value())
			//	{
			//		mLto.fillFiniteDiffRadianceBuffer(mFDGradImageSelection.value().first, mFDGradImageSelection.value().second, mFdGradH, &mData->mFdRadianceBufferCopy, &mData->mFd2RadianceBufferCopy);
			//	}
			//}
			//ImGui::PopItemWidth();

			ImGui::End();
	}
	}
	if (aUiConf->draw_info->mDrawMode) {
		aUiConf->draw_info->mTarget = DrawInfo::Target::CUSTOM;
#ifdef IALT_USE_SPHERICAL_HARMONICS
		if (ImGui::Begin("Draw", nullptr, 0))
		{
			ImGui::Checkbox("Set SH", &mDrawSetSH);
			ImGui::End();
		}
#endif
	}
}

void InteractiveAdjointLightTracing::waitForNextFrame()
{
	std::mutex mutex;
	std::unique_lock<std::mutex> lck(mutex);
	mNextFrameCV.wait(lck);
}

void InteractiveAdjointLightTracing::startOptimizerThread(std::function<void(InteractiveAdjointLightTracing*, LightTraceOptimizer*, rvk::Buffer*, unsigned int, float, int) > funcToRun)
{
	if (mOptimizerThread.joinable()) {
		spdlog::warn("Optimizer still running. Waiting for it to stop...");
		mOptimizerThread.join();
	}

	mOptimizerThread = std::thread(funcToRun, this, &mLto, &mData->mRadianceBufferCopy, mOptimizerChoice, mOptimizerStepSize, mOptMaxIters.value());
}

void InteractiveAdjointLightTracing::runForward(const Eigen::Map<Eigen::VectorXd>& params)
{
	if (!mData.has_value()) {
		throw std::runtime_error{ "IALT has no data member set" };
	}

	Eigen::VectorXd p = params;
	mLto.forward(p, &mData->mRadianceBufferCopy, &mData->mIncidentRadianceBufferCopy);
}

double InteractiveAdjointLightTracing::runBackward(Eigen::VectorXd& derivParams)
{
	return mLto.backward(derivParams);
}

void InteractiveAdjointLightTracing::useCurrentRadianceAsTarget(const bool clearWeights)
{
	if(clearWeights) mLto.setTargetWeights(1);
	mLto.copyRadianceToTarget();
	mLto.buildObjectiveFunction(Common::getInstance().getRenderSystem()->getMainScene().get()->getSceneData());
	mLto.copyTargetToMesh(Common::getInstance().getRenderSystem()->getMainScene().get()->getSceneData());
}

InteractiveAdjointLightTracing::OptimizerResult InteractiveAdjointLightTracing::runOptimizer(const LightTraceOptimizer::Optimizers optimizerType, const float stepSize, int maxIterations)
{
	// TODO: should probably pass incident radiance buffer here also
	auto result= mLto.optimize(optimizerType, &mData->mRadianceBufferCopy, stepSize, maxIterations );
	const auto history = mLto.exportHistory();

	return { .history = history, .lastPhi = result.lastPhi };
}

void InteractiveAdjointLightTracing::clearGradVis()
{
	SingleTimeCommand stc = mRoot.singleTimeCommand();
	stc.begin();
	mData->mDerivVisImageAccumulate.CMD_ClearColor(stc.buffer(), 0, 0, 0, 0);
	stc.end();
}

std::unique_ptr<tamashii::Image> InteractiveAdjointLightTracing::getFrameImage()
{
	SingleTimeCommand stc = mRoot.singleTimeCommand();
	const VkExtent3D extent = mFrameData[mRoot.currentIndex()].mColor.getExtent();
	std::vector<glm::vec4> img_data(extent.width * extent.height);
	mFrameData[mRoot.currentIndex()].mColor.STC_DownloadData2D(&stc, extent.width, extent.height, 16, img_data.data());

	const auto img = new tamashii::Image("");
	img->init(extent.width, extent.height, tamashii::Image::Format::RGBA32_FLOAT, img_data.data());
	return std::unique_ptr<tamashii::Image>{img};
}

std::unique_ptr<tamashii::Image> InteractiveAdjointLightTracing::getGradImage()
{

	SingleTimeCommand stc = mRoot.singleTimeCommand();
	const VkExtent3D extent = mData.value().mDerivVisImageAccumulate.getExtent();

	std::vector<glm::vec4> img_data(extent.width * extent.height);
	mData->mDerivVisImageAccumulate.STC_DownloadData2D(&stc, extent.width, extent.height, 16, img_data.data());
	//std::vector<uint32_t> img_sample_count(extent.width * extent.height);
	//mData->mDerivVisImageAccumulateCount.STC_DownloadData2D(&stc, extent.width, extent.height, 4, img_sample_count.data());
	//for (size_t i = 0; i < img_data.size(); i++) if (img_sample_count[i]) img_data[i] /= img_sample_count[i];

	const auto img = new tamashii::Image("");
	img->init(extent.width, extent.height, tamashii::Image::Format::RGBA32_FLOAT, img_data.data());
	return std::unique_ptr<tamashii::Image>{img};
}

void InteractiveAdjointLightTracing::recomputePathGuidingBuffers()
{
	computeHSHIntegrals();
	computeNormalizationConstants();
	// update descriptors
	for (auto& frameData : mFrameData)
	{
		frameData.mPTGlobalDescriptor.setBuffer(PATH_GUIDING_PHI_MAP_BINDING, &mData->mPGPhiMapBuffer);
		frameData.mPTGlobalDescriptor.setBuffer(PATH_GUIDING_LEGENDRE_MAP_BINDING, &mData->mPGLegendreMapBuffer);
		frameData.mPTGlobalDescriptor.setBuffer(PATH_GUIDING_NORMALIZATION_CONSTANTS_BINDING, &mData->mPGNormalizationConstantsBuffer);
		frameData.mPTGlobalDescriptor.update();
	}
}

// The code below was adapted from the Mitsuba implementation
static void phiIntegrals(float* out, float a, float b) {
	const int shOrder = LTOVars::shOrder;

	std::vector<float> sinPhiA(shOrder + 1);
	std::vector<float> sinPhiB(shOrder + 1);
	std::vector<float> cosPhiA(shOrder + 1);
	std::vector<float> cosPhiB(shOrder + 1);

	cosPhiA[0] = 1; sinPhiA[0] = 0;
	cosPhiB[0] = 1; sinPhiB[0] = 0;
	cosPhiA[1] = std::cos(a);
	sinPhiA[1] = std::sin(a);
	cosPhiB[1] = std::cos(b);
	sinPhiB[1] = std::sin(b);

	for (size_t m = 2; m <= LTOVars::shOrder; ++m) {
		sinPhiA[m] = 2 * sinPhiA[m - 1] * cosPhiA[1] - sinPhiA[m - 2];
		sinPhiB[m] = 2 * sinPhiB[m - 1] * cosPhiB[1] - sinPhiB[m - 2];

		cosPhiA[m] = 2 * cosPhiA[m - 1] * cosPhiA[1] - cosPhiA[m - 2];
		cosPhiB[m] = 2 * cosPhiB[m - 1] * cosPhiB[1] - cosPhiB[m - 2];
	}

	for (int32_t m = -shOrder; m <= shOrder; ++m) {
		if (m == 0)
			out[m + shOrder] = b - a;
		else if (m > 0)
			out[m + shOrder] = (sinPhiB[m] - sinPhiA[m]) / m;
		else
			out[m + shOrder] = (cosPhiB[-m] - cosPhiA[-m]) / m;
	}
}

static size_t I(size_t l, size_t m) {
	return l * (l + 1) / 2 + m;
}

float legendreP(int l, int m, float x) {
	/* Evaluate the recurrence in double precision */
	double p_mm = 1;

	if (m > 0) {
		double somx2 = std::sqrt((1 - x) * (1 + x));
		double fact = 1;
		for (int i = 1; i <= m; i++) {
			p_mm *= (-fact) * somx2;
			fact += 2;
		}
	}

	if (l == m)
		return (float)p_mm;

	double p_mmp1 = x * (2 * m + 1) * p_mm;
	if (l == m + 1)
		return (float)p_mmp1;

	double p_ll = 0;
	for (int ll = m + 2; ll <= l; ++ll) {
		p_ll = ((2 * ll - 1) * x * p_mmp1 - (ll + m - 1) * p_mm) / (ll - m);
		p_mm = p_mmp1;
		p_mmp1 = p_ll;
	}

	return (float)p_ll;
}
static void legendreIntegrals(float* out, float a, float b) {
	out[0] = b - a;
	if (LTOVars::shOrder == 1) {
		return;
	}
	const size_t valuesCount = LTOVars::shOrder * (LTOVars::shOrder + 1) / 2;

	std::vector<float> Pa(valuesCount);
	std::vector<float> Pb(valuesCount);

	for (size_t l = 0; l < LTOVars::shOrder; ++l) {
		for (size_t m = 0; m <= l; ++m) {
			Pa[I(l, m)] = legendreP(l, m, a);
			Pb[I(l, m)] = legendreP(l, m, b);
		}
	}

	out[I(1, 0)] = (b * b - a * a) / 2;
	out[I(1, 1)] = .5f * (-b * std::sqrt(1 - b * b) - std::asin(b) + a * std::sqrt(1 - a * a) + std::asin(a));

	for (size_t l = 2; l < LTOVars::shOrder; ++l) {
		for (size_t m = 0; m <= l - 2; ++m) {
			float ga = (2 * l - 1) * (1 - a * a) * Pa[I(l - 1, m)];
			float gb = (2 * l - 1) * (1 - b * b) * Pb[I(l - 1, m)];
			out[I(l, m)] = ((l - 2) * (l - 1 + m) * out[I(l - 2, m)] - gb + ga) / ((l + 1) * (l - m));
		}

		out[I(l, l - 1)] = (2 * l - 1) / (float)(l + 1) * ((1 - a * a) * Pa[I(l - 1, l - 1)] - (1 - b * b) * Pb[I(l - 1, l - 1)]);
		out[I(l, l)] = 1 / (float)(l + 1) * (l * (2 * l - 3) * (2 * l - 1) * out[I(l - 2, l - 2)] + b * Pb[I(l, l)] - a * Pa[I(l, l)]);
	}

	// for hemispherical harmonics we need to make the change of variables $z = 2\cos(\theta) - 1$
	// with $\theta \in [0, \pi/2]$ instead of $z = \cos(\theta)$ with $\theta \in [0, \pi]$.
	// In code, the only change is this $1/2$ factor
	for (size_t i = 0; i < valuesCount; ++i)
		out[i] /= float(2);
}

void InteractiveAdjointLightTracing::computeHSHIntegrals()
{
	const size_t totalSubdivisionCount = (2 << mPTGlobalUBO.pg_hsh_subdivision_depth) - 1;
	const size_t phiMapElementCount = totalSubdivisionCount * (2 * LTOVars::shOrder + 1);
	const size_t phiMapSize = phiMapElementCount * sizeof(float);
	mData->mPGPhiMapBuffer.create(rvk::Buffer::Use::STORAGE | rvk::Buffer::Use::UPLOAD, phiMapSize, rvk::Buffer::Location::DEVICE);
	rvk::Buffer phiMapStagingBuffer(&mRoot.device);
	phiMapStagingBuffer.create(rvk::Buffer::Use::DOWNLOAD, phiMapSize, rvk::Buffer::Location::HOST_COHERENT);
	phiMapStagingBuffer.mapBuffer();
	float* phiMapPtr = (float*)phiMapStagingBuffer.getMemoryPointer();

	const size_t legendreMapElementCount = totalSubdivisionCount * (LTOVars::shOrder * (LTOVars::shOrder + 1) / 2);
	const size_t legendreMapSize = legendreMapElementCount * sizeof(float);
	mData->mPGLegendreMapBuffer.create(rvk::Buffer::Use::STORAGE | rvk::Buffer::Use::UPLOAD, legendreMapSize, rvk::Buffer::Location::DEVICE);
	rvk::Buffer legendreMapStagingBuffer(&mRoot.device);
	legendreMapStagingBuffer.create(rvk::Buffer::Use::DOWNLOAD, legendreMapSize, rvk::Buffer::Location::HOST_COHERENT);
	legendreMapStagingBuffer.mapBuffer();
	float* legendreMapPtr = (float*)legendreMapStagingBuffer.getMemoryPointer();
	
	size_t phiMapPos = 0;
	size_t legendreMapPos = 0;
	for (size_t i = 0; i <= mPTGlobalUBO.pg_hsh_subdivision_depth; ++i) {
		size_t subdivisions = 1ull << i;
		float zStep = -2.f / (float)subdivisions;
		float phiStep = 2 * (float)M_PI / (float)subdivisions;
		for (size_t j = 0; j < subdivisions; ++j) {
			phiIntegrals(&phiMapPtr[phiMapPos], phiStep * j, phiStep * (j + 1));
			phiMapPos += LTOVars::shOrder * 2 + 1;
			legendreIntegrals(&legendreMapPtr[legendreMapPos], 1.f + zStep * j, 1.f + zStep * (j + 1));
			legendreMapPos += LTOVars::shOrder * (LTOVars::shOrder + 1) / 2;
		}
	}
	
	phiMapStagingBuffer.unmapBuffer();
	legendreMapStagingBuffer.unmapBuffer();

	SingleTimeCommand stc = mRoot.singleTimeCommand();
	stc.begin();
	phiMapStagingBuffer.CMD_CopyBuffer(stc.buffer(), &mData->mPGPhiMapBuffer, 0, phiMapSize);
	legendreMapStagingBuffer.CMD_CopyBuffer(stc.buffer(), &mData->mPGLegendreMapBuffer, 0, legendreMapSize);
	stc.end();
}

// computes (a-b)!/(a+b)!
static double ratio_of_factorials(size_t a, size_t b) {
	double result = 1.0;
	for (size_t i = a - b + 1; i <= a + b; ++i) {
		result /= double(i);
	}
	return result;
}

void InteractiveAdjointLightTracing::computeNormalizationConstants()
{
	const size_t elementCount = (LTOVars::shOrder * (LTOVars::shOrder + 1) / 2);
	const size_t bufferSize = elementCount * sizeof(float);
	mData->mPGNormalizationConstantsBuffer.create(rvk::Buffer::Use::STORAGE | rvk::Buffer::Use::UPLOAD, bufferSize, rvk::Buffer::Location::DEVICE);
	rvk::Buffer stagingBuffer(&mRoot.device);
	stagingBuffer.create(rvk::Buffer::Use::DOWNLOAD, bufferSize, rvk::Buffer::Location::HOST_COHERENT);
	stagingBuffer.mapBuffer();
	float* out = (float*)stagingBuffer.getMemoryPointer();
	
	for (size_t l = 0; l < LTOVars::shOrder; ++l) {
		for (size_t m = 0; m <= l; ++m) {
			// divide by 2 * M_PI instead of 4 * M_PI because we use hemispherical harmonics
			out[I(l, m)] = std::sqrt((2.0 * double(l) + 1.0) / (2.0 * M_PI) * ratio_of_factorials(l, m));
		}
	}

	stagingBuffer.unmapBuffer();
	SingleTimeCommand stc = mRoot.singleTimeCommand();
	stc.begin();
	stagingBuffer.CMD_CopyBuffer(stc.buffer(), &mData->mPGNormalizationConstantsBuffer, 0, bufferSize);
	stc.end();
}


