﻿using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.InteropServices;
using System.Text;
using System.Threading.Tasks;
using SharpDX;
using SharpDX.D3DCompiler;
using SharpDX.Direct3D;
using SharpDX.Direct3D11;
using SharpDX.DXGI;
using Buffer = SharpDX.Direct3D11.Buffer;
using Device = SharpDX.Direct3D11.Device;

namespace VoluRen
{
    /// <summary>
    /// SliceBased-Renderer mit Depth-Of-Field-Effekt
    /// </summary>
    public class SliceBasedDOFRenderer : IRenderer
    {
        private const float SAMPLINGDIST = 0.01f;

        private BufferWrapper<StructVertexColor> _vertexBufferWrapper2;
        private BufferWrapper<ushort> _indexBufferWrapper2;

        private BufferWrapper<StructVertex> _vertexBufferWrapper;
        private BufferWrapper<Projections> _projectionsBufferWrapper;
        private BufferWrapper<PerFrameDOF> _perFrameDOFBufferWrapper;
        private BufferWrapper<PerFrameSlice> _perFrameSliceBufferWrapper;

        private RenderTargetView _sliceBufferRTV;
        private ShaderResourceView _sliceBufferSRV;

        private RenderTargetView[] _interBufferRTV;
        private ShaderResourceView[] _interBufferSRV;

        private RenderTargetView[] _eyeBufferRTV;
        private ShaderResourceView[] _eyeBufferSRV;

        private Device _device;
        private Camera _camera;

        private VertexShader _sliceVS;
        private PixelShader _slicePS;
        private InputLayout _sliceLayout;

        private VertexShader _blendVS;
        private PixelShader _blendPS;

        private VertexShader _dofVS;
        private PixelShader _dofPS;

        private VertexShader _colorcubeVS;
        private PixelShader _colorcubePS;
        private InputLayout _colorcubeLayout;

        private BoundingBox _boundingBox;
        private BoundingBox _boundingBoxView;

        private Slice _slice;

        private StructVertex[] _vertices;

        private Matrix _normalizeMatrix;

        private RasterizerState _cullNoneState;
        private SamplerState _trilinearSamperState;

        private Vector2 _coCTexSpace;
        private float _radius;

        private int _destIndex;
        private int _sourceIndex;
        private int _width;
        private int _height;

        private BlendState _opaqueBlendState;
        private BlendState _btfBlendState;

        private Color _clearColor;

        /// <summary>
        /// Initialisiert den Renderer
        /// Erzeugt Shader, RenderTargets/ShaderResources,...
        /// </summary>
        /// <param name="device">D3D11 Device</param>
        /// <param name="camera">Kamera-Objekt</param>
        public void Init(Device device, Camera camera)
        {
            _device = device;
            _camera = camera;

            _normalizeMatrix = new Matrix(0.5f, 0, 0, 0.0f, 0, 0.5f, 0, 0.0f, 0, 0, 0.5f, 0.0f, 0.5f, 0.5f, 0.5f, 1.0f);

            _clearColor = new Color(0.0f, 0.0f, 0.0f, 0.0f);

            ShaderFlags flags = ShaderFlags.EnableStrictness;
#if DEBUG
            flags |= ShaderFlags.Debug;
#endif

            var blob = ShaderBytecode.CompileFromFile("resources\\shaders\\Slice.hlsl", "VertexShaderFunction", "vs_5_0", flags, EffectFlags.None);
            var inputsig = ShaderSignature.GetInputSignature(blob);
            _sliceVS = new VertexShader(device, blob);

            blob = ShaderBytecode.CompileFromFile("resources\\shaders\\Slice.hlsl", "PixelShaderFunction", "ps_5_0", flags, EffectFlags.None);
            _slicePS = new PixelShader(device, blob);

            _sliceLayout = new InputLayout(device, inputsig, new[]{
					new InputElement("POSITION", 0, Format.R32G32B32_Float, 0),
				});

            blob = ShaderBytecode.CompileFromFile("resources\\shaders\\FullscreenQuadBlend.hlsl", "VertexShaderFunction", "vs_5_0", flags, EffectFlags.None);
            _blendVS = new VertexShader(device, blob);

            blob = ShaderBytecode.CompileFromFile("resources\\shaders\\FullscreenQuadBlend.hlsl", "PixelShaderFunction", "ps_5_0", flags, EffectFlags.None);
            _blendPS = new PixelShader(device, blob);

            blob = ShaderBytecode.CompileFromFile("resources\\shaders\\DepthOfField.hlsl", "VertexShaderFunction", "vs_5_0", flags, EffectFlags.None);
            _dofVS = new VertexShader(device, blob);

            blob = ShaderBytecode.CompileFromFile("resources\\shaders\\DepthOfField.hlsl", "PixelShaderFunction", "ps_5_0", flags, EffectFlags.None);
            _dofPS = new PixelShader(device, blob);

            blob = ShaderBytecode.CompileFromFile("resources\\shaders\\ColorCube.hlsl", "VertexShaderFunction", "vs_5_0", flags, EffectFlags.None);
            inputsig = ShaderSignature.GetInputSignature(blob);
            _colorcubeVS = new VertexShader(device, blob);

            blob = ShaderBytecode.CompileFromFile("resources\\shaders\\ColorCube.hlsl", "PixelShaderFunction", "ps_5_0", flags, EffectFlags.None);
            _colorcubePS = new PixelShader(device, blob);

            _colorcubeLayout = new InputLayout(device, inputsig, new[]{
					new InputElement("POSITION", 0, Format.R32G32B32_Float, 0),
                    new InputElement("COLOR",0,Format.R32G32B32A32_Float,12,0),
				});

            blob.Dispose();
            inputsig.Dispose();

            BufferDescription vertexbufferdesc = new SharpDX.Direct3D11.BufferDescription
            {
                BindFlags = BindFlags.VertexBuffer | BindFlags.IndexBuffer,
                CpuAccessFlags = CpuAccessFlags.Write,
                OptionFlags = ResourceOptionFlags.None,
                StructureByteStride = 0,
                Usage = ResourceUsage.Dynamic
            };

            _boundingBox = new BoundingBox();
            StructVertexColor[] vertices = new StructVertexColor[(int)BoundingBox.Corner.CornerCOUNT];
            for (int i = 0; i < (int)BoundingBox.Corner.CornerCOUNT; i++)
            {
                vertices[i] = new StructVertexColor();
                vertices[i].Position = _boundingBox.DrawVertices[i];
                vertices[i].Color = new SharpDX.Color4(0.7f, 0.7f, 0.7f, 0.2f);
            }
            vertexbufferdesc.SizeInBytes = Marshal.SizeOf(typeof(StructVertexColor)) * vertices.Count();

            _vertexBufferWrapper2 = new BufferWrapper<StructVertexColor>(device, vertexbufferdesc);
            _vertexBufferWrapper2.ArrayValue = vertices;

            vertexbufferdesc.SizeInBytes = Marshal.SizeOf(typeof(ushort)) * _boundingBox.LineIndices.Count();
            _indexBufferWrapper2 = new BufferWrapper<ushort>(device, vertexbufferdesc);
            _indexBufferWrapper2.ArrayValue = _boundingBox.LineIndices;

            _slice = new Slice(_camera.ViewDirection, 0.0f);

            _vertices = new StructVertex[18];
            vertexbufferdesc.SizeInBytes = Marshal.SizeOf(typeof(StructVertex)) * 18;
            _vertexBufferWrapper = new BufferWrapper<StructVertex>(device, vertexbufferdesc);
            _projectionsBufferWrapper = new BufferWrapper<Projections>(device);
            _perFrameDOFBufferWrapper = new BufferWrapper<PerFrameDOF>(device);
            _perFrameSliceBufferWrapper = new BufferWrapper<PerFrameSlice>(device);

            _trilinearSamperState = new SamplerState(_device, StateManagement.Instance.SamplerLinearClamp);

            _cullNoneState = new RasterizerState(_device, StateManagement.Instance.RCullNoneSolid);

            device.ImmediateContext.InputAssembler.PrimitiveTopology = PrimitiveTopology.TriangleList;

            BlendStateDescription bdesc = StateManagement.Instance.Opaque;
            RenderTargetBlendDescription rtbdesc = StateManagement.Instance.RTOpaque;
            bdesc.RenderTarget[0] = rtbdesc;
            _opaqueBlendState = new BlendState(_device, bdesc);

            bdesc = StateManagement.Instance.BackToFront;
            rtbdesc = StateManagement.Instance.RTBackToFront;
            bdesc.RenderTarget[0] = rtbdesc;
            _btfBlendState = new BlendState(_device, bdesc);
        }

        /// <summary>
        /// Erzeugt die RenderTargets
        /// Muss bei Resize des Fensters aufgerufen werden, damit das Ergebnis den Viewport ausfüllt
        /// </summary>
        /// <param name="width">Viewport-Breite</param>
        /// <param name="height">Viewport-Höhe</param>
        public void CreateRenderTarget(int width, int height)
        {
            _width = width;
            _height = height;

            if (_interBufferRTV != null || _interBufferSRV != null || _eyeBufferRTV != null || _eyeBufferSRV != null)
            {
                for (int i = 0; i < 2; i++)
                {
                    if (_interBufferRTV != null)
                        _interBufferRTV[i].Dispose();

                    if (_interBufferSRV != null)
                        _interBufferSRV[i].Dispose();

                    if (_eyeBufferRTV != null)
                        _eyeBufferRTV[i].Dispose();

                    if (_eyeBufferSRV != null)
                        _eyeBufferSRV[i].Dispose();
                }
            }

            Texture2DDescription desc = new Texture2DDescription();
            desc.Width = width;
            desc.Height = height;
            desc.MipLevels = 1;
            desc.Format = Format.R32G32B32A32_Float;
            desc.Usage = ResourceUsage.Default;
            desc.BindFlags = BindFlags.RenderTarget | BindFlags.ShaderResource;
            desc.CpuAccessFlags = CpuAccessFlags.None;
            desc.ArraySize = 1;
            desc.SampleDescription.Count = 1;

            Texture2D buffer = new Texture2D(_device, desc);

            if (_sliceBufferRTV != null)
                _sliceBufferRTV.Dispose();
            if (_sliceBufferSRV != null)
                _sliceBufferSRV.Dispose();

            _sliceBufferRTV = new RenderTargetView(_device, buffer);
            _sliceBufferSRV = new ShaderResourceView(_device, buffer);

            buffer.Dispose();

            Texture2D[] iBuffer = new Texture2D[2];
            _interBufferRTV = new RenderTargetView[2];
            _interBufferSRV = new ShaderResourceView[2];

            Texture2D[] eBuffer = new Texture2D[2];
            _eyeBufferRTV = new RenderTargetView[2];
            _eyeBufferSRV = new ShaderResourceView[2];
            for (int i = 0; i < 2; i++)
            {
                iBuffer[i] = new Texture2D(_device, desc);
                _interBufferRTV[i] = new RenderTargetView(_device, iBuffer[i]);
                _interBufferSRV[i] = new ShaderResourceView(_device, iBuffer[i]);
                
                eBuffer[i] = new Texture2D(_device, desc);
                _eyeBufferRTV[i] = new RenderTargetView(_device, eBuffer[i]);
                _eyeBufferSRV[i] = new ShaderResourceView(_device, eBuffer[i]);

                iBuffer[i].Dispose();
                eBuffer[i].Dispose();
            }
        }

        /// <summary>
        /// Berechnet die BoundingBox in ViewSpace
        /// Wird für die Schnitt-Berechnung des Slicers benötigt
        /// </summary>
        /// <param name="nearPt">Punkt am Nähesten zur Kamera</param>
        /// <param name="farPt">Punkt am weitesten weg von der Kamera</param>
        /// <returns>ViewSpace-BoundingBox</returns>
        private BoundingBox CalculateViewSpaceBB(out Vector3 nearPt, out Vector3 farPt)
        {
            Vector3[] bbviewcorners = new Vector3[(int)BoundingBox.Corner.CornerCOUNT];
            Vector3 worldTemp = Vector3.Zero;

            nearPt = new Vector3(float.MaxValue, float.MaxValue, float.MaxValue);
            farPt = new Vector3(float.MinValue, float.MinValue, float.MinValue);

            for (int i = 0; i < (int)BoundingBox.Corner.CornerCOUNT; i++)
            {
                worldTemp = (Vector3)Vector3.Transform(_boundingBox.Vertices[i], VolumeManagement.Instance.ModelMatrix);
                bbviewcorners[i] = (Vector3)Vector3.Transform(worldTemp, _camera.ViewMatrix);

                if (bbviewcorners[i].Z < nearPt.Z)
                    nearPt = bbviewcorners[i];

                if (bbviewcorners[i].Z > farPt.Z)
                    farPt = bbviewcorners[i];
            }

            return new BoundingBox(bbviewcorners);
        }

        /// <summary>
        /// Erzeugt die ProxyGeometry für jeden einzelnen Slice und zeichnet die Slices in einen Buffer (Slice-Shader)
        /// Slices werden mit vorhergehenden kombiniert und der DOF-Effekt erzeugt (DepthOfField-Shader)
        /// </summary>
        /// <param name="startZ">Start der Iteration durch das Volumen</param>
        /// <param name="slicecnt">Anzahl der Slices, die erzeugt werden müssen</param>
        /// <param name="btf">true - wenn EyeBuffer befüllt wird (BackToFront), false - wenn IntermediateBuffer befüllt wird (FrontToBack)</param>
        /// <param name="focus">Distanz zum Fokuspunkt (zwischen 0.0f = 1. Slice in View-Richtung, 1.0f = Letzte Slice in View-Richtung)</param>
        private void CreateProxyAndDraw(float startZ, int slicecnt, bool btf, float focus)
        {
            List<Vector3> intersections = null;
            
            _sourceIndex = 0;
            _destIndex = 1;

            for (int i = 0; i < slicecnt; i++)
            {
                if (btf)
                    _slice.Distance = startZ - i * SAMPLINGDIST;
                else
                    _slice.Distance = startZ + i * SAMPLINGDIST;

                intersections = _slice.CreateProxyGeometry(_boundingBoxView);

                if (intersections != null)
                {
                    CalculateCircleOfConfusion(_slice.Distance, focus);
                    DrawSlice(intersections);

                    if (btf)
                    {
                        _device.ImmediateContext.OutputMerger.SetTargets(_eyeBufferRTV[_destIndex]);
                        _device.ImmediateContext.ClearRenderTargetView(_eyeBufferRTV[_destIndex], _clearColor);
                    }
                    else
                    {
                        _device.ImmediateContext.OutputMerger.SetTargets(_interBufferRTV[_destIndex]);
                        _device.ImmediateContext.ClearRenderTargetView(_interBufferRTV[_destIndex], _clearColor);
                    }
                    
                    _device.ImmediateContext.Rasterizer.State = _cullNoneState;

                    _perFrameDOFBufferWrapper.Value = new PerFrameDOF()
                    {
                        CoCTexSpace = _coCTexSpace,
                        BackToFront = btf ? 0f : 1f,
                    };

                    _device.ImmediateContext.VertexShader.Set(_dofVS);
                    _device.ImmediateContext.PixelShader.Set(_dofPS);
                    _device.ImmediateContext.PixelShader.SetConstantBuffer(0, _perFrameDOFBufferWrapper.Buffer);
                    _device.ImmediateContext.PixelShader.SetSampler(0, _trilinearSamperState);
                    _device.ImmediateContext.PixelShader.SetShaderResource(2, btf ? _eyeBufferSRV[_sourceIndex] : _interBufferSRV[_sourceIndex]);
                    _device.ImmediateContext.PixelShader.SetShaderResource(3, _sliceBufferSRV);
                    _device.ImmediateContext.Draw(3, 0);

                    _device.ImmediateContext.PixelShader.SetShaderResource(2, null);
                    _device.ImmediateContext.PixelShader.SetShaderResource(3, null);

                    _destIndex = 1 - _destIndex;
                    _sourceIndex = 1 - _sourceIndex;
                }
            }
        }

        /// <summary>
        /// Berechnet den Circle of Confusion anhand des Abstandes zum Fokuspunkt
        /// </summary>
        /// <param name="z">Z-Position des aktuellen Slices in ViewSpace</param>
        /// <param name="zf">Z-Position des Fokuspunktes in ViewSpace</param>
        private void CalculateCircleOfConfusion(float z, float zf)
        {
            _radius = InputManagement.Instance.R * Math.Abs(z - zf);

            Vector4 coCViewSpace = new Vector4(_radius, _radius, Math.Abs(z), 1.0f);
            Vector4 coCClipSpace = Vector4.Transform(coCViewSpace, _camera.ProjectionMatrix);
            Vector4 tmp = Vector4.Transform(coCClipSpace, _normalizeMatrix);
            _coCTexSpace = new Vector2(tmp.X / tmp.W, tmp.Y / tmp.W);
            _coCTexSpace.X = _coCTexSpace.X / _width;
            _coCTexSpace.Y = _coCTexSpace.Y / _height;
        }

        /// <summary>
        /// Zeichnet den aktuellen Frame
        /// </summary>
        /// <param name="args">Zeitdaten (hier ungenutzt -> können null sein)</param>
        public void Draw(TimeEventArgs args)
        {
            Vector3 nearPtView, farPtView;

            _device.ImmediateContext.ClearRenderTargetView(_interBufferRTV[0], _clearColor);
            _device.ImmediateContext.ClearRenderTargetView(_interBufferRTV[1], _clearColor);
            _device.ImmediateContext.ClearRenderTargetView(_eyeBufferRTV[0], _clearColor);
            _device.ImmediateContext.ClearRenderTargetView(_eyeBufferRTV[1], _clearColor);

            _boundingBoxView = CalculateViewSpaceBB(out nearPtView, out farPtView);

            Vector3 dir = farPtView - nearPtView;
            Vector3 focusPt = nearPtView + Vector3.Normalize(dir) * InputManagement.Instance.FocusPlanePosition * dir.Length();

            float fulldistance = Math.Abs(farPtView.Z - nearPtView.Z);
            float frontdistance = Math.Abs(focusPt.Z - nearPtView.Z);

            int totalSliceCnt = (int)(Math.Round(fulldistance / SAMPLINGDIST));
            int frontSliceCnt = (int)(Math.Round(frontdistance / SAMPLINGDIST));
            int backSliceCnt = totalSliceCnt - frontSliceCnt;

            //btf rendering
            CreateProxyAndDraw(farPtView.Z, backSliceCnt, true, focusPt.Z);

            //ftb rendering
            CreateProxyAndDraw(nearPtView.Z, frontSliceCnt, false, focusPt.Z);

            #region blend results
            _device.ImmediateContext.OutputMerger.SetTargets(MainForm._renderView);

            _device.ImmediateContext.OutputMerger.BlendState = _btfBlendState;

            if(InputManagement.Instance.DrawBB)
                DrawBoundingBox();

            _device.ImmediateContext.VertexShader.Set(_blendVS);
            _device.ImmediateContext.PixelShader.Set(_blendPS);
            _device.ImmediateContext.PixelShader.SetSampler(0, _trilinearSamperState);
            _device.ImmediateContext.PixelShader.SetShaderResource(2, _eyeBufferSRV[_destIndex]);
            _device.ImmediateContext.PixelShader.SetShaderResource(3, _interBufferSRV[_destIndex]);
            _device.ImmediateContext.Draw(3, 0);
            _device.ImmediateContext.PixelShader.SetShaderResource(2, null);
            _device.ImmediateContext.PixelShader.SetShaderResource(3, null);
            #endregion
        }

        /// <summary>
        /// Zeichnet die BoundingBox mit einer LineList
        /// </summary>
        private void DrawBoundingBox()
        {
            _device.ImmediateContext.InputAssembler.PrimitiveTopology = PrimitiveTopology.LineList;

            _device.ImmediateContext.InputAssembler.InputLayout = _colorcubeLayout;
            _device.ImmediateContext.InputAssembler.SetVertexBuffers(0, new VertexBufferBinding(_vertexBufferWrapper2.Buffer, Marshal.SizeOf(typeof(StructVertexColor)), 0));
            _device.ImmediateContext.InputAssembler.SetIndexBuffer(_indexBufferWrapper2.Buffer, Format.R16_UInt, 0);

            _projectionsBufferWrapper.Value = new Projections()
            {
                Model = Matrix.Transpose(Matrix.Identity),
                View = Matrix.Transpose(_camera.ViewMatrix),
                Projection = Matrix.Transpose(_camera.ProjectionMatrix),
            };
            _device.ImmediateContext.VertexShader.Set(_colorcubeVS);
            _device.ImmediateContext.VertexShader.SetConstantBuffer(0, _projectionsBufferWrapper.Buffer);
            _device.ImmediateContext.PixelShader.Set(_colorcubePS);
            _device.ImmediateContext.DrawIndexed(24, 0, 0);

            _device.ImmediateContext.InputAssembler.PrimitiveTopology = PrimitiveTopology.TriangleList;
        }

        /// <summary>
        /// Zeichnet den aktuellen Slice/ProxyGeometry
        /// </summary>
        /// <param name="intersections">Vertexliste in ViewSpace</param>
        private void DrawSlice(List<Vector3> intersections)
        {
            _device.ImmediateContext.OutputMerger.BlendState = _opaqueBlendState;

            _device.ImmediateContext.Rasterizer.State = _cullNoneState;
            _device.ImmediateContext.OutputMerger.SetTargets(_sliceBufferRTV);
            _device.ImmediateContext.ClearRenderTargetView(_sliceBufferRTV, _clearColor);

            for (int i = 0; i < intersections.Count; i++)
            {
                _vertices[i].Position = (Vector3)Vector3.Transform(intersections[i], _camera.InverseViewMatrix);
                _vertices[i].Position = (Vector3)Vector3.Transform(_vertices[i].Position, VolumeManagement.Instance.InverseModelMatrix);
            }

            _vertexBufferWrapper.ArrayValue = _vertices;
            _device.ImmediateContext.InputAssembler.SetVertexBuffers(0, new VertexBufferBinding(_vertexBufferWrapper.Buffer, Marshal.SizeOf(typeof(StructVertex)), 0));

            _projectionsBufferWrapper.Value = new Projections()
            {
                Model = Matrix.Transpose(VolumeManagement.Instance.ModelMatrix),
                View = Matrix.Transpose(_camera.ViewMatrix),
                Projection = Matrix.Transpose(_camera.ProjectionMatrix),
                Normalize = Matrix.Transpose(_normalizeMatrix),
            };

            _perFrameSliceBufferWrapper.Value = new PerFrameSlice()
            {
                Alpha = InputManagement.Instance.Alpha,
            };

            _device.ImmediateContext.InputAssembler.InputLayout = _sliceLayout;
            _device.ImmediateContext.VertexShader.Set(_sliceVS);
            _device.ImmediateContext.VertexShader.SetConstantBuffer(0, _projectionsBufferWrapper.Buffer);
            _device.ImmediateContext.PixelShader.Set(_slicePS);
            _device.ImmediateContext.PixelShader.SetSampler(0, _trilinearSamperState);
            _device.ImmediateContext.PixelShader.SetShaderResource(0, VolumeManagement.Instance.FilteredGradientTextureSrv);
            _device.ImmediateContext.PixelShader.SetShaderResource(1, VolumeManagement.Instance.TransferFunctionSrv);
            _device.ImmediateContext.PixelShader.SetConstantBuffer(1, _perFrameSliceBufferWrapper.Buffer);

            _device.ImmediateContext.Draw(intersections.Count, 0);
        }

        /// <summary>
        /// Disposed die erzeugten DirectX-Objekte
        /// </summary>
        public void Dispose()
        {
            _vertexBufferWrapper2.Dispose();
            _indexBufferWrapper2.Dispose();

            _vertexBufferWrapper.Dispose();
            _projectionsBufferWrapper.Dispose();

            _sliceBufferRTV.Dispose();
            _sliceBufferSRV.Dispose();

            for (int i = 0; i < 2; i++)
            {
                _interBufferRTV[i].Dispose();
                _interBufferSRV[i].Dispose();
                _eyeBufferRTV[i].Dispose();
                _eyeBufferSRV[i].Dispose();
            }

            _sliceVS.Dispose();
            _slicePS.Dispose();
            _sliceLayout.Dispose();

            _blendVS.Dispose();
            _blendPS.Dispose();

            _dofVS.Dispose();
            _dofPS.Dispose();
            
            _colorcubeVS.Dispose();
            _colorcubePS.Dispose();
            _colorcubeLayout.Dispose();

            _cullNoneState.Dispose();
            _trilinearSamperState.Dispose();

            _perFrameDOFBufferWrapper.Dispose();
            _perFrameSliceBufferWrapper.Dispose();

            _opaqueBlendState.Dispose();
            _btfBlendState.Dispose();

            _device.Dispose();
        }
    }
}
