#include "Shared.h"
#include "Common.hlsli"

ConstantBuffer<GlobalShaderData> globalShaderData : register(b0, space0);
ConstantBuffer<RenderData> renderData : register(b1, space0);
ConstantBuffer<DebugData> debugData : register(b2, space0);
ConstantBuffer<RenderHandles> renderHandles : register(b3, space0);

static const float PI = 3.14159265359;

float DistributionGGX(float3 N, float3 H, float roughness) {
  float a = roughness * roughness;
  float a2 = a * a;
  float NdotH = max(dot(N, H), 0.0);
  float NdotH2 = NdotH * NdotH;

  float num = a2;
  float denom = (NdotH2 * (a2 - 1.0) + 1.0);
  denom = PI * denom * denom;

  return num / max(denom, 0.0000001);
}

float GeometrySchlickGGX(float NdotV, float roughness) {
  float r = (roughness + 1.0);
  float k = (r * r) / 8.0;

  float num = NdotV;
  float denom = NdotV * (1.0 - k) + k;

  return num / max(denom, 0.0000001);
}

float GeometrySmith(float3 N, float3 V, float3 L, float roughness) {
  float NdotV = max(dot(N, V), 0.0);
  float NdotL = max(dot(N, L), 0.0);
  float ggx2 = GeometrySchlickGGX(NdotV, roughness);
  float ggx1 = GeometrySchlickGGX(NdotL, roughness);

  return ggx1 * ggx2;
}

float3 FresnelSchlick(float cosTheta, float3 F0) {
  return F0 + (1.0 - F0) * pow(max(1.0 - cosTheta, 0.0), 5.0);
}

float3 worldPosFromDepth(float depth, float2 uv) {
    float4 clipPos;
    clipPos.x = uv.x * 2.0 - 1.0;
    clipPos.y = (uv.y * -2.0) + 1.0;
    clipPos.z = depth;
    clipPos.w = 1.0;

    float4 worldPos = mul(globalShaderData.invViewProj, clipPos);
    return worldPos.xyz / worldPos.w;
}

[NumThreads(8,8,1)]
void ComputeMain(uint3 dispatchThreadID : SV_DispatchThreadID) {
  uint2 pixel = dispatchThreadID.xy;
  if (pixel.x >= globalShaderData.width || pixel.y >= globalShaderData.height) return;

  Texture2D<uint4> gbuffer0 = ResourceDescriptorHeap[renderHandles.gbuffer0Index];
  Texture2D<uint2> gbuffer1 = ResourceDescriptorHeap[renderHandles.gbuffer1Index];
  Texture2D<float> depthTex = ResourceDescriptorHeap[renderHandles.depthIndex];
  RWTexture2D<float4> hdrRt = ResourceDescriptorHeap[renderHandles.hdrRtIndex];
  Texture2D<float4> shadowMap = ResourceDescriptorHeap[renderHandles.shadowMapIndex];

  uint4 rt0 = gbuffer0[pixel];
  uint2 rt1 = gbuffer1[pixel];
  float depth = depthTex[pixel];

  float3 albedo = SRGBToLinear(unpackRGBA8(rt0.x).rgb);
  float3 n = unpackNormal11_11_10(rt0.y);
  float3 emissive = unpackEmissive(rt0.z);

  float2 roughMetal = unpackFloat2_16(rt1.x);
  float roughness = roughMetal.x;
  float metallic = roughMetal.y;

  if (bool(debugData.showMeshlets)) {
      hdrRt[pixel] = float4(albedo, 1.0);
      return;
  }

  float2 uv = (float2(pixel) + 0.5) / float2(globalShaderData.width, globalShaderData.height);
  float3 worldPos = worldPosFromDepth(depth, uv);

  float3 L_unnormalized = globalShaderData.lightPos - worldPos;
  float distance = length(L_unnormalized);
  float3 l = L_unnormalized / distance;

  float attenuation = 1.0 / (distance + 1.0);
  float distanceFactor = saturate(1.0 - (distance / globalShaderData.lightRadius));
  attenuation *= (distanceFactor * distanceFactor);

  float3 v = normalize(globalShaderData.viewPos - worldPos);
  float3 h = normalize(l + v);

  float3 F0 = float3(0.04, 0.04, 0.04);
  F0 = lerp(F0, albedo.rgb, metallic);

  float3 radiance = float3(1.0, 1.0, 1.0) * globalShaderData.lightStrength;

  float NDF = DistributionGGX(n, h, roughness);
  float G   = GeometrySmith(n, v, l, roughness);
  float3 F  = FresnelSchlick(max(dot(h, v), 0.0), F0);

  float3 numerator = NDF * G * F;
  float denom = 4.0 * max(dot(n, v), 0.0) * max(dot(n, l), 0.0) + 0.0001;
  float3 specular = numerator / denom;

  float3 kS = F;
  float3 kD = float3(1.0, 1.0, 1.0) - kS;
  kD *= 1.0 - metallic;

  float NdotL = max(dot(n, l), 0.0);
  float3 Lo = (kD * albedo.rgb / PI + specular) * radiance * NdotL;

  float shadow = shadowMap[pixel].r;

  float3 ambient = albedo.rgb * globalShaderData.ambientFactor;
  float3 color = (Lo * shadow) + ambient + emissive;

  hdrRt[pixel] = float4(color, 1.0);
}
