/*
    This program is free software; you can redistribute it and/or modify
    it under the terms of the GNU General Public License version 2 
    as published by the Free Software Foundation.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program; if not, write to the Free Software
    Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA


    Copyright (C) 2006  Thierry Berger-Perrin <tbptbp@gmail.com>
*/
/*
    This program is free software; you can redistribute it and/or modify
    it under the terms of the GNU General Public License version 2
    as published by the Free Software Foundation.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program; if not, write to the Free Software
    Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA


    Copyright (C) 2006  Thierry Berger-Perrin <tbptbp@gmail.com>
*/
#include "specifics.h"
#include "rt_render.h"

#include "math_constants.h"
#include "math_vec.h"
#include "math_aabb.h"
#include "rt_bvh.h"
#include "float.h"
#include <memory>

namespace rt {
    namespace mono {
        struct hit_t {
            float    t;
            float    u,v;
            uint_t    id;

            FINLINE hit_t() : t(cst::section.plus_inf), u(0), v(0), id(~0u) {}
            FINLINE explicit hit_t(const float tmax) : t(tmax), u(0), v(0), id(uint_t(-1)) {}
        };

        struct ray_segment_t { float t_near, t_far; };

        struct ray_t  {
                explicit ray_t(const vec_t &o) : pos(o), dir(vec_t::zero), rcp_dir(vec_t::zero) {}
                const vec_t pos;
                const vec_t dir;
                const vec_t rcp_dir;

            // ok, i'm lying a bit there.
            void set_dir(const vec_t &d) const {
                (vec_t &)dir = d;
                (vec_t &)rcp_dir = vec_t(1.f/d.x, 1.f/d.y, 1.f/d.z);
            }
        };

        struct trace_stack_t {
            enum { max_lvl = 32 /*64*/ };
            struct trace_t {
                const kdtree::node_t * __restrict node;
                rt::mono::ray_segment_t    rs;
                //int pad;
                float pad;
            };

            trace_t traces[max_lvl];
            int_t idx;
            trace_stack_t() : idx(0) {}
            void reset() { idx = 0; }
            void wipe() { std::memset(traces, 0, sizeof(traces)); }
            bool_t pop() { return --idx >= 0; }
            const kdtree::node_t * __restrict node() const { return traces[idx].node; }
            const ray_segment_t    &rs() const { return traces[idx].rs; }
        };

        struct trace_bvh_stack_t {
            enum { max_lvl = 64 };
            struct trace_t {
                const bvh::node_t * __restrict node;
            };

            trace_t traces[max_lvl];
            int_t idx;
            trace_bvh_stack_t() : idx(0) {}
            void reset() { idx = 0; }
            void wipe() { std::memset(traces, 0, sizeof(traces)); }
            bool_t pop() { return --idx >= 0; }
            const bvh::node_t * __restrict node() const { return traces[idx].node; }
        };


        // not robust, but it's a bit of pain to fix without using intrinsics
        static bool_t intersect_ray_box(const aabb_t &box, const ray_t &ray, ray_segment_t &rs) {
            float
                l1    = (box.pmin.x - ray.pos.x) * ray.rcp_dir.x,
                l2    = (box.pmax.x - ray.pos.x) * ray.rcp_dir.x;
            rs.t_near    = minf(l1,l2);
            rs.t_far    = maxf(l1,l2);

            l1    = (box.pmin.y - ray.pos.y) * ray.rcp_dir.y;
            l2    = (box.pmax.y - ray.pos.y) * ray.rcp_dir.y;
            rs.t_near    = maxf(minf(l1,l2), rs.t_near);
            rs.t_far    = minf(maxf(l1,l2), rs.t_far);

            l1    = (box.pmin.z - ray.pos.z) * ray.rcp_dir.z;
            l2    = (box.pmax.z - ray.pos.z) * ray.rcp_dir.z;
            rs.t_near    = maxf(minf(l1,l2), rs.t_near);
            rs.t_far    = minf(maxf(l1,l2), rs.t_far);

            //return ((rs.t_far >= 0.f) & (rs.t_far >= rs.t_near));
            return ((rs.t_far >= rs.t_near) & (rs.t_far >= 0.f));
        }

        static bool_t intersect_ray_box2(const aabb_t &box, const ray_t &ray, float t)
        {
            float l1 = (box.pmin.x - ray.pos.x) * ray.rcp_dir.x;
            float l2 = (box.pmax.x - ray.pos.x) * ray.rcp_dir.x;
            float t_near   = minf(l1,l2);
            float t_far    = maxf(l1,l2);

            l1    = (box.pmin.y - ray.pos.y) * ray.rcp_dir.y;
            l2    = (box.pmax.y - ray.pos.y) * ray.rcp_dir.y;
            t_near   = maxf(minf(l1,l2), t_near);
            t_far    = minf(maxf(l1,l2), t_far);

            l1    = (box.pmin.z - ray.pos.z) * ray.rcp_dir.z;
            l2    = (box.pmax.z - ray.pos.z) * ray.rcp_dir.z;
            t_near   = maxf(minf(l1,l2), t_near);
            t_far    = minf(maxf(l1,l2), t_far);

            return ((t_far >= t_near) & (t_far >= 0.f) & (t_near < t));
        }


        // this one has a few 'early' exits, works ok for scalars.
        static void intersect_ray_tri_wald(
            const wald_tri_t &tri,
            const ray_t &ray, hit_t &hit)
        {
            const uint_t k = tri.k, ku = cst::wald_modulo[k], kv = cst::wald_modulo[k+1];
            const float dir_ku = ray.dir[ku], dir_kv = ray.dir[kv];
            const float pos_ku = ray.pos[ku], pos_kv = ray.pos[kv];
            const float t = (tri.n_d - ray.pos[k] - tri.n_u*pos_ku - tri.n_v*pos_kv) /
                (ray.dir[k] + tri.n_u*dir_ku + tri.n_v*dir_kv);
            if (!((hit.t > t) & (t >= 0.f))) return;
            const float hu = pos_ku + t*dir_ku - tri.vert_ku;
            const float hv = pos_kv + t*dir_kv - tri.vert_kv;
            const float beta = hv*tri.b_nu + hu*tri.b_nv;
            const float gamma = hu*tri.c_nu + hv*tri.c_nv;
            if ((beta < 0.f) | (gamma < 0.f) | ((beta + gamma) > 1.f))
                return;
            hit.t = t;
            hit.u = beta;
            hit.v = gamma;
            hit.id = tri.id;
        }

        // copy & paste, picked up in a myriad of alternatives, seems to work at least for gcc & icc.
        static FINLINE void trace(const rt::raytracer_t &rt, const ray_t &ray, hit_t &hit, trace_stack_t &trace_stack) {
            ray_segment_t rs;
            if (intersect_ray_box(rt.scene.bounding_box, ray, rs)) {
                trace_stack.reset();
                const bool_t signs_all[4] = { ray.dir[0] >= 0.f, ray.dir[1] >= 0.f, ray.dir[2] >= 0.f, 0 };

                const kdtree::node_t * __restrict node = rt.scene.kdtree.root;

                while(true) {
                    {
                        float // hello icc, that's a hint.
                            t_near    = rs.t_near,
                            t_far    = rs.t_far;

                        // hmm
                        t_near = maxf(t_near, 0.f);

                        while (EXPECT_TAKEN(!node->is_leaf())) {
                            const uint_t
                                bits = node->inner.dim_offset_flag,
                                axis = bits & 3u,
                                off     = bits & kdtree::node_t::mask_children;
                                //axis = node->get_axis();
                            const float
                                split = node->inner.split_coord,
                                d = (split - ray.pos[axis]) * ray.rcp_dir[axis];

                            const bool_t sign = signs_all[axis];

                            const kdtree::node_t
                                * __restrict const base = (const kdtree::node_t * __restrict) ((const char * __restrict)node + off),
                                * __restrict const left = base + (sign^0),
                                * __restrict const right = base + (sign^1);
                                //* __restrict const left  = node->get_back() + (sign^0);
                                //* __restrict const right = node->get_back() + (sign^1);
                            //node = node->get_back() + (sign^1);
                            node = right;

                            const bool_t
                                frontside    = d > t_far,    // case two, t_near <= t_far <= d -> cull back side
                                backside    = d < t_near;    // case one, d <= t_near <= t_far -> cull front side

                            if (frontside)
                                ;
                            else if (backside)
                                node = left;
                            else {
                                const int_t idx = trace_stack.idx;

                                trace_stack.traces[idx].node        = left;
                                trace_stack.traces[idx].rs.t_near    = d;
                                trace_stack.traces[idx].rs.t_far    = t_far;
                                t_far = d;
                                ++trace_stack.idx;
                            }
                        }
                        rs.t_near = t_near;
                        rs.t_far = t_far;
                    }


                    const uint_t
                        idx    = node->get_list(),
                        count    = node->leaf.count;

                    if (count) { // <-- not really needed.
                        uint_t i=0;
                        for (uint_t dec=count; dec; --dec) {
                            intersect_ray_tri_wald(rt.scene.kdtree.acc[rt.scene.kdtree.ids[idx+i]], ray, hit);
                            ++i;
                        }
                    }

                    if (trace_stack.pop() & (hit.t > rs.t_far)) {
                        node = trace_stack.node();
                        rs = trace_stack.rs();
                    }
                    else
                        break;    // early termination
                }
            }
        }

        static FINLINE void trace_bvh(
            const rt::raytracer_t &rt,
            const ray_t &ray, hit_t &hit,
            trace_bvh_stack_t &trace_stack)
        {
            const int32_t signs_all[3] = {
                (ray.dir[0] >= 0.f) & 1,
                (ray.dir[1] >= 0.f) & 1,
                (ray.dir[2] >= 0.f) & 1,
            };
            trace_stack.reset();
            trace_stack.traces[0].node = rt.scene.bvhtree.root;
            ++trace_stack.idx;
pop_label:
            while(EXPECT_TAKEN(trace_stack.pop())) {
                const bvh::node_t *__restrict node = trace_stack.traces[trace_stack.idx].node;
                while(EXPECT_TAKEN(!node->is_leaf())) {
                    if(!intersect_ray_box2(node->aabb, ray, hit.t))
                        goto pop_label;
                    const int32_t next = signs_all[node->d];
                    const int32_t first = ~next & 1;
                    const bvh::node_t * __restrict left = &rt.scene.bvhtree.root[node->offset_flag + next];
                    const int32_t idx = trace_stack.idx++;
                    trace_stack.traces[idx].node = left;
                    node = &rt.scene.bvhtree.root[node->offset_flag + first];
                }
                intersect_ray_tri_wald(rt.scene.bvhtree.acc[node->tri_id], ray, hit);
            }
        }
    }
}

namespace rt {
    namespace mono {
        static const rt::pixel_t pixel_zero(0,0,0,0);

        static FINLINE void do_render_tile(const renderer_t &renderer, const point_t &corner) {
            //BREAKPOINT();
            trace_stack_t trace_stack;
            trace_bvh_stack_t trace_bstack;

            const float cx = float(corner.x), cy = float(corner.y);
            vec_t scan_line(renderer.rt.camera.sampler.top + renderer.rt.camera.sampler.dx*cx + renderer.rt.camera.sampler.dy*cy);

            const ray_t ray(renderer.rt.camera.get_eye());

            const point_t &res(renderer.framebuffer.get_resolution());
            uint_t pixel_line = corner.y*res.x + corner.x;

            const vec_t extent(renderer.rt.scene.bounding_box.get_extent());
            const float scale_depth = extent.get_max() * (1.f/4);

            for (int dec_y=tile_size_y; dec_y; --dec_y) {
                vec_t scan(scan_line);
                uint_t pixel_idx = pixel_line;
                for (int dec_x=tile_size_x; dec_x; --dec_x) {
                    rt::mono::hit_t hit;
                    ray.set_dir(scan);

                    trace(renderer.rt, ray, hit, trace_stack);
                    //trace_bvh(renderer.rt, ray, hit, trace_bstack);

                    const bool_t bingo = hit.id != ~0u;
                    const float d = scale_depth / hit.t;
                    renderer.framebuffer.plot(pixel_idx, bingo ? pixel_t(d*0.25f, d*0.5f, d, 0) : pixel_zero);

                    scan = scan + renderer.rt.camera.sampler.dx;
                    ++pixel_idx;
                }
                scan_line = scan_line + renderer.rt.camera.sampler.dy;
                pixel_line += res.x;
            }
        }
    }
}

#include "horde.h"

namespace horde {

    // that's where we store the job and whatever parameter we can factorize out
    struct job_t {
        enum { die = 0, render_mono = 1, render_packet } job_type;
        union {
            struct job_die_t { int rc; } job_die;
            struct job_render_t { const rt::renderer_t    *renderer; } job_render;
        } workload;
    };

}
/*
namespace rt {
    namespace packet {
        static FINLINE void do_render_tile(const renderer_t &renderer, const point_t &corner);
    }
}
*/

namespace horde {
    namespace captain {
        int        num_grunts;
        #ifdef WINDOWS
            HANDLE deployphore;
            HANDLE dockphore[max_grunts];
        #else
            sem_t deployphore, dockphore;
        #endif
    }
    grunt_t grunts[max_grunts];

    static job_t job;

    //note: we'll inline it all at this point.
    static NOINLINE void grunt_render_tiles_mono(const rt::renderer_t &renderer) {
        point_t point;
        while (lifo.pop(point)) // food fighting
            rt::mono::do_render_tile(renderer, point);
    }

    extern NOINLINE void grunt_render_tiles_packet(const rt::renderer_t &renderer);


    int grunt_t::run() {
        /*
            We'll wake-up with a job of some type possibly chunked.
        */
        bool_t i_will_survive = true;
        while (i_will_survive) {
            wait();

            switch (job.job_type) {
                case job_t::die:            i_will_survive = false; break;
                case job_t::render_mono:    grunt_render_tiles_mono(*job.workload.job_render.renderer); break;
                case job_t::render_packet:    grunt_render_tiles_packet(*job.workload.job_render.renderer); break;
            }

            signal();
        }

        return job.workload.job_die.rc;
    }
}

namespace rt {
    namespace mono {
    }

    /*
        We split the framebuffer in tiles, stack up those tiles on a lifo,
        wake up a bunch of workers and wait until they are done.
    */
    //void render_horde(const raytracer_t &rt, const framebuffer_t &framebuffer) {
    void render_horde(const rt::renderer_t &renderer, const bool_t mode_mono) {
        // one job at any given time.
        horde::job.job_type = mode_mono ? horde::job_t::render_mono : horde::job_t::render_packet;
        horde::job.workload.job_render.renderer = &renderer;

        // bake tiles
        horde::lifo.reset();

        const point_t &res(renderer.framebuffer.get_resolution());
        const int
            num_tiles_x = res.x / tile_size_x,
            num_tiles_y = res.y / tile_size_y;

        if ((num_tiles_x*num_tiles_y) > horde::lifo64_t<point_t>::max_capacity)
            fatal("rt::mono::render_horde: lifo capacity exceeded, too many tiles.");

        // we pile them from top-left to bottom-right, they will be popped
        // in reverse order and... that's faster.
        point_t scan(0,0);
        #ifdef __ICC__
            // auto-vectorization bug. yeepee.
            #pragma novector
        #endif
        for (int dec_y = num_tiles_y; dec_y; --dec_y) {
            scan.x = 0;
            for (int dec_x = num_tiles_x; dec_x; --dec_x) {
                horde::lifo.push(scan);
                scan.x += tile_size_x;
            }
            scan.y += tile_size_y;
        }

        // render those tiles.
        horde::captain::signal();
        horde::captain::wait();
    }
}
