/*
 *   This program is free software; you can redistribute it and/or modify
 *   it under the terms of the GNU General Public License version 2
 *   as published by the Free Software Foundation.
 *
 *   This program is distributed in the hope that it will be useful,
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *   GNU General Public License for more details.
 *
 *   You should have received a copy of the GNU General Public License
 *   along with this program; if not, write to the Free Software
 *   Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
 *
 *   Copyright (C) 2008  Benjamin Segovia <segovia.benjamin@gmail.com>
 */

#ifdef _WIN32
    #define WINDOWS_LEAN_AND_MEAN
    #define NOMINMAX
    #include <windows.h>
#pragma warning(disable:4996)
#endif

#include <stdlib.h>
#include <stdio.h>
#include <string.h>
#include <math.h>
#include <assert.h>

#include "specifics.h"
#include "sys_clock.h"
#include "sys_log.h"
#include "sys_map.h"
#include "bvhlib.h"
#include "kdlib.h"
#include "kdlib_blockifier.h"
#include "rt_kdtree.h"
#include "rt_camera.h"
#include "sys_mem.h"

#include <d3d11.h>
#include <d3dcompiler.h>
#include <d3dx11.h>

#include "DXUT.h"

// Compute Shader stuff
ID3D11ComputeShader*        g_pCSNoShading = NULL;
ID3D11ComputeShader*        g_pCSUnshadowed = NULL;
ID3D11ComputeShader*        g_pCSShadowed = NULL;

// Buffer for ray-tracing
ID3D11Buffer*               g_pBufKdNodes = NULL;
ID3D11Buffer*               g_pBufKdIds = NULL;
ID3D11Buffer*               g_pBufRtTris = NULL;
ID3D11Buffer*               g_pBufNormals = NULL;

ID3D11ShaderResourceView*   g_pBufKdNodesSRV = NULL;
ID3D11ShaderResourceView*   g_pBufKdIdsSRV = NULL;
ID3D11ShaderResourceView*   g_pBufRtTrisSRV = NULL;
ID3D11ShaderResourceView*   g_pBufNormalsSRV = NULL;

ID3D11ShaderResourceView*	g_pSRVs[4];

ID3D11Buffer*               g_pBufConstant = NULL;

// Output buffer
ID3D11Buffer*               g_pBufDest = NULL;
ID3D11ShaderResourceView*   g_pBufDestSRV = NULL;
ID3D11UnorderedAccessView*	g_pBufDestUAV = NULL;

// QuadBuffer
ID3D11Buffer*               g_pBufQuad = NULL;
ID3D11InputLayout*          g_pLayoutQuad = NULL;
ID3D11VertexShader*         g_pVSQuad = NULL;
ID3D11PixelShader*          g_pPSQuad = NULL;
ID3D11SamplerState*         g_pSampleStateLinear = NULL;
ID3D11Buffer*               g_pBufConstantQuad = NULL;

ID3D11DepthStencilState*	g_pDss;

struct SCREEN_VERTEX
{
    D3DXVECTOR4 pos;
    D3DXVECTOR2 tex;
};

D3D11_VIEWPORT g_VP;

/* The camera */
rt::camera_t cam;

/* The current display mode */
enum display_mode_t 
{
	shadowed = 0, 
	unshadowed, 
	noshading 
};

uint32_t display_mode = unshadowed;

/* The light position */
float lpos[3];

/* The normals of the triangles */
std::vector<vec_t> normals;

/* The bounding box of the scene */
aabb_t aabb;

/* The initial position */
vec_t eye(0.f, 0.5f, 2.4f);
float fov = 90.f;

/* The kd-tree we want to intersect on the GPU */
kdtree::descriptor_t app_kd_tree;

extern void run(int argc, char** argv);

/***************************************************************************//**
 * Compute the accelerated triangles
/******************************************************************************/
static bool_t compute_tri_acc(const triangle_t &t, rt::wald_tri_t &w, const uint_t id, const uint_t mat_id)
{
    const vec_t &A(t.verts[0]), &B(t.verts[1]), &C(t.verts[2]);
    const vec_t b(C - A), c(B - A), N(b.cross(c));
    uint_t k = 0;
    for (uint_t i = 1; i < 3; ++i) k = fabsf(N[i]) > fabsf(N[k]) ? i : k;
    const uint_t u = (k+1)%3, v = (k+2)%3;
    const float denom = (b[u]*c[v] - b[v]*c[u]);
    const float krec = N[k];
    const float nu = N[u] / krec, nv = N[v] / krec, nd = N.dot(A) / krec;
    const float bnu =  b[u] / denom, bnv = -b[v] / denom;
    const float cnu =  c[v] / denom, cnv = -c[u] / denom;
    w.k = k;
    w.n_u = float(nu);
    w.n_v = float(nv);
    w.n_d = float(nd);
    w.vert_ku = float(A[u]);
    w.vert_kv = float(A[v]);
    w.b_nu = float(bnu);
    w.b_nv = float(bnv);
    w.c_nu = float(cnu);
    w.c_nv = float(cnv);
    w.id = id;
    w.matid = mat_id;
    return (krec == 0.) | (denom == 0.);
}

/***************************************************************************//**
 * Compile the custom triangles used for the intersections
/******************************************************************************/
static NOINLINE void bake_intersection( const triangle_t * __restrict const tri, const uint_t tri_n, rt::wald_tri_t * __restrict acc)
{
    int deg = 0;
    for (uint_t tid=0; tid < tri_n; ++tid)
        deg += compute_tri_acc(tri[tid], acc[tid], tid, 0);
    sys::log("bake_intersection: %d triangles, %d degenerated.\n", tri_n, deg);
}

/***************************************************************************//**
 * Load the geometry and compile the kd-tree and the custom triangles
/******************************************************************************/
static NOINLINE int scene_compile(kdtree::descriptor_t &kd_tree, const char *file)
{
    sys::log("loading %s...\n", file);
    sys::map_t m;

    /* Load the file and compile the BVH tree */
    m.open(file);
    if (!m.is_mapped()) fatal("failed to mmap scene data.");
    const triangle_t * __restrict const soup = (const triangle_t * __restrict const) m.begin();
    const uint32_t tri_n = m.get_size<triangle_t>();
    kdlib_compile(soup, tri_n, kd_tree, aabb);
    kd_tree.acc = (rt::wald_tri_t *)sys::mem::allocate(tri_n*sizeof(rt::wald_tri_t));
    bake_intersection(soup, tri_n, (rt::wald_tri_t *)kd_tree.acc);

    /* Perform a block allocation of the kd-tree */
    kdtree::descriptor_t dst;
    kdlib::do_blockify(dst, kd_tree, 4);
    sys::mem::liberate((void *) kd_tree.acc);
    sys::mem::liberate((void *) kd_tree.ids);
    sys::mem::liberate((void *) kd_tree.root);
    kd_tree = dst;

    /* Allocate the per-triangle normal vectors */
    normals.resize(tri_n);
    for(uint32_t i = 0; i < tri_n; ++i) {
        const vec_t d0 = soup[i].verts[2] - soup[i].verts[0];
        const vec_t d1 = soup[i].verts[1] - soup[i].verts[0];
        normals[i] = d1.cross(d0);
        normals[i] = normals[i].normalize();
    }
    return tri_n;
}

/***************************************************************************//**
 * Main Program
/******************************************************************************/
int main(int argc, char** argv)
{
    run(argc, argv);
    return EXIT_SUCCESS;
}


//--------------------------------------------------------------------------------------
// Compile and create the CS
//--------------------------------------------------------------------------------------
HRESULT CreateComputeShader( LPCWSTR pSrcFile, LPCSTR pFunctionName, LPCSTR pProfile,
                             ID3D11Device* pDevice, ID3D11ComputeShader** ppShaderOut )
{
    HRESULT hr;
   
    ID3DBlob* pErrorBlob = NULL;
    ID3DBlob* pBlob = NULL;  //D3D10_SHADER_SKIP_OPTIMIZATION | D3D10_SHADER_DEBUG
    hr = D3DX11CompileFromFile( pSrcFile, NULL, NULL, pFunctionName, pProfile, D3D10_SHADER_OPTIMIZATION_LEVEL3, NULL, NULL, &pBlob, &pErrorBlob, NULL );
    if ( FAILED(hr) )
    {
        if ( pErrorBlob )
            sys::log( (char*)pErrorBlob->GetBufferPointer() );

        SAFE_RELEASE( pErrorBlob );
        SAFE_RELEASE( pBlob );    

        return hr;
    }    

    hr = pDevice->CreateComputeShader( pBlob->GetBufferPointer(), pBlob->GetBufferSize(), NULL, ppShaderOut );

    SAFE_RELEASE( pErrorBlob );
    SAFE_RELEASE( pBlob );

    return hr;
}

//--------------------------------------------------------------------------------------
// Compile and create a VS
//--------------------------------------------------------------------------------------
HRESULT CreateVertexShader( LPCWSTR pSrcFile, LPCSTR pFunctionName, LPCSTR pProfile,
                             ID3D11Device* pDevice, D3D11_INPUT_ELEMENT_DESC *layout, UINT numElements, ID3D11VertexShader** ppShaderOut, ID3D11InputLayout** ppLayoutOut )
{
    HRESULT hr;
   
    ID3DBlob* pErrorBlob = NULL;
    ID3DBlob* pBlob = NULL;
    hr = D3DX11CompileFromFile( pSrcFile, NULL, NULL, pFunctionName, pProfile, D3D10_SHADER_ENABLE_STRICTNESS, NULL, NULL, &pBlob, &pErrorBlob, NULL );
    if ( FAILED(hr) )
    {
        if ( pErrorBlob )
            sys::log( (char*)pErrorBlob->GetBufferPointer() );

        SAFE_RELEASE( pErrorBlob );
        SAFE_RELEASE( pBlob );    

        return hr;
    }    

    hr = pDevice->CreateVertexShader( pBlob->GetBufferPointer(), pBlob->GetBufferSize(), NULL, ppShaderOut );
	
	if(FAILED(hr)) 
	{
		SAFE_RELEASE( pErrorBlob );
		SAFE_RELEASE( pBlob );
		return hr;
	}
    
	hr = pDevice->CreateInputLayout(layout, numElements, pBlob->GetBufferPointer(), pBlob->GetBufferSize(),  ppLayoutOut);

    SAFE_RELEASE( pErrorBlob );
    SAFE_RELEASE( pBlob );

    return hr;
}

//--------------------------------------------------------------------------------------
// Compile and create PS
//--------------------------------------------------------------------------------------
HRESULT CreatePixelShader( LPCWSTR pSrcFile, LPCSTR pFunctionName, LPCSTR pProfile,
                             ID3D11Device* pDevice, ID3D11PixelShader** ppShaderOut )
{
    HRESULT hr;
   
    ID3DBlob* pErrorBlob = NULL;
    ID3DBlob* pBlob = NULL;
    hr = D3DX11CompileFromFile( pSrcFile, NULL, NULL, pFunctionName, pProfile, D3D10_SHADER_ENABLE_STRICTNESS, NULL, NULL, &pBlob, &pErrorBlob, NULL );
    if ( FAILED(hr) )
    {
        if ( pErrorBlob )
            sys::log( (char*)pErrorBlob->GetBufferPointer() );

        SAFE_RELEASE( pErrorBlob );
        SAFE_RELEASE( pBlob );    

        return hr;
    }    

    hr = pDevice->CreatePixelShader( pBlob->GetBufferPointer(), pBlob->GetBufferSize(), NULL, ppShaderOut );

    SAFE_RELEASE( pErrorBlob );
    SAFE_RELEASE( pBlob );

    return hr;
}

//--------------------------------------------------------------------------------------
// Create Structured Buffer on GPU
//--------------------------------------------------------------------------------------
HRESULT CreateStructuredBufferOnGPU( ID3D11Device* pDevice, UINT uElementSize, UINT uStride, UINT uCount, const VOID* pInitData, ID3D11Buffer** ppBufOut )
{
    *ppBufOut = NULL;

    D3D11_BUFFER_DESC desc;
    ZeroMemory( &desc, sizeof(desc) );
    desc.BindFlags = D3D11_BIND_UNORDERED_ACCESS | D3D11_BIND_SHADER_RESOURCE;
    desc.ByteWidth = uElementSize * uCount;
    desc.MiscFlags = D3D11_RESOURCE_MISC_BUFFER_STRUCTURED;
    desc.StructureByteStride = uStride;

    if ( pInitData )
    {
        D3D11_SUBRESOURCE_DATA InitData;
        InitData.pSysMem = pInitData;
        return pDevice->CreateBuffer( &desc, &InitData, ppBufOut );
    } else
        return pDevice->CreateBuffer( &desc, NULL, ppBufOut );
}

//--------------------------------------------------------------------------------------
// Create Shader Resource View for Structured or Raw Buffers
//--------------------------------------------------------------------------------------
HRESULT CreateBufferSRV( ID3D11Device* pDevice, ID3D11Buffer* pBuffer, ID3D11ShaderResourceView** ppSRVOut )
{
    D3D11_BUFFER_DESC descBuf;
    ZeroMemory( &descBuf, sizeof(descBuf) );
    pBuffer->GetDesc( &descBuf );

    D3D11_SHADER_RESOURCE_VIEW_DESC desc;
    ZeroMemory( &desc, sizeof(desc) );
    desc.ViewDimension = D3D11_SRV_DIMENSION_BUFFEREX;
    desc.BufferEx.FirstElement = 0;

    if ( descBuf.MiscFlags & D3D11_RESOURCE_MISC_BUFFER_ALLOW_RAW_VIEWS )
    {
        // This is a Raw Buffer

        desc.Format = DXGI_FORMAT_R32_TYPELESS;
        desc.BufferEx.Flags = D3D11_BUFFEREX_SRV_FLAG_RAW;
        desc.BufferEx.NumElements = descBuf.ByteWidth / 4;
    } else
    if ( descBuf.MiscFlags & D3D11_RESOURCE_MISC_BUFFER_STRUCTURED )
    {
        // This is a Structured Buffer

        desc.Format = DXGI_FORMAT_UNKNOWN;
        desc.BufferEx.NumElements = descBuf.ByteWidth / descBuf.StructureByteStride;
    } else
    {
        return E_INVALIDARG;
    }

    return pDevice->CreateShaderResourceView( pBuffer, &desc, ppSRVOut );
}
//--------------------------------------------------------------------------------------
// Reject any D3D11 devices that aren't acceptable by returning false
//--------------------------------------------------------------------------------------
bool CALLBACK IsD3D11DeviceAcceptable( const CD3D11EnumAdapterInfo *AdapterInfo, UINT Output, const CD3D11EnumDeviceInfo *DeviceInfo,
                                       DXGI_FORMAT BackBufferFormat, bool bWindowed, void* pUserContext )
{
    if ( DeviceInfo->ComputeShaders_Plus_RawAndStructuredBuffers_Via_Shader_4_x == FALSE )
        return false;
	return true;
}

//--------------------------------------------------------------------------------------
// Create any D3D11 resources that depend on the back buffer
//--------------------------------------------------------------------------------------
HRESULT CALLBACK OnD3D11ResizedSwapChain( ID3D11Device* pd3dDevice, IDXGISwapChain* pSwapChain,
                                          const DXGI_SURFACE_DESC* pBackBufferSurfaceDesc, void* pUserContext )
{
	g_VP.Width = (float)pBackBufferSurfaceDesc->Width;
    g_VP.Height = (float)pBackBufferSurfaceDesc->Height;
    g_VP.MinDepth = 0.0f;
    g_VP.MaxDepth = 1.0f;
    g_VP.TopLeftX = 0;
    g_VP.TopLeftY = 0;

	HRESULT hr = CreateStructuredBufferOnGPU(pd3dDevice, sizeof(float) * 4, sizeof(float) * 4, pBackBufferSurfaceDesc->Width * pBackBufferSurfaceDesc->Height, NULL, &g_pBufDest);
    if(FAILED(hr)) return hr;

	hr = CreateBufferSRV(pd3dDevice, g_pBufDest, &g_pBufDestSRV);
	if(FAILED(hr)) return hr;

	D3D11_UNORDERED_ACCESS_VIEW_DESC DescUAV;
    ZeroMemory( &DescUAV, sizeof(D3D11_UNORDERED_ACCESS_VIEW_DESC) );
    DescUAV.Format = DXGI_FORMAT_UNKNOWN;
    DescUAV.ViewDimension = D3D11_UAV_DIMENSION_BUFFER;
	DescUAV.Buffer.FirstElement = 0;
	DescUAV.Buffer.NumElements = pBackBufferSurfaceDesc->Width * pBackBufferSurfaceDesc->Height;
    hr = pd3dDevice->CreateUnorderedAccessView( g_pBufDest, &DescUAV, &g_pBufDestUAV );
	if(FAILED(hr)) return hr;

	return S_OK;
}

//--------------------------------------------------------------------------------------
// Release D3D11 resources created in OnD3D11ResizedSwapChain 
//--------------------------------------------------------------------------------------
void CALLBACK OnD3D11ReleasingSwapChain( void* pUserContext )
{
	SAFE_RELEASE(g_pBufDestSRV);
	SAFE_RELEASE(g_pBufDestUAV);
	SAFE_RELEASE(g_pBufDest);
}

//--------------------------------------------------------------------------------------
// Run CS
//-------------------------------------------------------------------------------------- 
void RunComputeShader( ID3D11DeviceContext* pd3dImmediateContext,
                      ID3D11ComputeShader* pComputeShader,
                      UINT nNumViews, ID3D11ShaderResourceView** pShaderResourceViews, 
                      ID3D11Buffer* pCBCS, void* pCSData, DWORD dwNumDataBytes,
                      ID3D11UnorderedAccessView* pUnorderedAccessView,
                      UINT X, UINT Y, UINT Z )
{
    pd3dImmediateContext->CSSetShader( pComputeShader, NULL, 0 );
    pd3dImmediateContext->CSSetShaderResources( 0, nNumViews, pShaderResourceViews );
    pd3dImmediateContext->CSSetUnorderedAccessViews( 0, 1, &pUnorderedAccessView, (UINT*)&pUnorderedAccessView );
    if ( pCBCS )
    {
        ID3D11Buffer* ppCB[1] = { pCBCS };
        pd3dImmediateContext->CSSetConstantBuffers( 0, 1, ppCB );
    }

    pd3dImmediateContext->Dispatch( X, Y, Z );

    ID3D11UnorderedAccessView* ppUAViewNULL[1] = { NULL };
    pd3dImmediateContext->CSSetUnorderedAccessViews( 0, 1, ppUAViewNULL, (UINT*)(&ppUAViewNULL) );

    ID3D11ShaderResourceView* ppSRVNULL[3] = { NULL, NULL, NULL };
    pd3dImmediateContext->CSSetShaderResources( 0, 3, ppSRVNULL );
    pd3dImmediateContext->CSSetConstantBuffers( 0, 0, NULL );
}

//--------------------------------------------------------------------------------------
// Create any D3D11 resources that aren't dependant on the back buffer
//--------------------------------------------------------------------------------------
HRESULT CALLBACK OnD3D11CreateDevice( ID3D11Device* pd3dDevice, const DXGI_SURFACE_DESC* pBackBufferSurfaceDesc,
                                      void* pUserContext )
{
    HRESULT res = CreateComputeShader(L"Shader\\RayTrace.hlsl", "CSRayTraceNoshading", "cs_4_0", pd3dDevice, &g_pCSNoShading);
	if(FAILED(res)) return res;

	res = CreateComputeShader(L"Shader\\RayTrace.hlsl", "CSRayTraceUnshadowed", "cs_4_0", pd3dDevice, &g_pCSUnshadowed);
	if(FAILED(res)) return res;

	res = CreateComputeShader(L"Shader\\RayTrace.hlsl", "CSRayTraceShadowed", "cs_4_0", pd3dDevice, &g_pCSShadowed);
	if(FAILED(res)) return res;

	// Setup constant buffers
    D3D11_BUFFER_DESC Desc;
    Desc.Usage = D3D11_USAGE_DYNAMIC;
    Desc.BindFlags = D3D11_BIND_CONSTANT_BUFFER;
    Desc.CPUAccessFlags = D3D11_CPU_ACCESS_WRITE;
    Desc.MiscFlags = 0;

    // Vertex shader buffer
    Desc.ByteWidth = 176; // From shader
    res = pd3dDevice->CreateBuffer(&Desc, NULL, &g_pBufConstant );
	if(FAILED(res)) return res;

	 // Vertex shader buffer
    Desc.ByteWidth = 16; // From shader
	res = pd3dDevice->CreateBuffer(&Desc, NULL, &g_pBufConstantQuad );
	if(FAILED(res)) return res;
	
	// Allocate the device buffers
	res = CreateStructuredBufferOnGPU(pd3dDevice, sizeof(app_kd_tree.root[0]), sizeof(app_kd_tree.root[0]), app_kd_tree.node_n, &app_kd_tree.root[0], &g_pBufKdNodes);
    if(FAILED(res)) return res;
	res = CreateStructuredBufferOnGPU(pd3dDevice, sizeof(app_kd_tree.ids[0]), sizeof(app_kd_tree.ids[0]), app_kd_tree.id_n, &app_kd_tree.ids[0], &g_pBufKdIds);
    if(FAILED(res)) return res;
	res = CreateStructuredBufferOnGPU(pd3dDevice, sizeof(app_kd_tree.acc[0]), sizeof(UINT) * 4, app_kd_tree.tri_n, &app_kd_tree.acc[0], &g_pBufRtTris);
    if(FAILED(res)) return res;
	res = CreateStructuredBufferOnGPU(pd3dDevice, sizeof(normals[0]), sizeof(normals[0]), app_kd_tree.tri_n, &normals[0], &g_pBufNormals);
    if(FAILED(res)) return res;

	res = CreateBufferSRV(pd3dDevice, g_pBufKdNodes, &g_pBufKdNodesSRV);
	if(FAILED(res)) return res;
	res = CreateBufferSRV(pd3dDevice, g_pBufKdIds, &g_pBufKdIdsSRV);
	if(FAILED(res)) return res;
	res = CreateBufferSRV(pd3dDevice, g_pBufRtTris, &g_pBufRtTrisSRV);
	if(FAILED(res)) return res;
	res = CreateBufferSRV(pd3dDevice, g_pBufNormals, &g_pBufNormalsSRV);
	if(FAILED(res)) return res;

	g_pSRVs[0] = g_pBufKdNodesSRV;
	g_pSRVs[1] = g_pBufKdIdsSRV;
	g_pSRVs[2] = g_pBufRtTrisSRV;
	g_pSRVs[3] = g_pBufNormalsSRV;

	D3D11_INPUT_ELEMENT_DESC quadlayout[] =
    {
        { "POSITION", 0, DXGI_FORMAT_R32G32B32A32_FLOAT, 0, 0, D3D11_INPUT_PER_VERTEX_DATA, 0 },
        { "TEXCOORD", 0, DXGI_FORMAT_R32G32_FLOAT, 0, 16, D3D11_INPUT_PER_VERTEX_DATA, 0 },
    };
	
	res = CreateVertexShader(L"Shader\\ScreenQuad.hlsl", "VSQuad", "vs_4_0", pd3dDevice, quadlayout, 2, &g_pVSQuad, &g_pLayoutQuad);
	if(FAILED(res)) return res;

	res = CreatePixelShader(L"Shader\\ScreenQuad.hlsl", "PSQuad", "ps_4_0", pd3dDevice, &g_pPSQuad);
	if(FAILED(res)) return res;

	// Create a screen quad for render to texture operations
    SCREEN_VERTEX svQuad[4];
    svQuad[0].pos = D3DXVECTOR4( -1.0f, 1.0f, 0.5f, 1.0f );
    svQuad[0].tex = D3DXVECTOR2( 0.0f, 0.0f );
    svQuad[1].pos = D3DXVECTOR4( 1.0f, 1.0f, 0.5f, 1.0f );
    svQuad[1].tex = D3DXVECTOR2( 1.0f, 0.0f );
    svQuad[2].pos = D3DXVECTOR4( -1.0f, -1.0f, 0.5f, 1.0f );
    svQuad[2].tex = D3DXVECTOR2( 0.0f, 1.0f );
    svQuad[3].pos = D3DXVECTOR4( 1.0f, -1.0f, 0.5f, 1.0f );
    svQuad[3].tex = D3DXVECTOR2( 1.0f, 1.0f );

    D3D11_BUFFER_DESC vbdesc =
    {
        4 * sizeof( SCREEN_VERTEX ),
        D3D11_USAGE_DEFAULT,
        D3D11_BIND_VERTEX_BUFFER,
        0,
        0
    };
    D3D11_SUBRESOURCE_DATA InitData;
    InitData.pSysMem = svQuad;
    InitData.SysMemPitch = 0;
    InitData.SysMemSlicePitch = 0;
	res = pd3dDevice->CreateBuffer(&vbdesc, &InitData, &g_pBufQuad);
	if(FAILED(res)) return res;

	D3D11_SAMPLER_DESC SamplerDesc;
    ZeroMemory( &SamplerDesc, sizeof(SamplerDesc) );
    SamplerDesc.AddressU = D3D11_TEXTURE_ADDRESS_CLAMP;
    SamplerDesc.AddressV = D3D11_TEXTURE_ADDRESS_CLAMP;
    SamplerDesc.AddressW = D3D11_TEXTURE_ADDRESS_CLAMP;
    SamplerDesc.Filter = D3D11_FILTER_MIN_MAG_MIP_LINEAR;
    res = pd3dDevice->CreateSamplerState( &SamplerDesc, &g_pSampleStateLinear );
	if(FAILED(res)) return res;

	D3D11_DEPTH_STENCIL_DESC dsDesc;
	ZeroMemory(&dsDesc, sizeof(dsDesc));
	dsDesc.DepthEnable = false;
	dsDesc.StencilEnable = false;
	
	pd3dDevice->CreateDepthStencilState(&dsDesc, &g_pDss);
	
	return res;
}


//--------------------------------------------------------------------------------------
// Release D3D11 resources created in OnD3D11CreateDevice 
//--------------------------------------------------------------------------------------
void CALLBACK OnD3D11DestroyDevice( void* pUserContext )
{
	SAFE_RELEASE(g_pCSNoShading);
	SAFE_RELEASE(g_pCSUnshadowed);
	SAFE_RELEASE(g_pCSShadowed);

	SAFE_RELEASE(g_pBufConstant);

	SAFE_RELEASE(g_pBufKdNodesSRV);
	SAFE_RELEASE(g_pBufKdIdsSRV);
	SAFE_RELEASE(g_pBufRtTrisSRV);
	SAFE_RELEASE(g_pBufNormalsSRV);

	SAFE_RELEASE(g_pBufKdNodes);
	SAFE_RELEASE(g_pBufKdIds);
	SAFE_RELEASE(g_pBufRtTris);
	SAFE_RELEASE(g_pBufNormals);

	SAFE_RELEASE(g_pBufQuad);
	SAFE_RELEASE(g_pLayoutQuad);

	SAFE_RELEASE(g_pVSQuad);
	SAFE_RELEASE(g_pPSQuad);

	SAFE_RELEASE(g_pBufConstantQuad);

	SAFE_RELEASE(g_pSampleStateLinear);

	SAFE_RELEASE(g_pDss);
}

//--------------------------------------------------------------------------------------
// Handle updates to the scene.  This is called regardless of which D3D API is used
//--------------------------------------------------------------------------------------
void CALLBACK OnFrameMove( double fTime, float fElapsedTime, void* pUserContext )
{
	/* Update the camera */
    cam.open();
    cam.set_eye(eye);
    cam.set_fovx(fov);
	cam.update(point_t((int)g_VP.Width, (int)g_VP.Height));

    /* Update the light source */
    const float dt = (float) sys::laps_t::to_time(sys::laps_t::get()) * 0.001f;
    lpos[2] = .5f * cosf(dt);
    lpos[0] = 1.f + .5f * sinf(dt);
    lpos[1] = 2.5f;

	if(GetKeyState('W') < 0)
		eye.z -= 0.1f;
	if(GetKeyState('S') < 0)
		eye.z += 0.1f;
	if(GetKeyState('D') < 0)
		eye.x += 0.1f;
	if(GetKeyState('A') < 0)
		eye.x -= 0.1f;
	if(GetKeyState('R') < 0)
		eye.y += 0.1f;
	if(GetKeyState('F') < 0)
		eye.y -= 0.1f;

	if(GetKeyState('1') < 0)
		display_mode = shadowed;
	if(GetKeyState('2') < 0)
		display_mode = unshadowed;
	if(GetKeyState('3') < 0)
		display_mode = noshading;
}


//--------------------------------------------------------------------------------------
// Render the scene using the D3D11 device
//--------------------------------------------------------------------------------------
void CALLBACK OnD3D11FrameRender( ID3D11Device* pd3dDevice, ID3D11DeviceContext* pd3dImmediateContext,
                                  double fTime, float fElapsedTime, void* pUserContext )
{
	pd3dImmediateContext->OMSetDepthStencilState(g_pDss, 0);
	
	D3D11_MAPPED_SUBRESOURCE MappedResource;
	pd3dImmediateContext->Map(g_pBufConstant, 0, D3D11_MAP_WRITE_DISCARD, 0, &MappedResource);
    BYTE* constBuf = ( BYTE* )MappedResource.pData;
	
	// HLSL packing, somehow c++ doesn't pack it the way it should, even with the correct modifiers
	memcpy(constBuf, &cam.get_eye(), sizeof(vec_t));
	constBuf += sizeof(float) * 4;
	memcpy(constBuf, &cam.get_dir(), sizeof(vec_t));
	constBuf += sizeof(float) * 4;
	memcpy(constBuf, &cam.get_up(), sizeof(vec_t));
	constBuf += sizeof(float) * 4;
	memcpy(constBuf, &cam.get_right(), sizeof(vec_t));
	constBuf += sizeof(float) * 3;
	float tempf = cam.get_fovx();
	memcpy(constBuf, &tempf, sizeof(float));
	constBuf += sizeof(float);
	int tempi = cam.get_world_up_index();
	memcpy(constBuf, &tempi, sizeof(int));
	constBuf += sizeof(float) * 4;
	memcpy(constBuf, &cam.sampler.top, sizeof(vec_t));
	constBuf += sizeof(float) * 4;
	memcpy(constBuf, &cam.sampler.dx, sizeof(vec_t));
	constBuf += sizeof(float) * 4;
	memcpy(constBuf, &cam.sampler.dy, sizeof(vec_t));
	constBuf += sizeof(float) * 4;
	memcpy(constBuf, &aabb.pmin, sizeof(vec_t));
	constBuf += sizeof(float) * 3;
	memcpy(constBuf, &aabb.pmax, sizeof(vec_t));
	constBuf += sizeof(float) * 5;
	memcpy(constBuf, &lpos, sizeof(vec_t));
	constBuf += sizeof(float) * 3;
	UINT w = (int)g_VP.Width;
	memcpy(constBuf, &w, sizeof(UINT));
	

    pd3dImmediateContext->Unmap(g_pBufConstant, 0);

	switch(display_mode)
	{
	case shadowed:
		RunComputeShader(pd3dImmediateContext, g_pCSShadowed, 4, g_pSRVs, g_pBufConstant, NULL, 0, g_pBufDestUAV, (int)g_VP.Width, (int)g_VP.Height, 1);
		break;
	case unshadowed:
		RunComputeShader(pd3dImmediateContext, g_pCSUnshadowed, 4, g_pSRVs, g_pBufConstant, NULL, 0, g_pBufDestUAV, (int)g_VP.Width, (int)g_VP.Height, 1);
		break;
	case noshading:
		RunComputeShader(pd3dImmediateContext, g_pCSNoShading, 4, g_pSRVs, g_pBufConstant, NULL, 0, g_pBufDestUAV, (int)g_VP.Width, (int)g_VP.Height, 1);
		break;
	};
	

	// Draw quad with the texture on it
	UINT strides = sizeof( SCREEN_VERTEX );
    UINT offsets = 0;
	
	pd3dImmediateContext->RSSetViewports(1, &g_VP);
	pd3dImmediateContext->IASetInputLayout(g_pLayoutQuad);
	pd3dImmediateContext->IASetVertexBuffers( 0, 1, &g_pBufQuad, &strides, &offsets );
    pd3dImmediateContext->IASetPrimitiveTopology( D3D11_PRIMITIVE_TOPOLOGY_TRIANGLESTRIP );

    pd3dImmediateContext->VSSetShader( g_pVSQuad, NULL, 0 );
    pd3dImmediateContext->PSSetShader( g_pPSQuad, NULL, 0 );	
	pd3dImmediateContext->PSSetShaderResources(0, 1, &g_pBufDestSRV );
	pd3dImmediateContext->PSSetSamplers(0, 1, &g_pSampleStateLinear);
	
	pd3dImmediateContext->Map(g_pBufConstantQuad, 0, D3D11_MAP_WRITE_DISCARD, 0, &MappedResource);
    constBuf = ( BYTE* )MappedResource.pData;
	memcpy(constBuf, &w, sizeof(UINT));
	pd3dImmediateContext->Unmap(g_pBufConstantQuad, 0);
	ID3D11Buffer* ppCB[1] = { g_pBufConstantQuad };
    pd3dImmediateContext->PSSetConstantBuffers( 0, 1, ppCB );
	
    pd3dImmediateContext->Draw( 4, 0 );

	ID3D11ShaderResourceView* ppSRVNULL[1] = { NULL};
	pd3dImmediateContext->PSSetShaderResources(0, 1, ppSRVNULL);
	ppCB[0] = NULL;
	pd3dImmediateContext->PSSetConstantBuffers( 0, 1, ppCB );
}



void run(int argc, char** argv)
{
    /* Load the data file and compile the bvh tree */
    scene_compile(app_kd_tree, "Data/FairyForestF160.ra2");
	
	// Set general DXUT callbacks
    DXUTSetCallbackFrameMove( OnFrameMove );

	// Set the D3D11 DXUT callbacks.
    DXUTSetCallbackD3D11DeviceAcceptable( IsD3D11DeviceAcceptable );
    DXUTSetCallbackD3D11DeviceCreated( OnD3D11CreateDevice );
    DXUTSetCallbackD3D11SwapChainResized( OnD3D11ResizedSwapChain );
    DXUTSetCallbackD3D11FrameRender( OnD3D11FrameRender );
    DXUTSetCallbackD3D11SwapChainReleasing( OnD3D11ReleasingSwapChain );
    DXUTSetCallbackD3D11DeviceDestroyed( OnD3D11DestroyDevice );

    /* Init the timers */
    sys::laps_t::bootstrap();

	DXUTInit( true, true, NULL ); // Parse the command line, show msgboxes on error, no extra command line params
    DXUTSetCursorSettings( true, true ); // Show the cursor and clip it when in full screen
    DXUTCreateWindow( L"DX11 RayTracer" );

    // Only require 10-level hardware
	DXUTCreateDevice( D3D_FEATURE_LEVEL_10_0, true, 512, 512 );

    /* Set a camera here */
    cam.open();
    cam.set_eye(eye);
    cam.set_fovx(fov);
    cam.update(point_t((int)g_VP.Width, (int)g_VP.Height));
    
    DXUTMainLoop(); // Enter into the DXUT render loop
}