import loader
loader.loadTamashii()

import pyialt as ta
import numpy as np
import os
import time
import smt
import smt.surrogate_models
import smt.sampling_methods
import smt.applications
import scipy.optimize
import scipy.stats
from matplotlib import pyplot as plt
import matplotlib.cm as cm
import warnings

def runMultipleSims(xi):
    print(f'MultiSim: {xi.shape[0]}')
    #print(xi)

    fi = 0 * xi[:,0,np.newaxis]
    df = 0 * xi
    
    ialt = ta.impls.ialt
    for idx in range(xi.shape[0]):
        ialt.forward(xi[idx])
        f, g = ialt.backward()
        
        fi[idx] = f
        df[idx,:] = g

    return (fi, df)

def runEGOmanually(scenePath, iterations= 20, savePlots= True):
    # Note by relf (https://bytemeta.vip/repo/SMTorg/smt/issues/314):
    # If you take n_comp=nx, where nx is the dimension of the given problem input, then GEKPLS boils down to GEK.
    N = 8
    dim = 3
    pls_dim = 3
    xlimits = np.array([[-2,2],[-2,2],[-2,2]])
    sampling = smt.sampling_methods.LHS(xlimits=xlimits, criterion='ese') #, random_state=1)
    xi = sampling(N)

    plt_fig = plt.figure(figsize=(8, 6))
    plt_ax = plt_fig.add_subplot(projection='3d')
    plt_fig.show()
    plt.pause(1)

    sc = plt_ax.scatter(xi[:,0], xi[:,2], xi[:,1], s=16)
    cbar = plt_fig.colorbar(sc)
    cbar.draw_all()
    plt_ax.set_xlabel('X')
    plt_ax.set_zlabel('Y')
    plt_ax.set_ylabel('Z') # Y-up
    plt_fig.canvas.draw()
    plt_fig.canvas.flush_events()

    ta.var.log_level = "none"
    ta.var.default_camera = "Camera.001"
    ta.var.numRaysXperLight = 2000
    ta.var.numRaysYperLight = 2000
    ta.var.constRandomSeed = True
    ta.var.usePathTracing = [True, False]
    ialt = ta.impls.ialt

    init = [2,2,-2]
    target = [-2,2,2]
    
    s = ta.openScene(scenePath)
    light = s.lights[0]

    # Prepare target
    light.position = target
    ta.frame()
    ialt.currentRadianceAsTarget(True)
    target_image = ialt.getTargetImage()
    
    # Prepare init 
    light.position = init
    ta.frame()
    init_image = ta.captureCurrentFrame()

    # t.showWindow(True)
    # Optimize
    print('Start...')
    start_time = time.time()
    
    s.lightsToParameterVector()
    fi, df= runMultipleSims(xi)
    
    best_idx= -1
    
    for i in range(iterations):
        # print(f'fi= {fi}') # <- very loud
        # print(f'df= {df}')
        
        gpMdl = smt.surrogate_models.GEKPLS(
            theta0=[1e-2] * pls_dim,
            xlimits= xlimits,
            extra_points= dim,
            print_prediction= False,
            n_comp= pls_dim,
            delta_x=1e-4
            )
            
        gpMdl.set_training_values(xi, fi)
        for idx in range(dim):
            gpMdl.set_training_derivatives(xi, df[:,idx], idx)
    
        gpMdl.train()
    
        df_mdl = 0 * df
        for idx in range(dim):
            df_mdl[:,idx,np.newaxis] = gpMdl.predict_derivatives(xi, idx)

        maxGradErr = np.max(np.abs(df - df_mdl))
        print(f'gradient abs. err. = {maxGradErr}')
        maxInterpErr = np.max(np.abs(gpMdl.predict_values(xi) - fi))
        print(f'interpolation abs. err. = {maxInterpErr}')

        # Draw model predictions
        [X,Y,Z] = np.meshgrid(np.linspace(-2,2,10) , np.linspace(-2,2,10) , np.linspace(-2,2,10))
        XYZ = np.concatenate((
            np.reshape(X, [np.prod(X.shape),1]),
            np.reshape(Y, [np.prod(X.shape),1]),
            np.reshape(Z, [np.prod(X.shape),1])), axis=1)
        f_mdl = np.reshape(gpMdl.predict_values(XYZ) , X.shape)

        cbar.remove()
        plt_ax.clear()
        plt_ax.scatter(X, Z, Y, c=f_mdl, s=8)
        
        # Draw lamp positions
        sc = plt_ax.scatter(xi[:,0], xi[:,2], xi[:,1], c=fi, s=32)
        cbar = plt_fig.colorbar(sc)
        cbar.draw_all()
        
        best_idx = np.argmin(fi)
        best_x = xi[best_idx]
        
        plt_ax.set_title(f"Best so far {fi[best_idx][0]:.3g} at [{best_x[0]:.3g} {best_x[1]:.3g} {best_x[2]:.3g}] (idx {best_idx})" )
        plt_fig.canvas.draw()
        plt_fig.canvas.flush_events()
        if savePlots:
            plt.savefig(f"gekpls_{i}.png")
        
        # Optimize on the model
        x_start = best_x
        
        opt_res = scipy.optimize.minimize(
            lambda x: float(gpMdl.predict_values(np.array([x])).flat[0] - 1.0 * np.sqrt(gpMdl.predict_variances(np.array([x])).flat[0])),
            x_start,
            method="SLSQP",
            bounds=xlimits,
            options={"maxiter": 500})
        print(f'LCB opt status: {opt_res.status} evals: {opt_res.nfev} value: {opt_res.fun}')
        
        x_new = opt_res.x
        print(f'next x: {x_new}')
   
        # Render
        ialt.forward(x_new)
        # t.frame()
        f, g = ialt.backward()

        # Append new position, loss and gradients
        print(f'new (f,g): {f} {g}')
        xi = np.concatenate((xi, np.array([x_new])))
        fi = np.concatenate((fi, np.array([[f]])))
        df = np.concatenate((df, np.array([g])))
        
    elapsed_time = time.time() - start_time
    print('Done! Elapsed time: {}s'.format(elapsed_time))
    # t.showWindow(False)

    # Rerender the scene with the best params found and take a screenshot
    ialt.forward(xi[best_idx])
    ta.frame()
    opti_img = ta.captureCurrentFrame()

    # Create summary plot with screenshots and loss history
    fig, axs = plt.subplots(2, 2, figsize=(10, 10))
    axs[0][0].plot(fi, markevery= [best_idx], marker= 'X', markerfacecolor= 'red', markersize=10, markeredgewidth= 0)
    axs[0][0].set_xlabel('iteration')
    axs[0][0].set_ylabel('Loss')
    axs[0][0].set_ylim(0, 0.4)
    axs[0][0].set_title('Parameter error plot')

    axs[0][1].imshow(init_image)
    axs[0][1].axis('off')
    axs[0][1].set_title('Initial Image')

    axs[1][0].imshow(opti_img)
    axs[1][0].axis('off')
    axs[1][0].set_title(f'Optimized image (idx {best_idx})')

    axs[1][1].imshow(target_image)
    axs[1][1].axis('off')
    axs[1][1].set_title('Target Image')
    plt.show()


def main():
    testRabbitScenePath = os.path.dirname(__file__) + "/rabbit.gltf"
    warnings.filterwarnings("ignore")
    runEGOmanually(testRabbitScenePath)

if __name__ == '__main__':
    main()
