clear; clc;

if ~contains(path, 'gek_2016') % Gradient-enhanced Kriging, 2016, JHS de Baar, UNSW Canberra from https://www.mathworks.com/matlabcentral/fileexchange/60230-gradient-enhanced-kriging
    addpath 'D:\m\gek_2016';
    disp(path);
end
%%

d = 3;

N = 2^d;
options.debug = 0; % gives output for debugging, set to 0 for no output
options.estvar = 'yes';
options.predicttype = 'predictplus'; % also estimate gradients
% maximum likelihood settings:
% options.hyperinit = [1 1 .1];
% options.hyperspace = [1 1 1];
% options.hyperest = 'brute';
% options.brutesize = 1e3;
options.hyperinit = [1 1 4 4 4];
% options.hyperspace = [1 1 1 1 1];
options.hyperspace = [0 0 1 1 1];
options.fminopts = optimset('Display','notify','TolFun',1e-12,'TolX',1e-4,'MaxIter',1e4,'MaxFunEvals',1e4);

errf = 1e-4; errdf = 1e-2;


%%
xi = lhsdesign(N,d)*3-1.5; % [0,1] --> [-1.5,1.5]
fx = zeros(N,1);
df = zeros(N,d);

% Note: expected optimal solution (no SH version) is around  [ -0.479571    2.00059 -0.0689892] with objective 0.0563
runSetup();

for i = 1:N
    [fxi,dfi] = runSim(xi(i,:));
    fx(i) = fxi;
    df(i,:) = dfi;
end
%%
stopcount=0;
for iter = 1:40
    
    x_min = [-2;-2;-2]; x_max = [2;2;2]; % box constraint for next evaluation point
%     if length(fx) > N+5 % only keep limited memory of past function evals
%         [~,idx] = sort(fx); idx = idx(1:(N+5)); %keep the best evaluations 
%         idx = union(idx, length(fx)); %also keep latest
%         xi = xi(idx,:); fx = fx(idx,:); df = df(idx,:);
%         x_min = min(xi); x_max = max(xi); % restrict search to box covered by values
%         disp(['restriction: ' num2str([x_min x_max])]);
%     end
%%
    gekmodel = gekPart1(xi,fx,errf*ones(size(fx)),df,errdf*ones(size(df)),options);
%     options.hyperinit = gekmodel.hypers; % warm start next iteration
%     options.hyperspace = [0 0 1 1 1]; % don't re-optimize error factors

    % in-sample test
    [test_fx,~,test_grad] = gekPart2(gekmodel, xi);
    disp('in-sample GEK errors [f, g]:'); max(abs(test_fx - fx)), max(max(abs(test_grad - df)))

    [bestF, bestI] = min(fx);
    ei = @(x) expectedImprovement(x, bestF, gekmodel);
%     if( mod(iter,5)==0 ) % sometimes optimize on the surrogate alone
%         ei = @(x) gekPart2(gekmodel, x');
%         disp('Direct opt');
%     end
    fminopts = optimoptions(@fmincon,'Display','notify');
    x_start = xi(bestI,:)'; expImp=inf;
    for ri = 1:9
        [x_,expImp_] = fmincon(ei,x_start,[],[],[],[],x_min,x_max,[],fminopts);
        x_start = rand(3,1)*4-2;  
        if  expImp_ < expImp
            expImp = expImp_;
            x = x_;
        end
    end
%%
    [fxi,dfi] = runSim(x');
    
    xi = [xi ; x']; %#ok
    fx = [fx ; fxi]; %#ok
    df = [df ; dfi]; %#ok

    fprintf('%.6g ',[bestF, fxi, (fxi-bestF), -expImp]); disp(iter);
%%

%     figure;
%     scale = 1e-4; quiver3(xi(:,1),xi(:,2), fx , scale*dfdx(:,1),scale*dfdx(:,2), scale*0*fx ,0);
    [X,Z] = meshgrid(-4:.1:4);
    for layer = 1:7:size(X,2)
        xout = [reshape((X),numel(X),1) X(1,layer)*ones(numel(X),1) reshape((Z),numel(Z),1)];
        [fout, varfout, gradout, vargradout, report] = gekPart2(gekmodel,xout);
        F = reshape(fout,size(X));
        V = reshape(varfout,size(X));
        warning('off','MATLAB:contour:ConstantData');
        [~,h]=contour(X,Z,F); set(h,'ContourZLevel',X(1,layer));
        surf(X,Z,0*X+X(1,layer),F,'EdgeAlpha',0,'FaceAlpha',0.1);
        hold on;
    end
%     surf(X,Y,F-sqrt(abs(V)),'EdgeAlpha',0,'FaceAlpha',0.3);
%     surf(X,Y,F+sqrt(abs(V)),'EdgeAlpha',0,'FaceAlpha',0.3);
    plot3(x(1),x(3),x(2), 'kd');
    %plot3(xi(:,1),xi(:,3),xi(:,2), 'rx');
    h = scatter3(xi(:,1),xi(:,3),xi(:,2),20,fx); set(h, 'MarkerFaceColor','flat');
    plot3(xi(bestI,1),xi(bestI,3),xi(bestI,2), 'go');
    hold off; view(-80,5); caxis([0 4]); colorbar; drawnow; pause(0.2);
    saveas(gcf,['gek_iter_' num2str(iter) '.png']);
    
%%
    if abs(expImp)<1e-10, stopcount = stopcount+1; else, stopcount=0; end
    if stopcount >= 3, break; end
end
%%
plot(bestSoFar(fx));
%%
save('results.mat', 'xi', 'fx', 'df');
%%
function runSetup()
    !prepareWorkdir.bat
    !mklink /J _workdir\gltf_v3 ..\..\..\tamashii-resources\ialt_paper_scenes\pcon\simple_office\gltf_v3
    cd _workdir
end

function [fxi,dfi] = runSim(xi)
    disp('sim'); disp(xi);
    
    dlmwrite('params.txt',xi');
    
    [~, cout] = system('inter_adj_light_trace.exe -headless 1 -runPredefinedTest single-eval-pos-params -load_scene "gltf_v3\simple_office_v3.gltf" -numRaysXperLight 2048 -numRaysYperLight 2048 -constRandSeed 0 -useSHdiffOnlyCoeffObjective 1');

    obj_str = regexp(cout,'obj = [ e\-\+0-9\.]*;', 'match');
    eval(obj_str{1});
    fxi = obj;

    dp_str = regexp(cout,'dp  = \[[ e\-\+0-9\.]*\];', 'match');
    eval(dp_str{1});
    dfi = dp;

    disp([fxi dfi]);
end

function ei = expectedImprovement(x , bestF, gekmodel)
    [fx,varx] = gekPart2(gekmodel, x');
    sig = sqrt(varx+1e-10);
    ei = (fx - 1.0*sig) - bestF; % "lower confidence bound" approach (higher weight on std-dev increases exploration)
%     ei = -(  (bestF - fx) * normcdf( (bestF-fx)/sig ) + sig * normpdf( (bestF-fx)/sig )  );
end

%%
function x = bestSoFar(x)
    b = x(1);
    for k=1:length(x)
        if  x(k) < b
            b = x(k);
        else
            x(k)=b;
        end
    end
end