#include "Shared.h"
#include "Common.hlsli"

struct GBufferOutput {
    uint4 rt0 : SV_TARGET0; // x: Albedo, y: Normal, z: Emissive, w: unusued
    uint2 rt1 : SV_TARGET1;  // x: rough/metal, y: velocity (not yet)
};

struct VertexOut {
  float4 position : SV_POSITION;
  float3 worldPos : POSITION1;
  float4 tangent : TANGENT;
  float3 normal : NORMAL;
  float2 uv : TEXCOORD0;
  float4 meshletColor : COLOR0;
  nointerpolation uint materialIndex : COLOR1;
};

struct Payload {
  uint payloadData[AMPLIFICATION_SHADER_GROUP_SIZE];
};

ConstantBuffer<GlobalShaderData> globalShaderData : register(b0, space0);
ConstantBuffer<RenderData> renderData : register(b1, space0);
ConstantBuffer<DebugData> debugData : register(b2, space0);

Texture2D textures[] : register(t0, space1);
SamplerState samplr : register(s0);

groupshared Payload s_Payload;
groupshared float3 s_WorldPos[MESHLET_MAX_VERTICES];

[NumThreads(AMPLIFICATION_SHADER_GROUP_SIZE, 1, 1)] void
AmplificationMain(uint GroupID : SV_GroupID, uint ThreadIndex : SV_GroupIndex,
                  uint DispatchThreadID : SV_DispatchThreadID) {
  StructuredBuffer<Meshlet> meshlets =
      ResourceDescriptorHeap[renderData.meshletBufferIndex];
  StructuredBuffer<Mesh> meshes =
      ResourceDescriptorHeap[renderData.meshBufferIndex];

  bool visible = false;

  if (DispatchThreadID < renderData.meshletCount) {
    Meshlet meshlet = meshlets[DispatchThreadID];
    Mesh mesh = meshes[meshlet.meshIndex];

    float3 center = mul(globalShaderData.fixatedView, mul(mesh.transform, float4(meshlet.boundingSphere.xyz, 1.0))).xyz;
    float radius = meshlet.boundingSphere.w * mesh.scale;

    int3 axisInt;
    axisInt.x = (int(meshlet.cone >> 24) & 0xFF) << 24 >> 24;
    axisInt.y = (int(meshlet.cone >> 16) & 0xFF) << 24 >> 24;
    axisInt.z = (int(meshlet.cone >> 8)  & 0xFF) << 24 >> 24;
    int cutoffInt = (int(meshlet.cone >> 0)  & 0xFF) << 24 >> 24;
    float3 coneAxis = normalize(float3(axisInt) / 127.0f);
    float coneCutoff = float(cutoffInt) / 127.0f;

    coneAxis = mul((float3x3)globalShaderData.fixatedView, mul((float3x3)mesh.transform, coneAxis));
    coneAxis = normalize(coneAxis);

    bool coneCull = dot(center, coneAxis) >= coneCutoff * length(center) + radius;

    visible = true;
    visible = visible && center.z * globalShaderData.viewFrustum.y - abs(center.x) * globalShaderData.viewFrustum.x > -radius;
    visible = visible && center.z * globalShaderData.viewFrustum.w - abs(center.y) * globalShaderData.viewFrustum.z > -radius;
    visible = visible && center.z + radius > globalShaderData.zNear;
    visible = visible && (!coneCull || bool(debugData.disableBackfaceCulling));
  }

  if (visible) {
    uint index = WavePrefixCountBits(visible);
    s_Payload.payloadData[index] = DispatchThreadID;
  }

  uint visibleMeshlets = WaveActiveCountBits(visible);
  DispatchMesh(visibleMeshlets, 1, 1, s_Payload);
}

uint Hash(uint a) {
  a = (a + 0x7ed55d16) + (a << 12);
  a = (a ^ 0xc761c23c) ^ (a >> 19);
  a = (a + 0x165667b1) + (a << 5);
  a = (a + 0xd3a2646c) ^ (a << 9);
  a = (a + 0xfd7046c5) + (a << 3);
  a = (a ^ 0xb55a4f09) ^ (a >> 16);
  return a;
}

float4 SampleTex(uint tex, float2 uv) {
  float4 sampl = textures[tex - 1].Sample(samplr, uv);
  return sampl;
}

[OutputTopology("triangle")]
[NumThreads(MESH_SHADER_GROUP_SIZE, 1, 1)]
void MeshMain(
    uint GroupID : SV_GroupID,
    uint GroupThreadID : SV_GroupThreadID,
    in payload Payload s_Payload,
    out indices uint3 OutTriangles[MESHLET_MAX_TRIANGLES],
    out vertices VertexOut OutVertices[MESHLET_MAX_VERTICES],
    out primitives bool OutCull[MESHLET_MAX_TRIANGLES] : SV_CullPrimitive
) {
  uint index = s_Payload.payloadData[GroupID];

  if (index >= renderData.meshletCount) {
    return;
  }

  StructuredBuffer<Vertex> vertices = ResourceDescriptorHeap[renderData.vertexBufferIndex];
  StructuredBuffer<uint> meshletVertices = ResourceDescriptorHeap[renderData.meshletVertexBufferIndex];
  StructuredBuffer<uint> meshletTriangles = ResourceDescriptorHeap[renderData.meshletTriangleBufferIndex];
  StructuredBuffer<Meshlet> meshlets = ResourceDescriptorHeap[renderData.meshletBufferIndex];
  StructuredBuffer<Mesh> meshes = ResourceDescriptorHeap[renderData.meshBufferIndex];

  Meshlet meshlet = meshlets[index];
  SetMeshOutputCounts(meshlet.vertexCount, meshlet.triangleCount);

  float4x4 transformation = meshes[meshlet.meshIndex].transform;

  if (GroupThreadID < meshlet.vertexCount) {
    uint vertexIndex = meshletVertices[meshlet.vertexOffset + GroupThreadID];
    Vertex vertex = vertices[vertexIndex];

    float4 position = float4(vertex.position, 1.0f);
    float4 worldPos = mul(transformation, position);

    s_WorldPos[GroupThreadID] = worldPos.xyz;

    float4 cameraPos = mul(globalShaderData.view, worldPos);
    float4 projectedPosition = mul(globalShaderData.projection, cameraPos);

    uint hashValue = Hash(index);
    float4 meshletColor = float4(float(hashValue & 255), float((hashValue >> 8) & 255), float((hashValue >> 16) & 255), 255.0) / 255.0;

    OutVertices[GroupThreadID].position = projectedPosition;
    OutVertices[GroupThreadID].worldPos = worldPos.xyz;
    OutVertices[GroupThreadID].uv = vertex.uv;
    OutVertices[GroupThreadID].meshletColor = meshletColor;
    OutVertices[GroupThreadID].materialIndex = meshlet.materialIndex;

    float3x3 worldIt = (float3x3)transformation;
    OutVertices[GroupThreadID].normal = normalize(mul(worldIt, vertex.normal));
    OutVertices[GroupThreadID].tangent = float4(normalize(mul(worldIt, vertex.tangent.xyz)), vertex.tangent.w);
  }

  GroupMemoryBarrierWithGroupSync();

  if (GroupThreadID < meshlet.triangleCount) {
    uint offset = meshlet.triangleOffset + GroupThreadID * 3;

    uint3 triIndices = uint3(
      meshletTriangles[offset + 0],
      meshletTriangles[offset + 1],
      meshletTriangles[offset + 2]
    );

    OutTriangles[GroupThreadID] = triIndices;

    bool cull = false;

    float3 p0 = s_WorldPos[triIndices.x];
    float3 p1 = s_WorldPos[triIndices.y];
    float3 p2 = s_WorldPos[triIndices.z];

    float3 edge1 = p1 - p0;
    float3 edge2 = p2 - p0;
    float3 faceNormal = normalize(cross(edge1, edge2));

    float3 viewDir = normalize(p0 - globalShaderData.viewPos);

    if (dot(faceNormal, viewDir) > 0.0) {
      cull = true;
    }

    OutCull[GroupThreadID] = cull;
  }
}

GBufferOutput PixelMain(VertexOut input) {
  StructuredBuffer<Material> materials = ResourceDescriptorHeap[renderData.materialBufferIndex];
  Material mat = materials[input.materialIndex];

  float4 albedo = mat.diffuseFactor;
  if (mat.albedoTexture != 0) {
    albedo *= SampleTex(mat.albedoTexture, input.uv);
  }

  if (bool(debugData.showMeshlets)) {
    albedo = input.meshletColor;
  }

  if (albedo.a < 0.1) discard;

  float3 normalMap = float3(0.0, 0.0, 1.0);
  float3 n = normalize(input.normal);

  if (mat.normalTexture != 0) {
    float2 rg = SampleTex(mat.normalTexture, input.uv).rg * 2.0 - 1.0;

    float z = sqrt(max(0.0, 1.0 - dot(rg, rg)));
    float3 normalMap = normalize(float3(rg, z));

    float4 tangent = input.tangent;

    float3 bitangent = cross(n, tangent.xyz) * tangent.w;
    n = normalize(normalMap.r * tangent.xyz + normalMap.g * bitangent + normalMap.b * n);
  }

  float metallic = mat.metallicFactor;
  float roughness = mat.roughnessFactor;

  if (mat.metallicRoughnessTexture != 0) {
    float4 mrSample = SampleTex(mat.metallicRoughnessTexture, input.uv);
    roughness *= mrSample.g;
    metallic *= mrSample.b;
  }

  float3 emissive = mat.emissiveFactor;
  if (mat.emissiveTexture != 0) {
    emissive *= SampleTex(mat.emissiveTexture, input.uv).rgb;
  }


  GBufferOutput output;
  output.rt0.x = packRGBA8(float4(LinearToSRGB(albedo.rgb), 1.0));
  output.rt0.y = packNormal11_11_10(n);
  output.rt0.z = packEmissive(emissive);
  output.rt0.w = 0;

  output.rt1.x = packFloat2_16(float2(roughness, metallic));
  output.rt1.y = packFloat2_16(float2(0.5, 0.5));

  return output;
}
