clc; clear variables;
obj = read_wobj('tilted_plane.obj');
co = obj.vertices;
el = obj.objects( [obj.objects(:).type]=='f' ).data.vertices;
% [co,el] = refineRGB(co,el,findBoundary(el),1:size(el,1)); % refine once to have same number of DOFs as P2 mesh
% trimesh(el,co(:,1),co(:,2),co(:,3),'FaceAlpha',0,'EdgeAlpha',0.2,'EdgeColor','k');

% DO NOT USE: specular version not implemented here!!!
% opt.specular=0.6; % ratio of specular reflection (diffuse will be 1-s)
% opt.specularDrawScale=0.06; % size of specular balls in plots
% [opt.usph_el, opt.usph_co] = readIcoSphereFromMesh();

opt.debugPlot = 0;
opt.smoothing = 0;%.005; % add some Laplacian smoothing to the result
opt.n_th = 15;
opt.n_rh = 36;
opt.lth = 15; % spot light cone angle (deg)

opt.lightDisplacement = [-0.5 0 0.5];
pw_target = mytestDirect(el,co, opt);
title('target'); drawnow;

[phi, dphidpw] = convWstnDistance(el,co, pw_target ,pw_target)

% % no longer the case .. falloff term around spot light cone edge reduces total output power -- integrate pw over mesh == should equal 1
% pw = pw_target;
% int_pw = 0;
% for k=1:size(el,1)
%     % linear fcn within el --> avg. of corners * area
%     avg = ( pw(el(k,1),:) + pw(el(k,2),:) +pw(el(k,3),:) ) / 3;
%     ar = 0.5*( norm(cross( co(el(k,2),:)-co(el(k,1),:) , co(el(k,3),:)-co(el(k,1),:) )));
%     int_pw = int_pw + avg*ar;
% end
% sum(int_pw) % ToDo: also use L2-projection (or spherical harmonic projection) in specular case - and compare by integrating specular power over unit sphere

%%
% close all;
% pw_target = 0*pw_target + 0.5;

%%
p = [0 0 0];
phi_iters = [];
%%
for j=1:150
    opt.lightDisplacement = p;
    
    [pw, dpwdx,dpwdy,dpwdz] = mytestDirect(el,co, opt);
    [phi, dphidpw] = mytestObjective(el,co, pw, pw_target);
    dphidp = (dphidpw(:)') * [dpwdx(:), dpwdy(:), dpwdz(:)];
    
    if  j==-1 % adjoint not updated yet
        [pw_adj, phi_adj, dphidp_adj] = mytestAdjoint(el,co,opt, pw_target);

        % compare direct and adjoint calculations -- works
        disp('(pw, phi, dphidp) direct-adjoint abs err:');
        disp([max(abs(pw(:)-pw_adj(:)))  max(abs(phi(:)-phi_adj(:)))  max(abs(dphidp(:)-dphidp_adj(:)))]);
        disp('grad_adj:');
        disp(dphidp_adj);
    end
    
    % run FD check
    fd_h=1e-7;
    dphidp_fd = [0 0 0];
%     dpwdx_fd = 0*pw; dpwdy_fd = 0*pw; dpwdz_fd = 0*pw;
    if  j==1 % works
        disp('FD-check ...');
        for l=1:length(dphidp_fd)
            opt.lightDisplacement = p;
            opt.lightDisplacement(l) = opt.lightDisplacement(l) + fd_h;
            pw_fd  = mytestDirect(el,co, opt);
            phi_fd = mytestObjective(el,co, pw_fd, pw_target);
            dphidp_fd(l) = (phi_fd - phi) / fd_h;
%             if l==1, dpwdx_fd = (pw_fd - pw) / fd_h; end
%             if l==2, dpwdy_fd = (pw_fd - pw) / fd_h; end
%             if l==3, dpwdz_fd = (pw_fd - pw) / fd_h; end
        end
        disp('abs err (grad-fd):');
        disp(abs(dphidp - dphidp_fd));
        disp('grad, fd:');
        disp(dphidp);
        disp(dphidp_fd);
        disp('... done');
        %dphidp = dphidp_fd;
    end

    phi_iters = [phi_iters; phi]; %#ok
    
    title(['step ' num2str(length(phi_iters),'%3u') ', p = [' num2str(p,'%.3f, ') ']']);  drawnow;
%     saveas(gcf,['_img/grad_desc_' num2str(j,'%03u') '.png']);
    
    p = p - 0.1*dphidp; % basic gradient descent step
end

figure; plot(phi_iters);

%% functions ...

function [phi, dphidpw] = mytestObjective(el,co, pw, pw_target)
    % standard L2 distance ...
    M = consistentMass(el,co);
    phi = 0.5*(pw-pw_target)'*M*(pw-pw_target);
    dphidpw = M*(pw-pw_target);
    
%     % convolutional 2-Wasserstein distance
%     [phi, dphidpw] = convWstnDistance(el,co, pw,pw_target);
end


function [pw, phi, dphidp] = mytestAdjoint(el,co, opt, pw_target)
    disp('OLD VERSION DO NOT USE HERE');
    M = consistentMass(el,co);
    L = laplacianP1tri(el,co);

    dataPerNode = 1;
    if  isfield(opt, 'specular')
        tmp = projectSpecularReflection([],[],-1,[],[],[],[],opt);
        dataPerNode = length(tmp);
    end
    pw = zeros(size(co,1),dataPerNode);

    if  opt.debugPlot
        figure; trimesh(el,co(:,1),co(:,2),co(:,3),'FaceAlpha',0,'EdgeAlpha',0.2,'EdgeColor','k'); axis equal; hold on;
    end

    lp = [0.1,3,-0.2];
    if  isfield(opt,'lightDisplacement')
        lp = lp + opt.lightDisplacement;
    end
    [ldx,ldy,ldz] = sph2cart(-pi/2,0,1); ld = [ldx,ldy,ldz]; clear ldx ldy ldz;
    lth = opt.lth; %15; % deg

    if  opt.debugPlot
        quiver3(lp(:,1),lp(:,2),lp(:,3), ld(:,1),ld(:,2),ld(:,3));
    end

    h1=simplePerpendicularVector(ld);
    h1=h1./norm(h1);

    th_vals = linspace(0,lth,opt.n_th+1); th_vals = th_vals(2:end);
    rh_vals = linspace(0,360,opt.n_rh+1); rh_vals = rh_vals(2:end);
    avg_th = sum(th_vals)/opt.n_th;

    for th = th_vals
        for rh = rh_vals
            R1 = axang2rotm([h1 th/180*pi]);
            R2 = axang2rotm([ld rh/180*pi]);
            rd = (R2*R1*ld')'; % ray direction
            
            if  opt.debugPlot
                quiver3(lp(:,1),lp(:,2),lp(:,3), rd(:,1),rd(:,2),rd(:,3), 'g');
            end

            for k = 1:size(el,1)
                [isHit, u, v, d] = rayTriangleIntersection(lp, rd, co(el(k,1),:), co(el(k,2),:), co(el(k,3),:));
                if  isHit && d>0
                    if  opt.debugPlot
                        hx = co(el(k,1),:) + u*(co(el(k,2),:)-co(el(k,1),:)) + v*(co(el(k,3),:)-co(el(k,1),:));
                        plot3(hx(1),hx(2),hx(3),'bo');
                    end

                    % distribute 1/(opt.n_th*opt.n_rh) of the total radiative power to the three nodes
%                     ph = 1/(opt.n_th*opt.n_rh);
                    ph = 1/(opt.n_th*opt.n_rh) * (th / avg_th);
                    if  isfield(opt, 'specular')
                        ps = projectSpecularReflection(el,co, k,u,v,rd,ph, opt);
                        pw(el(k,1),:) = pw(el(k,1),:) + (1-u-v)*ps;
                        pw(el(k,2),:) = pw(el(k,2),:) +    u   *ps;
                        pw(el(k,3),:) = pw(el(k,3),:) +      v *ps;
                    else
                        pw(el(k,1)) = pw(el(k,1)) + (1-u-v)*ph;
                        pw(el(k,2)) = pw(el(k,2)) +    u   *ph;
                        pw(el(k,3)) = pw(el(k,3)) +      v *ph;
                    end
                    break; % end loop once we've found an intersection
                end
            end
        end
    end
    
    %pw = M\pw; % L2-projection
    pw = (M+opt.smoothing*L)\pw; % L2-projection with some smoothing

    
    % forward step done, now adjoint step ...
    if  isempty(pw_target)
        phi=-1;dphidp=0; % no target - stop after forward step
    else
        % evaluate objective
        [phi, dphidpw] = mytestObjective(el,co, pw,pw_target);
%         dphidpw=M\dphidpw;  % L2-projection of partials
        dphidpw = (M+opt.smoothing*L)\dphidpw; % L2-projection with some smoothing
        
        % adjoint derivative
        dphidp = [0 0 0]; % 3 parameters (light position x,y,z)
        for th = th_vals
            for rh = rh_vals
                R1 = axang2rotm([h1 th/180*pi]);
                R2 = axang2rotm([ld rh/180*pi]);
                rd = (R2*R1*ld')'; % ray direction

                for k = 1:size(el,1)
                    [isHit, u, v, d, dudp,dvdp] = rayTriangleIntersectionWithDerivative(lp, rd, co(el(k,1),:), co(el(k,2),:), co(el(k,3),:));
                    if  isHit && d>0
%                         ph = 1/(opt.n_th*opt.n_rh);
                        ph = 1/(opt.n_th*opt.n_rh) * (th / avg_th);
                        if  isfield(opt, 'specular')
                            ps = projectSpecularReflection(el,co, k,u,v,rd,ph, opt);
                            %pw(el(k,1),:) = pw(el(k,1),:) + (1-u-v)*ps;
                            dphidp(1) = dphidp(1) + dphidpw(el(k,1),:)*(-dudp(1)-dvdp(1))*ps';
                            dphidp(2) = dphidp(2) + dphidpw(el(k,1),:)*(-dudp(2)-dvdp(2))*ps';
                            dphidp(3) = dphidp(3) + dphidpw(el(k,1),:)*(-dudp(3)-dvdp(3))*ps';
                            %pw(el(k,2),:) = pw(el(k,2),:) +    u   *ps;
                            dphidp(1) = dphidp(1) + dphidpw(el(k,2),:)*( dudp(1)        )*ps';
                            dphidp(2) = dphidp(2) + dphidpw(el(k,2),:)*( dudp(2)        )*ps';
                            dphidp(3) = dphidp(3) + dphidpw(el(k,2),:)*( dudp(3)        )*ps';
                            %pw(el(k,3),:) = pw(el(k,3),:) +      v *ps;
                            dphidp(1) = dphidp(1) + dphidpw(el(k,3),:)*(         dvdp(1))*ps';
                            dphidp(2) = dphidp(2) + dphidpw(el(k,3),:)*(         dvdp(2))*ps';
                            dphidp(3) = dphidp(3) + dphidpw(el(k,3),:)*(         dvdp(3))*ps';
                        else
                            dphidp = dphidp + dphidpw(el(k,1))*(-dudp-dvdp)*ph;
                            dphidp = dphidp + dphidpw(el(k,2))*( dudp     )*ph;
                            dphidp = dphidp + dphidpw(el(k,3))*(      dvdp)*ph;
                        end
                    end
                end
            end
        end
    end
    
    if ~opt.debugPlot,  figure; end
    trisurf(el,co(:,1),co(:,2),co(:,3),pw(:,1)); axis equal; shading interp;
    colormap('hot'); caxis([min(pw(:)) max(pw(:))]);
    %%
    if  isfield(opt, 'specular')
        hold on;
        for i = 1:size(co,1)
            trisurf(opt.usph_el, ...
                opt.usph_co(:,1)*opt.specularDrawScale+co(i,1), ...
                opt.usph_co(:,2)*opt.specularDrawScale+co(i,2), ...
                opt.usph_co(:,3)*opt.specularDrawScale+co(i,3), ...
                pw(i,1)+pw(i,2:end),'FaceAlpha',1); axis equal; shading interp;
            colormap('hot');
        end
        for k = 1:size(el,1)
            trisurf(opt.usph_el, ...
                opt.usph_co(:,1)*opt.specularDrawScale+(co(el(k,1),1)+co(el(k,2),1)+co(el(k,3),1))./3, ...
                opt.usph_co(:,2)*opt.specularDrawScale+(co(el(k,1),2)+co(el(k,2),2)+co(el(k,3),2))./3, ...
                opt.usph_co(:,3)*opt.specularDrawScale+(co(el(k,1),3)+co(el(k,2),3)+co(el(k,3),3))./3, ...
                (pw(el(k,1),1)+pw(el(k,2),1)+pw(el(k,3),1))./3+(pw(el(k,1),2:end)+pw(el(k,2),2:end)+pw(el(k,3),2:end))./3,...
                'FaceAlpha',1); axis equal; shading interp;
            colormap('hot');
        end
        caxis([0 max(max(pw(:,1)+pw(:,2:end)))]); hold off;
    end
    
    view(140,20); drawnow;
%     saveas(gcf,['_img/lightpos_' num2str(opt.lightDisplacement(1)) '.png']);
end



function [pw, dpwdx,dpwdy,dpwdz] = mytestDirect(el,co, opt)
    M = consistentMass(el,co);
    L = laplacianP1tri(el,co);
    [qp,qw] = gaussQuadratureTri(); % quadrature rule in barycentric coords
    % from live script ... derivative wrt. light displacement of this term: ( dot(n,-rd)/(rl*rl) )
    geometricDerivative = @(lp1,lp2,lp3,n1,n2,n3,qp1,qp2,qp3)[n1.*1.0./((lp1-qp1).^2+(lp2-qp2).^2+(lp3-qp3).^2).^(3.0./2.0)-(lp1.*2.0-qp1.*2.0).*1.0./((lp1-qp1).^2+(lp2-qp2).^2+(lp3-qp3).^2).^(5.0./2.0).*(lp1.*n1+lp2.*n2+lp3.*n3-n1.*qp1-n2.*qp2-n3.*qp3).*(3.0./2.0);n2.*1.0./((lp1-qp1).^2+(lp2-qp2).^2+(lp3-qp3).^2).^(3.0./2.0)-(lp2.*2.0-qp2.*2.0).*1.0./((lp1-qp1).^2+(lp2-qp2).^2+(lp3-qp3).^2).^(5.0./2.0).*(lp1.*n1+lp2.*n2+lp3.*n3-n1.*qp1-n2.*qp2-n3.*qp3).*(3.0./2.0);n3.*1.0./((lp1-qp1).^2+(lp2-qp2).^2+(lp3-qp3).^2).^(3.0./2.0)-(lp3.*2.0-qp3.*2.0).*1.0./((lp1-qp1).^2+(lp2-qp2).^2+(lp3-qp3).^2).^(5.0./2.0).*(lp1.*n1+lp2.*n2+lp3.*n3-n1.*qp1-n2.*qp2-n3.*qp3).*(3.0./2.0)];

    dataPerNode = 1;
    if  isfield(opt, 'specular')
        tmp = projectSpecularReflection([],[],-1,[],[],[],[],opt);
        dataPerNode = length(tmp);
    end
    pw = zeros(size(co,1),dataPerNode);
    dpwdx = pw; dpwdy = pw; dpwdz = pw; % derivatives of radiative power wrt. light position

    if  opt.debugPlot
        figure; trimesh(el,co(:,1),co(:,2),co(:,3),'FaceAlpha',0,'EdgeAlpha',0.2,'EdgeColor','k'); axis equal; hold on;
    end

    lp = [0.1,3,-0.2];
    if  isfield(opt,'lightDisplacement')
        lp = lp + opt.lightDisplacement;
    end
    [ldx,ldy,ldz] = sph2cart(-pi/2,0 ,1); ld = [ldx,ldy,ldz]; clear ldx ldy ldz;
    lth = opt.lth;
    lcap = 2*pi*(1-cos(lth/180*pi)); % area of spherical cap on unit sphere (solid angle) for cone of lth opening angle

    if  opt.debugPlot
        quiver3(lp(:,1),lp(:,2),lp(:,3), ld(:,1),ld(:,2),ld(:,3)); hold on;
    end

%     h1=simplePerpendicularVector(ld);
%     h1=h1./norm(h1);

    for k = 1:size(el,1) % for each element
        n = cross( co(el(k,2),:)-co(el(k,1),:) , co(el(k,3),:)-co(el(k,1),:) );
        ar = 0.5*norm(n);
        n = n./norm(n);

        for i = 1:length(qw) % for each quadrature point
            qpx = qp(i,:) * co(el(k,:),:);
            rd = qpx-lp;
            rl = norm(rd);
            rd = rd./rl;
            
            if  dot(rd,ld)>cos(lth/180*pi) % cos( light angle ) > spot cone
                rth = acos( dot(rd,ld) )*180/pi;
                if  rth>(lth-5)
                    blend = (lth-rth)/5; 
                    dblendfcn = @(ld1,ld2,ld3,lp1,lp2,lp3,qp1,qp2,qp3)[(1.0./sqrt(-(ld1.*lp1+ld2.*lp2+ld3.*lp3-ld1.*qp1-ld2.*qp2-ld3.*qp3).^2./((lp1-qp1).^2+(lp2-qp2).^2+(lp3-qp3).^2)+1.0).*(ld1.*1.0./sqrt((lp1-qp1).^2+(lp2-qp2).^2+(lp3-qp3).^2)-(lp1.*2.0-qp1.*2.0).*1.0./((lp1-qp1).^2+(lp2-qp2).^2+(lp3-qp3).^2).^(3.0./2.0).*(ld1.*lp1+ld2.*lp2+ld3.*lp3-ld1.*qp1-ld2.*qp2-ld3.*qp3).*(1.0./2.0)).*-3.6e1)./pi;(1.0./sqrt(-(ld1.*lp1+ld2.*lp2+ld3.*lp3-ld1.*qp1-ld2.*qp2-ld3.*qp3).^2./((lp1-qp1).^2+(lp2-qp2).^2+(lp3-qp3).^2)+1.0).*(ld2.*1.0./sqrt((lp1-qp1).^2+(lp2-qp2).^2+(lp3-qp3).^2)-(lp2.*2.0-qp2.*2.0).*1.0./((lp1-qp1).^2+(lp2-qp2).^2+(lp3-qp3).^2).^(3.0./2.0).*(ld1.*lp1+ld2.*lp2+ld3.*lp3-ld1.*qp1-ld2.*qp2-ld3.*qp3).*(1.0./2.0)).*-3.6e1)./pi;(1.0./sqrt(-(ld1.*lp1+ld2.*lp2+ld3.*lp3-ld1.*qp1-ld2.*qp2-ld3.*qp3).^2./((lp1-qp1).^2+(lp2-qp2).^2+(lp3-qp3).^2)+1.0).*(ld3.*1.0./sqrt((lp1-qp1).^2+(lp2-qp2).^2+(lp3-qp3).^2)-(lp3.*2.0-qp3.*2.0).*1.0./((lp1-qp1).^2+(lp2-qp2).^2+(lp3-qp3).^2).^(3.0./2.0).*(ld1.*lp1+ld2.*lp2+ld3.*lp3-ld1.*qp1-ld2.*qp2-ld3.*qp3).*(1.0./2.0)).*-3.6e1)./pi];

                    dblenddp = dblendfcn(ld(1),ld(2),ld(3),lp(1),lp(2),lp(3),qpx(1),qpx(2),qpx(3));
                else
                    blend = 1; dblenddp = 0;
                end

                if  opt.debugPlot
                    quiver3(qpx(:,1),qpx(:,2),qpx(:,3), -rl*rd(:,1),-rl*rd(:,2),-rl*rd(:,3),0, 'g');
                    quiver3(qpx(:,1),qpx(:,2),qpx(:,3), n(1),n(2),n(3),0, 'r');
                end

                % distribute the total radiative power to the three nodes
                %  power/solid angle     surface cosine and distance falloff
                ph = 1.0/lcap * ar*qw(i)*blend*dot(n,-rd)/(rl*rl); % ar*qw(i) due to quadrature over triangle (shape function = barycentric coord = qp(i,:) multiplied below)
                if  isfield(opt, 'specular')
                    ps = projectSpecularReflection(el,co, k,qp(i,2),qp(i,3),rd,ph, opt); % qp(i,2),qp(i,3) are (u,v) coords -- qp(i,1)==(1-u-v)
%                     pw(el(k,1),:) = pw(el(k,1),:) + qp(i,1)*ps;
%                     pw(el(k,2),:) = pw(el(k,2),:) + qp(i,2)*ps;
%                     pw(el(k,3),:) = pw(el(k,3),:) + qp(i,3)*ps;
                    pw(el(k,:),:) = pw(el(k,:),:) + qp(i,:)'*ps;
%                         % gradient evaluation wrt. light position

                else
                    pw(el(k,:)) = pw(el(k,:)) + qp(i,:)'*ph;
                    % gradient evaluation wrt. light position
                    % from live script ... d/dp(dot(n,-rd)/(rl*rl))
                    dphdp = 1.0/lcap * ar*qw(i)*blend* geometricDerivative(lp(1),lp(2),lp(3),n(1),n(2),n(3),qpx(1),qpx(2),qpx(3)) ...
                          + 1.0/lcap * ar*qw(i)*dot(n,-rd)/(rl*rl)*dblenddp;
                    
                    dpwdx(el(k,:)) = dpwdx(el(k,:)) + qp(i,:)'*dphdp(1);
                    dpwdy(el(k,:)) = dpwdy(el(k,:)) + qp(i,:)'*dphdp(2);
                    dpwdz(el(k,:)) = dpwdz(el(k,:)) + qp(i,:)'*dphdp(3);
                end
            end
        end
    end

    pw = (M+opt.smoothing*L)\pw; % L2-projection with some smoothing
    dpwdx = (M+opt.smoothing*L)\dpwdx;
    dpwdy = (M+opt.smoothing*L)\dpwdy;
    dpwdz = (M+opt.smoothing*L)\dpwdz;
    
    figure; trisurf(el,co(:,1),co(:,2),co(:,3),pw(:,1)); axis equal; shading interp;
    colormap('hot'); caxis([min(pw(:)) max(pw(:))]);
    %%
    if  isfield(opt, 'specular')
        hold on;
        for i = 1:size(co,1)
            trisurf(opt.usph_el, ...
                opt.usph_co(:,1)*opt.specularDrawScale+co(i,1), ...
                opt.usph_co(:,2)*opt.specularDrawScale+co(i,2), ...
                opt.usph_co(:,3)*opt.specularDrawScale+co(i,3), ...
                pw(i,1)+pw(i,2:end),'FaceAlpha',1); axis equal; shading interp;
            colormap('hot');
        end
        for k = 1:size(el,1)
            trisurf(opt.usph_el, ...
                opt.usph_co(:,1)*opt.specularDrawScale+(co(el(k,1),1)+co(el(k,2),1)+co(el(k,3),1))./3, ...
                opt.usph_co(:,2)*opt.specularDrawScale+(co(el(k,1),2)+co(el(k,2),2)+co(el(k,3),2))./3, ...
                opt.usph_co(:,3)*opt.specularDrawScale+(co(el(k,1),3)+co(el(k,2),3)+co(el(k,3),3))./3, ...
                (pw(el(k,1),1)+pw(el(k,2),1)+pw(el(k,3),1))./3+(pw(el(k,1),2:end)+pw(el(k,2),2:end)+pw(el(k,3),2:end))./3,...
                'FaceAlpha',1); axis equal; shading interp;
            colormap('hot');
        end
        caxis([0 max(max(pw(:,1)+pw(:,2:end)))]); hold off;
    end
    
    view(140,20); drawnow;
%     saveas(gcf,['_img/lightpos_' num2str(opt.lightDisplacement(1)) '.png']);
end


