#define M_PI    3.14159265358979323846f
#define DEGRAD(a)    ((a) * (float(M_PI) / 180.f))
#define RADDEG(a)    ((a) * (180.f/ float(M_PI)))
#define HUGE 1e10f

/***************************************************************************//**
 * Simple structure to sample the screen
/******************************************************************************/
struct sampler_t 
{
    float3 top, dx, dy;
    float3 map(int2 screen) 
	{
        return top + dx * float(screen.x) + dy * float(screen.y);
    }
};

/***************************************************************************//**
 * Simple camera structure
/******************************************************************************/
struct camera_t 
{
    float3 eye, dir, up, right;
    float fovx;
    int world_up_index;
    sampler_t screen_sampler;

    void look_at(float3 target,  int up_idx = -1);
    void set_fovx(float degree)			{ fovx = DEGRAD(degree) * .5f; }
	void set_eye(float3 v)				{ eye = v; }
	float get_fovx()					{ return RADDEG(fovx)*2.f; }
    float3 get_eye()					{ return eye; }
	float3 get_up()						{ return up; }
    float3 get_dir()					{ return dir; }
    float3 get_right()					{ return right; }
    void set_world_up_index(int idx)	{ world_up_index = idx; }
    int get_world_up_index()			{ return world_up_index; }
};

struct ray_t 
{
    float3 pos, dir;
    void set_pos(float3 o) { pos = o; }
    void set_dir(float3 d) { dir = d; }
    float2 get_dir_pos(uint i) 
	{
        float2 ret;
        switch(i) 
		{
            case 0:     ret.x = pos.x; ret.y = dir.x; return ret;
            case 1:     ret.x = pos.y; ret.y = dir.y; return ret;
            default:    ret.x = pos.z; ret.y = dir.z; return ret;
        }
    }
};

/***************************************************************************//**
 * The structure used to store the intersection point
/******************************************************************************/
struct hit_t 
{
    float t, u, v;
    uint id;

    void create()
	{
		t = HUGE;
		id = ~0u;
	}

	void create(float tmax)
	{
		t = tmax;
		id = -1;
	}
};

/***************************************************************************//**
 * The trace stack used when traversing the kd-tree
/******************************************************************************/
struct trace_t
{
    uint node;
    float tnear;
    float tfar;
};

static trace_t t[8];	// nonstatic in struct -> Forced to unroll loop, but unrolling failed. <-- SUCKS!
struct trace_stack_t 
{    
    int idx;
	
	trace_t get()				{ return t[idx]; }
	
	void push(trace_t trace)	
	{ 
		t[idx] = trace;
		idx++;
	}

    void reset()	{ idx = 0; }
    int pop()		{ return --idx >= 0; }
    uint node()		{ return get().node; }
    float t_near()	{ return get().tnear; }
    float t_far()	{ return get().tfar; }
};

/***************************************************************************//**
 * The triangle used for the intersection computations
/******************************************************************************/
struct perm_t { float3 dir, pos; };
struct wald_tri_t 
{
    uint4 internal0, internal1, internal2;
	
	uint k()		{ return internal0.x; }
	float n_u()		{ return(asfloat(internal0.y)); }
	float n_v()		{ return(asfloat(internal0.z)); }
	float n_d()		{ return(asfloat(internal0.w)); }
	float vert_ku() { return(asfloat(internal1.x)); }
	float vert_kv() { return(asfloat(internal1.y)); }
	float b_nu()	{ return(asfloat(internal1.z)); }
	float b_nv()	{ return(asfloat(internal1.w)); }
	float c_nu()	{ return(asfloat(internal2.x)); }
	float c_nv()	{ return(asfloat(internal2.y)); }
	uint id()		{ return internal2.z; }
	uint matid()	{ return internal2.w; }

    perm_t get_perm(ray_t ray) 
	{
        perm_t perm;
        uint axis = k();
        switch(axis) 
		{
            case 0:
                perm.dir = ray.dir.xyz;
                perm.pos = ray.pos.xyz;
                return perm;
            case 1:
                perm.dir = ray.dir.yzx;
                perm.pos = ray.pos.yzx;
                return perm;
            default:
                perm.dir = ray.dir.zxy;
                perm.pos = ray.pos.zxy;
                return perm;
        }
    }
};

/***************************************************************************//**
 * The kd-tree structure
/******************************************************************************/
typedef uint tri_id_t;
#define mask_leaf ((int)1ul<<31)
#define mask_list (~mask_leaf)
#define mask_axis (3)
#define mask_children (~mask_axis)

namespace kdtree 
{
    struct node_t 
	{
        uint2 internal;
        uint is_leaf()			{ return offset_flag() & (uint) mask_leaf; }
        uint is_node()			{ return !is_leaf(); }
        uint get_list()			{ return offset_flag() & (uint) mask_list; }
        uint get_axis()			{ return dim_offset_flag() & (uint) mask_axis; }
        uint get_offset()		{ return dim_offset_flag() & (uint) mask_children; }
        uint offset_flag()		{ return internal.x; }
        uint tri_count()		{ return internal.y; }
        uint dim_offset_flag()  { return internal.x; }
        float split_coord()		{ return asfloat(internal.y); }
    };
}

/***************************************************************************//**
 * The axis-aligned bounding box
/******************************************************************************/
struct aabb_t 
{
    float xmin, ymin, zmin;
    float xmax, ymax, zmax;
};

// Buffer declarations
StructuredBuffer<uint2> KdNodes : register(t0);
StructuredBuffer<uint> KdIds : register(t1);
StructuredBuffer<uint4> KdTris : register(t2);
StructuredBuffer<float3> KdNormals : register(t3);

RWStructuredBuffer<float4> RtImage : register(u0);

cbuffer cbConstants : register( b0 )
{
	camera_t cam;
	aabb_t aabb;
	float3 lightPos;
	uint windowWidth;
}
	
/***************************************************************************//**
 * Perform the Kay-Kajiya ray-box intersection
/******************************************************************************/
 static uint intersect_ray_box(aabb_t box, ray_t ray, in out float tmin, in out float tmax)
{
    float l1 = (box.xmin - ray.pos.x) / ray.dir.x;
    float l2 = (box.xmax - ray.pos.x) / ray.dir.x;
    tmin = max(min(l1,l2), tmin);
    tmax = min(max(l1,l2), tmax);
    l1 = (box.ymin - ray.pos.y) / ray.dir.y;
    l2 = (box.ymax - ray.pos.y) / ray.dir.y;
    tmin = max(min(l1,l2), tmin);
    tmax = min(max(l1,l2), tmax);
    l1 = (box.zmin - ray.pos.z) / ray.dir.z;
    l2 = (box.zmax - ray.pos.z) / ray.dir.z;
    tmin = max(min(l1,l2), tmin);
    tmax = min(max(l1,l2), tmax);
    return ((tmax >= tmin) & (tmax >= 0.f));
}

/***************************************************************************//**
 * Perform the intersection between one ray and one triangle
/******************************************************************************/
void intersect_ray_tri(uint id, ray_t ray, in out hit_t hit, float t_near, float t_far)
{
    wald_tri_t tri;
    tri.internal0 = KdTris[id * 3];
    perm_t p = tri.get_perm(ray);
    float dot = (tri.n_d() - p.pos.x - tri.n_u()*p.pos.y - tri.n_v()*p.pos.z);
    float denum = (p.dir.x + tri.n_u()*p.dir.y + tri.n_v()*p.dir.z);
    float t = dot / denum;
    if ((hit.t <= t) | (t < t_near) | t > t_far) return;
	tri.internal1 = KdTris[id * 3 + 1];
	tri.internal2 = KdTris[id * 3 + 2];
    float hu = p.pos.y + t*p.dir.y - tri.vert_ku();
    float hv = p.pos.z + t*p.dir.z - tri.vert_kv();
    float beta = hv*tri.b_nu() + hu*tri.b_nv();
    float gamma = hu*tri.c_nu() + hv*tri.c_nv();
    if ((beta < 0.f) | (gamma < 0.f) | (beta + gamma > 1.f)) return;
    hit.t = t;
    hit.u = beta;
    hit.v = gamma;
    hit.id = tri.id();
}

/***************************************************************************//**
 * Perform the intersection between one ray and one triangle
/******************************************************************************/
bool shadow_ray_tri(uint id, ray_t ray, float t_near, float t_far)
{
	//todo: fill in
    return false;
}

/***************************************************************************//**
 * Trace the ray inside the kd-tree
/******************************************************************************/
void trace_kdtree(ray_t ray, in out hit_t hit)
{
    /* First, intersect the bounding box of the scene */
    float t_near = 0.f, t_far = HUGE;
    if(!intersect_ray_box(aabb, ray, t_near, t_far)) return;

    /* Then, intersect the kd-tree */
    trace_stack_t stack;
    stack.reset();
    kdtree::node_t node;
    node.internal = KdNodes[0];

    for(;;) 
	{
        /* Process the non-leaf nodes */
        while(node.is_node()) 
		{
            uint bits = node.dim_offset_flag();
            uint axis = bits & 3u;
            uint off = (bits & (uint) mask_children) >> 3;
            float split = node.split_coord();
            float2 pos_dir = ray.get_dir_pos(axis);
            float dir = pos_dir.y;
            float d = (split - pos_dir.x) / dir;
            uint sign = dir >= 0.f;
            uint idx = off + (sign^1);
            if (d < t_near) 
				idx = off + (sign^0);
            else 
				if(d <= t_far) 
				{
					trace_t trace;
					trace.node = off + (sign^0);
					trace.tnear = d;
					trace.tfar = t_far;
					stack.push(trace);
					t_far = d;
				}
            node.internal = KdNodes[idx];
        }

        /* Intersect the inner triangles of the leaf and exit if possible */
        uint idx = node.get_list(), count = node.tri_count();
		
		uint i = 0;
        for (; i < count; ++i) 
		{
            uint id = KdIds[idx + i];
            intersect_ray_tri(id, ray, hit, t_near, t_far);
			// cannot make a return here - compiler bug
        }
		
		// if you place it here it's the same :p
        if(stack.pop() & (hit.t > t_far)) 
		{
            trace_t trace = stack.get();
            node.internal = KdNodes[trace.node];
            t_near = trace.tnear;
            t_far = trace.tfar;
        } else 
			break;
    }
}

bool shadow_kdtree(ray_t ray, float t_min, float t_max)
{
	 /* First, intersect the bounding box of the scene */
    float t_near = t_min, t_far = t_max;
    if(!intersect_ray_box(aabb, ray, t_near, t_far)) return false;

    /* Then, intersect the kd-tree */
    trace_stack_t stack;
    stack.reset();
    kdtree::node_t node;
    node.internal = KdNodes[0];

    for(;;) 
	{
        /* Process the non-leaf nodes */
        while(node.is_node()) 
		{
            uint bits = node.dim_offset_flag();
            uint axis = bits & 3u;
            uint off = (bits & (uint) mask_children) >> 3;
            float split = node.split_coord();
            float2 pos_dir = ray.get_dir_pos(axis);
            float dir = pos_dir.y;
            float d = (split - pos_dir.x) / dir;
            uint sign = dir >= 0.f;
            uint idx = off + (sign^1);
            if (d < t_near) 
				idx = off + (sign^0);
            else 
				if(d <= t_far) 
				{
					trace_t trace;
					trace.node = off + (sign^0);
					trace.tnear = d;
					trace.tfar = t_far;
					stack.push(trace);
					t_far = d;
				}
            node.internal = KdNodes[idx];
        }

        /* Intersect the inner triangles of the leaf and exit if possible */
        uint idx = node.get_list(), count = node.tri_count();
		
		uint i = 0;
        for (; i < count; ++i) 
		{
            uint id = KdIds[idx + i];
            if(shadow_ray_tri(id, ray, t_near, t_far))
				break;
        }

        if(stack.pop()) 
		{
            trace_t trace = stack.get();
            node.internal = KdNodes[trace.node];
            t_near = trace.tnear;
            t_far = trace.tfar;
        } else 
			break;
		
		if(i < count)
			return true;
    }
	return false;
}

[numthreads(1, 1, 1)]
void CSRayTraceShadowed(uint3 DTid : SV_DispatchThreadID)
{
	hit_t hit;
	hit.create();

    // Set the current ray
    float cx = float(DTid.x), cy = float(DTid.y);
    float3 scan_line = cam.screen_sampler.top + cam.screen_sampler.dx * cx + cam.screen_sampler.dy * cy;
    ray_t ray;
    ray.dir = normalize(scan_line);
    ray.set_pos(cam.get_eye());


    // Cast the ray
    trace_kdtree(ray, hit);
    
	if(hit.id != ~0u)
	{
		float3 n = KdNormals[hit.id];

		float3 dir = ray.pos + ray.dir * hit.t + n * 1e-3f - lightPos;
		float len = length(dir);
		float3 dirn = dir / len;

		ray_t sray;
		sray.set_pos(lightPos);
		sray.set_dir(dirn);
		bool shadowed = shadow_kdtree(sray, 0.f, len);
		
		RtImage[DTid.y * windowWidth + DTid.x] = (shadowed? 0.3f : 1.0f) * saturate(-dot(n, dirn)) / (len*len);
	}
	else
		RtImage[DTid.y * windowWidth + DTid.x] = 0.0f;
}

[numthreads(1, 1, 1)]
void CSRayTraceUnshadowed(uint3 DTid : SV_DispatchThreadID)
{
	hit_t hit;
	hit.create();

    // Set the current ray
    float cx = float(DTid.x), cy = float(DTid.y);
    float3 scan_line = cam.screen_sampler.top + cam.screen_sampler.dx * cx + cam.screen_sampler.dy * cy;
    ray_t ray;
    ray.dir = normalize(scan_line);
    ray.set_pos(cam.get_eye());


    // Cast the ray
    trace_kdtree(ray, hit);
    
	if(hit.id != ~0u)
	{
		float3 n = KdNormals[hit.id];

		float3 dir = ray.pos + ray.dir * hit.t + n * 1e-3f - lightPos;
		float len = length(dir);
		float3 dirn = dir / len;
		RtImage[DTid.y * windowWidth + DTid.x] = saturate(-dot(n, dirn)) / (len*len);
	}
	else
		RtImage[DTid.y * windowWidth + DTid.x] = 0.0f;
}

[numthreads(1, 1, 1)]
void CSRayTraceNoshading(uint3 DTid : SV_DispatchThreadID)
{
	hit_t hit;
	hit.create();

    // Set the current ray
    float cx = float(DTid.x), cy = float(DTid.y);
    float3 scan_line = cam.screen_sampler.top + cam.screen_sampler.dx * cx + cam.screen_sampler.dy * cy;
    ray_t ray;
    ray.dir = normalize(scan_line);
    ray.set_pos(cam.get_eye());


    // Cast the ray
    trace_kdtree(ray, hit);
    
	bool is_hit = (hit.id != ~0u);
	float d = 1.5f / hit.t;
	RtImage[DTid.y * windowWidth + DTid.x] = is_hit ? float4(d * 0.25f, d * 0.5f, d, 1.0f) : 0.0f;
}