clc; clear variables;
obj = read_wobj('tilted_plane.obj');
co = obj.vertices;
el = obj.objects( [obj.objects(:).type]=='f' ).data.vertices;
% trimesh(el,co(:,1),co(:,2),co(:,3),'FaceAlpha',0,'EdgeAlpha',0.2,'EdgeColor','k');

% opt.specular=0.6; % ratio of specular reflection (diffuse will be 1-s)
% opt.specularDrawScale=0.06; % size of specular balls in plots
% [opt.usph_el, opt.usph_co] = readIcoSphereFromMesh();

opt.debugPlot = 0;
opt.n_th = 15;
opt.n_rh = 36;

opt.lightDisplacement = [-0.5 0 0.5];
pw_target = mytestAdjoint(el,co, opt, []);
title('target'); drawnow;
sum(pw_target)

% [x,z] = meshgrid(-0.6:0.1:0.6); y=0;
% sampleObjFuncOnGrid(x,y,z, @(p) sampleEvalFunc(el,co, opt,pw_target, p));


%%
close all;
pw_target = 0*pw_target + 0.5;
% 0.5 const target has issue with wrong gradients (disconst. brightness chg. as rays cross edges) around iter 150 (grad.desc. scale 1.0)
% small FD steps agree with gradient calcs - can't "see" these discont. jumps locally!

%%
p = [0 0 0];
phi_iters = [];
%%
for j=1:150
    opt.lightDisplacement = p;
    
    [pw, dpwdx,dpwdy,dpwdz] = mytestDirect(el,co, opt);
    [phi, dphidpw] = mytestObjective(pw, pw_target);
    dphidp = (dphidpw(:)') * [dpwdx(:), dpwdy(:), dpwdz(:)];
    
    if  j==1
        [pw_adj, phi_adj, dphidp_adj] = mytestAdjoint(el,co,opt, pw_target);

        % compare direct and adjoint calculations -- works
        disp('(pw, phi, dphidp) direct-adjoint abs err:');
        disp([max(abs(pw(:)-pw_adj(:)))  max(abs(phi(:)-phi_adj(:)))  max(abs(dphidp(:)-dphidp_adj(:)))]);
    end
    
    % run FD check
    fd_h=1e-7;
    dphidp_fd = [0 0 0];
    if  j==1 % works (tested j == 1 and 3)
        disp('FD-check ...');
        for l=1:length(dphidp_fd)
            opt.lightDisplacement = p;
            opt.lightDisplacement(l) = opt.lightDisplacement(l) + fd_h;
            pw_fd  = mytestAdjoint(el,co, opt, []);
            phi_fd = mytestObjective(pw_fd, pw_target);
            dphidp_fd(l) = (phi_fd - phi) / fd_h;
        end
        disp('abs err (grad-fd):');
        disp(abs(dphidp - dphidp_fd));
        disp('grad, fd:');
        disp(dphidp);
        disp(dphidp_fd);
        disp('... done');
    end
    
    phi_iters = [phi_iters; phi]; %#ok
    
    title(['step ' num2str(j,'%3u') ', p = [' num2str(p,'%.3f, ') ']']);  drawnow;
%     saveas(gcf,['_img/grad_desc_' num2str(j,'%03u') '.png']);

    p = p - 0.2*dphidp; % basic gradient descent step
end

figure; plot(phi_iters);
%
%
%% functions ...

function [phi, dphidpw] = mytestObjective(pw, pw_target)
    phi = 0.5*sum((pw(:).*4-pw_target(:)./4).^2);
    dphidpw = (pw-pw_target).*16;
end


function [pw, phi, dphidp] = mytestAdjoint(el,co, opt, pw_target)
    dataPerNode = 1;
    if  isfield(opt, 'specular')
        tmp = projectSpecularReflection([],[],-1,[],[],[],[],opt);
        dataPerNode = length(tmp);
    end
    pw = zeros(size(co,1),dataPerNode);

    if  opt.debugPlot
        figure; trimesh(el,co(:,1),co(:,2),co(:,3),'FaceAlpha',0,'EdgeAlpha',0.2,'EdgeColor','k'); axis equal; hold on;
    end

    lp = [0.1,3,-0.2];
    if  isfield(opt,'lightDisplacement')
        lp = lp + opt.lightDisplacement;
    end
    [ldx,ldy,ldz] = sph2cart(-pi/2,0,1); ld = [ldx,ldy,ldz]; clear ldx ldy ldz;
    lth= 15; % deg

    if  opt.debugPlot
        quiver3(lp(:,1),lp(:,2),lp(:,3), ld(:,1),ld(:,2),ld(:,3));
    end

    h1=simplePerpendicularVector(ld);
    h1=h1./norm(h1);

    th_vals = linspace(0,lth,opt.n_th+1); th_vals = th_vals(2:end);
    rh_vals = linspace(0,360,opt.n_rh+1); rh_vals = rh_vals(2:end);
    avg_th = sum(th_vals) / opt.n_th;

    for th = th_vals
        for rh = rh_vals
            R1 = axang2rotm([h1 th/180*pi]);
            R2 = axang2rotm([ld rh/180*pi]);
            rd = (R2*R1*ld')'; % ray direction
            
            if  opt.debugPlot
                quiver3(lp(:,1),lp(:,2),lp(:,3), rd(:,1),rd(:,2),rd(:,3), 'g');
            end

            for k = 1:size(el,1)
                [isHit, u, v, d] = rayTriangleIntersection(lp, rd, co(el(k,1),:), co(el(k,2),:), co(el(k,3),:));
                if  isHit && d>0
                    if  opt.debugPlot
                        hx = co(el(k,1),:) + u*(co(el(k,2),:)-co(el(k,1),:)) + v*(co(el(k,3),:)-co(el(k,1),:));
                        plot3(hx(1),hx(2),hx(3),'bo');
                    end

                    ph = 1/(opt.n_th*opt.n_rh) * (th / avg_th);
                    if  isfield(opt, 'specular')
                        ps = projectSpecularReflection(el,co, k,u,v,rd,ph, opt);
                        pw(el(k,1),:) = pw(el(k,1),:) + (1-u-v)*ps;
                        pw(el(k,2),:) = pw(el(k,2),:) +    u   *ps;
                        pw(el(k,3),:) = pw(el(k,3),:) +      v *ps;
                    else
                        pw(el(k,1)) = pw(el(k,1)) + (1-u-v)*ph;
                        pw(el(k,2)) = pw(el(k,2)) +    u   *ph;
                        pw(el(k,3)) = pw(el(k,3)) +      v *ph;
                    end
                    break; % end loop once we've found an intersection
                end
            end
        end
    end
    
    % forward step done, now adjoint step ...
    if  isempty(pw_target)
        phi=-1;dphidp=0; % no target - stop after forward step
    else
        % evaluate objective
        [phi, dphidpw] = mytestObjective(pw,pw_target);
        
        % adjoint derivative
        dphidp = [0 0 0]; % 3 parameters (light position x,y,z)
        for th = th_vals
            for rh = rh_vals
                R1 = axang2rotm([h1 th/180*pi]);
                R2 = axang2rotm([ld rh/180*pi]);
                rd = (R2*R1*ld')'; % ray direction

                for k = 1:size(el,1)
                    [isHit, u, v, d, dudp,dvdp] = rayTriangleIntersectionWithDerivative(lp, rd, co(el(k,1),:), co(el(k,2),:), co(el(k,3),:));
                    if  isHit && d>0
                        ph = 1/(opt.n_th*opt.n_rh) * (th / avg_th);
                        if  isfield(opt, 'specular')
                            ps = projectSpecularReflection(el,co, k,u,v,rd,ph, opt);
                            %pw(el(k,1),:) = pw(el(k,1),:) + (1-u-v)*ps;
                            dphidp(1) = dphidp(1) + dphidpw(el(k,1),:)*(-dudp(1)-dvdp(1))*ps';
                            dphidp(2) = dphidp(2) + dphidpw(el(k,1),:)*(-dudp(2)-dvdp(2))*ps';
                            dphidp(3) = dphidp(3) + dphidpw(el(k,1),:)*(-dudp(3)-dvdp(3))*ps';
                            %pw(el(k,2),:) = pw(el(k,2),:) +    u   *ps;
                            dphidp(1) = dphidp(1) + dphidpw(el(k,2),:)*( dudp(1)        )*ps';
                            dphidp(2) = dphidp(2) + dphidpw(el(k,2),:)*( dudp(2)        )*ps';
                            dphidp(3) = dphidp(3) + dphidpw(el(k,2),:)*( dudp(3)        )*ps';
                            %pw(el(k,3),:) = pw(el(k,3),:) +      v *ps;
                            dphidp(1) = dphidp(1) + dphidpw(el(k,3),:)*(         dvdp(1))*ps';
                            dphidp(2) = dphidp(2) + dphidpw(el(k,3),:)*(         dvdp(2))*ps';
                            dphidp(3) = dphidp(3) + dphidpw(el(k,3),:)*(         dvdp(3))*ps';
                        else
                            dphidp = dphidp + dphidpw(el(k,1))*(-dudp-dvdp)*ph;
                            dphidp = dphidp + dphidpw(el(k,2))*( dudp     )*ph;
                            dphidp = dphidp + dphidpw(el(k,3))*(      dvdp)*ph;
                        end
                    end
                end
            end
        end
    end
    
    if ~opt.debugPlot,  figure; end
    trisurf(el,co(:,1),co(:,2),co(:,3),pw(:,1)); axis equal; shading interp;
    colormap('hot'); caxis([min(pw(:)) max(pw(:))]);
    %%
    if  isfield(opt, 'specular')
        hold on;
        for i = 1:size(co,1)
            trisurf(opt.usph_el, ...
                opt.usph_co(:,1)*opt.specularDrawScale+co(i,1), ...
                opt.usph_co(:,2)*opt.specularDrawScale+co(i,2), ...
                opt.usph_co(:,3)*opt.specularDrawScale+co(i,3), ...
                pw(i,1)+pw(i,2:end),'FaceAlpha',1); axis equal; shading interp;
            colormap('hot');
        end
        for k = 1:size(el,1)
            trisurf(opt.usph_el, ...
                opt.usph_co(:,1)*opt.specularDrawScale+(co(el(k,1),1)+co(el(k,2),1)+co(el(k,3),1))./3, ...
                opt.usph_co(:,2)*opt.specularDrawScale+(co(el(k,1),2)+co(el(k,2),2)+co(el(k,3),2))./3, ...
                opt.usph_co(:,3)*opt.specularDrawScale+(co(el(k,1),3)+co(el(k,2),3)+co(el(k,3),3))./3, ...
                (pw(el(k,1),1)+pw(el(k,2),1)+pw(el(k,3),1))./3+(pw(el(k,1),2:end)+pw(el(k,2),2:end)+pw(el(k,3),2:end))./3,...
                'FaceAlpha',1); axis equal; shading interp;
            colormap('hot');
        end
        caxis([0 max(max(pw(:,1)+pw(:,2:end)))]); hold off;
    end
    
    view(140,20); drawnow;
%     saveas(gcf,['_img/lightpos_' num2str(opt.lightDisplacement(1)) '.png']);
end



function [pw, dpwdx,dpwdy,dpwdz] = mytestDirect(el,co, opt)
    dataPerNode = 1;
    if  isfield(opt, 'specular')
        tmp = projectSpecularReflection([],[],-1,[],[],[],[],opt);
        dataPerNode = length(tmp);
    end
    pw = zeros(size(co,1),dataPerNode);
    dpwdx = pw; dpwdy = pw; dpwdz = pw; % derivatives of radiative power wrt. light position

    if  opt.debugPlot
        figure; trimesh(el,co(:,1),co(:,2),co(:,3),'FaceAlpha',0,'EdgeAlpha',0.2,'EdgeColor','k'); axis equal; hold on;
    end

    lp = [0.1,3,-0.2];
    if  isfield(opt,'lightDisplacement')
        lp = lp + opt.lightDisplacement;
    end
    [ldx,ldy,ldz] = sph2cart(-pi/2,0,1); ld = [ldx,ldy,ldz]; clear ldx ldy ldz;
    lth= 15; % deg

    if  opt.debugPlot
        quiver3(lp(:,1),lp(:,2),lp(:,3), ld(:,1),ld(:,2),ld(:,3));
    end

    h1=simplePerpendicularVector(ld);
    h1=h1./norm(h1);

    th_vals = linspace(0,lth,opt.n_th+1); th_vals = th_vals(2:end);
    rh_vals = linspace(0,360,opt.n_rh+1); rh_vals = rh_vals(2:end);
    avg_th = sum(th_vals) / opt.n_th;

    for th = th_vals
        for rh = rh_vals
            R1 = axang2rotm([h1 th/180*pi]);
            R2 = axang2rotm([ld rh/180*pi]);
            rd = (R2*R1*ld')'; % ray direction
            
            if  opt.debugPlot
                quiver3(lp(:,1),lp(:,2),lp(:,3), rd(:,1),rd(:,2),rd(:,3), 'g');
            end

            for k = 1:size(el,1)
                [isHit, u, v, d, dudp,dvdp] = rayTriangleIntersectionWithDerivative (lp, rd, co(el(k,1),:), co(el(k,2),:), co(el(k,3),:));
                if  isHit && d>0
                    if  opt.debugPlot
                        hx = co(el(k,1),:) + u*(co(el(k,2),:)-co(el(k,1),:)) + v*(co(el(k,3),:)-co(el(k,1),:));
                        plot3(hx(1),hx(2),hx(3),'bo');
                    end

                    % distribute 1/(opt.n_th*opt.n_rh) of the total radiative power to the three nodes
                    ph = 1/(opt.n_th*opt.n_rh) * (th / avg_th);
                    if  isfield(opt, 'specular')
                        ps = projectSpecularReflection(el,co, k,u,v,rd,ph, opt);
                        pw(el(k,1),:) = pw(el(k,1),:) + (1-u-v)*ps;
                        pw(el(k,2),:) = pw(el(k,2),:) +    u   *ps;
                        pw(el(k,3),:) = pw(el(k,3),:) +      v *ps;
                        % gradient evaluation wrt. light position
                        dpwdx(el(k,1),:) = dpwdx(el(k,1),:) + (-dudp(1)-dvdp(1))*ps;
                        dpwdx(el(k,2),:) = dpwdx(el(k,2),:) +   dudp(1)         *ps;
                        dpwdx(el(k,3),:) = dpwdx(el(k,3),:) +           dvdp(1) *ps;
                        dpwdy(el(k,1),:) = dpwdy(el(k,1),:) + (-dudp(2)-dvdp(2))*ps;
                        dpwdy(el(k,2),:) = dpwdy(el(k,2),:) +   dudp(2)         *ps;
                        dpwdy(el(k,3),:) = dpwdy(el(k,3),:) +           dvdp(2) *ps;
                        dpwdz(el(k,1),:) = dpwdz(el(k,1),:) + (-dudp(3)-dvdp(3))*ps;
                        dpwdz(el(k,2),:) = dpwdz(el(k,2),:) +   dudp(3)         *ps;
                        dpwdz(el(k,3),:) = dpwdz(el(k,3),:) +           dvdp(3) *ps;
                    else
                        pw(el(k,1)) = pw(el(k,1)) + (1-u-v)*ph;
                        pw(el(k,2)) = pw(el(k,2)) +    u   *ph;
                        pw(el(k,3)) = pw(el(k,3)) +      v *ph;
                        % gradient evaluation wrt. light position
                        dpwdx(el(k,1)) = dpwdx(el(k,1)) + (-dudp(1)-dvdp(1))*ph;
                        dpwdx(el(k,2)) = dpwdx(el(k,2)) +   dudp(1)         *ph;
                        dpwdx(el(k,3)) = dpwdx(el(k,3)) +           dvdp(1) *ph;
                        dpwdy(el(k,1)) = dpwdy(el(k,1)) + (-dudp(2)-dvdp(2))*ph;
                        dpwdy(el(k,2)) = dpwdy(el(k,2)) +   dudp(2)         *ph;
                        dpwdy(el(k,3)) = dpwdy(el(k,3)) +           dvdp(2) *ph;
                        dpwdz(el(k,1)) = dpwdz(el(k,1)) + (-dudp(3)-dvdp(3))*ph;
                        dpwdz(el(k,2)) = dpwdz(el(k,2)) +   dudp(3)         *ph;
                        dpwdz(el(k,3)) = dpwdz(el(k,3)) +           dvdp(3) *ph;
                    end
                    
                    
                    break; % end loop once we've found an intersection
                end
            end
        end
    end
    
    if ~opt.debugPlot,  figure; end
    trisurf(el,co(:,1),co(:,2),co(:,3),pw(:,1)); axis equal; shading interp;
    colormap('hot'); caxis([min(pw(:)) max(pw(:))]);
    %%
    if  isfield(opt, 'specular')
        hold on;
        for i = 1:size(co,1)
            trisurf(opt.usph_el, ...
                opt.usph_co(:,1)*opt.specularDrawScale+co(i,1), ...
                opt.usph_co(:,2)*opt.specularDrawScale+co(i,2), ...
                opt.usph_co(:,3)*opt.specularDrawScale+co(i,3), ...
                pw(i,1)+pw(i,2:end),'FaceAlpha',1); axis equal; shading interp;
            colormap('hot');
        end
        for k = 1:size(el,1)
            trisurf(opt.usph_el, ...
                opt.usph_co(:,1)*opt.specularDrawScale+(co(el(k,1),1)+co(el(k,2),1)+co(el(k,3),1))./3, ...
                opt.usph_co(:,2)*opt.specularDrawScale+(co(el(k,1),2)+co(el(k,2),2)+co(el(k,3),2))./3, ...
                opt.usph_co(:,3)*opt.specularDrawScale+(co(el(k,1),3)+co(el(k,2),3)+co(el(k,3),3))./3, ...
                (pw(el(k,1),1)+pw(el(k,2),1)+pw(el(k,3),1))./3+(pw(el(k,1),2:end)+pw(el(k,2),2:end)+pw(el(k,3),2:end))./3,...
                'FaceAlpha',1); axis equal; shading interp;
            colormap('hot');
        end
        caxis([0 max(max(pw(:,1)+pw(:,2:end)))]); hold off;
    end
    
    view(140,20); drawnow;
%     saveas(gcf,['_img/lightpos_' num2str(opt.lightDisplacement(1)) '.png']);
end

function [phi,dphi] = sampleEvalFunc(el,co, opt,pw_target, p)
    opt.lightDisplacement = p;
    [~, phi, dphi] = mytestAdjoint(el,co, opt, pw_target);
end


