/*
 *   This program is free software; you can redistribute it and/or modify
 *   it under the terms of the GNU General Public License version 2
 *   as published by the Free Software Foundation.
 *
 *   This program is distributed in the hope that it will be useful,
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *   GNU General Public License for more details.
 *
 *   You should have received a copy of the GNU General Public License
 *   along with this program; if not, write to the Free Software
 *   Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
 *
 *   Copyright (C) 2007  Benjamin Segovia <bsegovia@liris.cnrs.fr>
 */

#include <memory>

#include "specifics.h"

#include "rt_bvh.h"
#include "bvhlib_internal.h"
#include "math_constants.h"

/* Convert one triangle into a Baduel triangles */
static bool_t compute_tri_acc(
    const triangle_t &t,
    rt::wald_tri_t &w,
    const uint_t id,
    const uint_t matid)
{
    const vec_t
        &A(t.verts[0]), &B(t.verts[1]), &C(t.verts[2]),
        b(C - A), c(B - A), N(b.cross(c));
    uint_t k = 0;
    for (uint_t i=1; i<3; ++i)
        k = fabsf(N[i]) > fabsf(N[k]) ? i : k;
    const uint_t u = (k+1)%3, v = (k+2)%3;
    const float
        denom = (b[u]*c[v] - b[v]*c[u]),
        krec = N[k];
    const float
        nu = N[u] / krec, nv = N[v] / krec, nd = N.dot(A) / krec,
        bnu =  b[u] / denom, bnv = -b[v] / denom,
        cnu =  c[v] / denom, cnv = -c[u] / denom;
    w.k          = k;
    w.n_u        = float(nu);
    w.n_v        = float(nv);
    w.n_d        = float(nd);
    w.vert_ku    = float(A[u]);
    w.vert_kv    = float(A[v]);
    w.b_nu       = float(bnu);
    w.b_nv       = float(bnv);
    w.c_nu       = float(cnu);
    w.c_nv       = float(cnv);
    w.id         = id;
    w.matid      = matid;

    return (krec == 0.) | (denom == 0.);
}

/* Compute the Baduel triangles for intersection */
static NOINLINE void bake_intersection(
    const triangle_t * __restrict const tri,
    const uint_t num_tri,
    rt::wald_tri_t * __restrict acc)
{
    int degenerated = 0;
    for (uint_t i = 0; i < num_tri; ++i)
        degenerated += compute_tri_acc(tri[i], acc[i], i, 0);
}

/* C callback style to compile a BVH */
bool_t bvhlib_compile(
    const triangle_t * __restrict const t,
    const uint32_t tri_n,
    bvh::descriptor_t &bvhtree)
{
    /* Build the tree here */
    std::auto_ptr<bvhlib::compiler_t> c(new bvhlib::compiler_t());
    if(c->injection(t, tri_n, bvhtree) < 0)
        return false;
    c->compile();
    bvhtree.root = c->root;

    /* Build the accelerated triangles here */
    bvhtree.acc.resize(tri_n);
    bake_intersection(t, tri_n, &bvhtree.acc[0]);

    return true;
}

/* C callback style to destroy a BVH */
void bvhlib_free(bvh::descriptor_t *bvhtree)
{
    if(bvhtree == NULL) return;
    if(bvhtree->root != NULL) delete[] bvhtree->root;
    delete bvhtree;
}

struct trace_bvh_stack_t {
    enum { max_lvl = 64 };
    struct trace_t {
        const bvh::node_t * __restrict node;
    };

    trace_t traces[max_lvl];
    int_t idx;
    trace_bvh_stack_t() : idx(0) {}
    void reset() { idx = 0; }
    void wipe() { std::memset(traces, 0, sizeof(traces)); }
    bool_t pop() { return --idx >= 0; }
    const bvh::node_t * __restrict node() const { return traces[idx].node; }
};

static FINLINE bool_t intersect_ray_box2(const aabb_t &box, const bvh_ray_t &ray, float t)
{
    float l1 = (box.pmin.x - ray.pos.x) * ray.rcp_dir.x;
    float l2 = (box.pmax.x - ray.pos.x) * ray.rcp_dir.x;
    float t_near   = minf(l1,l2);
    float t_far    = maxf(l1,l2);
    l1    = (box.pmin.y - ray.pos.y) * ray.rcp_dir.y;
    l2    = (box.pmax.y - ray.pos.y) * ray.rcp_dir.y;
    t_near   = maxf(minf(l1,l2), t_near);
    t_far    = minf(maxf(l1,l2), t_far);
    l1    = (box.pmin.z - ray.pos.z) * ray.rcp_dir.z;
    l2    = (box.pmax.z - ray.pos.z) * ray.rcp_dir.z;
    t_near   = maxf(minf(l1,l2), t_near);
    t_far    = minf(maxf(l1,l2), t_far);
    return ((t_far >= t_near) & (t_far >= 0.f) & (t_near < t));
}

/***************************************************************************//**
 * Perform the intersection between a ray and a accelarated triangle
/******************************************************************************/
static FINLINE void intersect_bvh_ray_tri_wald(
    const rt::wald_tri_t &tri,
    const bvh_ray_t &ray, bvh_hit_t &hit)
{
    const uint_t k = tri.k, ku = cst::wald_modulo[k], kv = cst::wald_modulo[k+1];
    const float dir_ku = ray.dir[ku], dir_kv = ray.dir[kv];
    const float pos_ku = ray.pos[ku], pos_kv = ray.pos[kv];
    const float t = (tri.n_d - ray.pos[k] - tri.n_u*pos_ku - tri.n_v*pos_kv) /
        (ray.dir[k] + tri.n_u*dir_ku + tri.n_v*dir_kv);
    if (!((hit.t > t) & (t >= 0.f))) return;
    const float hu = pos_ku + t*dir_ku - tri.vert_ku;
    const float hv = pos_kv + t*dir_kv - tri.vert_kv;
    const float beta = hv*tri.b_nu + hu*tri.b_nv;
    const float gamma = hu*tri.c_nu + hv*tri.c_nv;
    if ((beta < 0.f) | (gamma < 0.f) | ((beta + gamma) > 1.f))
        return;
    hit.t = t;
    hit.u = beta;
    hit.v = gamma;
    hit.id = tri.id;
}

/***************************************************************************//**
 * Cast a ray into a BVH
/******************************************************************************/
void trace_bvh(
    const bvh::descriptor_t &bvhtree,
    const bvh_ray_t &ray, bvh_hit_t &hit)
{
    const int32_t signs_all[3] = {
        (ray.dir[0] >= 0.f) & 1,
        (ray.dir[1] >= 0.f) & 1,
        (ray.dir[2] >= 0.f) & 1,
    };
    trace_bvh_stack_t trace_stack;

    trace_stack.reset();
    trace_stack.traces[0].node = bvhtree.root;
    ++trace_stack.idx;

pop_label:
    while(EXPECT_TAKEN(trace_stack.pop())) {
        const bvh::node_t *__restrict node = trace_stack.traces[trace_stack.idx].node;
        while(EXPECT_TAKEN(!node->is_leaf())) {
            if(!intersect_ray_box2(node->aabb, ray, hit.t))
                goto pop_label;
            const int32_t next = signs_all[node->d];
            const int32_t first = ~next & 1;
            const bvh::node_t * __restrict left = &bvhtree.root[node->offset_flag + next];
            const int32_t idx = trace_stack.idx++;
            trace_stack.traces[idx].node = left;
            node = &bvhtree.root[node->offset_flag + first];
        }
        intersect_bvh_ray_tri_wald(bvhtree.acc[node->tri_id], ray, hit);
    }
}
