import os
import loader
launchDir = os.getcwd()
loader.loadTamashii()

import pyialt as ta
from adam import Adam
import numpy as np
import time
from matplotlib import pyplot as plt
import matplotlib.cm as cm

def showGradImage(light):
    gimg = light.getGradImage(ta.impls.ialt.Param.X, 100)[:,:,0]

    cmap = cm.coolwarm_r
    vlim = np.max(np.abs(gimg))
    print(f'Remapping colors within range: [{-vlim:.2f}, {vlim:.2f}]')

    fig, axx = plt.subplots(1,1, figsize=(8, 8))
    axx.imshow(gimg, cmap=cmap, vmin=-vlim, vmax=vlim)
    axx.set_title('Full gradients')
    axx.axis('off')
    fig.tight_layout()
    plt.show()

def finitDiffGrad(ialt, params, h=0.1):
    grads = np.zeros(len(params))
    for j in range(len(params)):
        params[j] -= h
        ialt.forward(params)
        phim, _= ialt.backward()
        params[j] += h + h
        ialt.forward(params)
        phip, _= ialt.backward()
        params[j] -= h
        ialt.forward(params)
        phi, _= ialt.backward()
        grads[j] = (phip - phim) / (2 * h)
    return phi, grads

def runGradientDecent(scenePath, iterations):
    ta.var.log_level = "none"
    ta.var.default_camera = "Camera.001"
    ta.var.numRaysXperLight = 2000
    ta.var.numRaysYperLight = 2000
    ta.var.constRandomSeed = True
    ta.var.usePathTracing = [True, False]
    ialt = ta.impls.ialt

    init = [2,2,-2]
    target = [-2,2,2]
    
    s = ta.openScene(scenePath)
    light = s.lights[0]
    light.optimize.posX = light.optimize.posY = light.optimize.posZ = True
    # showGradImage(light)

    # prepare target
    light.position = target
    ta.frame()
    ialt.currentRadianceAsTarget(True)
    target_image = ialt.getTargetImage()
    # prepare init 
    light.position = init
    ta.frame()
    init_image = ta.captureCurrentFrame()

    # ta.openWindow()
    # optimize
    print('Start...')
    loss_hist = []
    params = s.lightsToParameterVector()
    adam = Adam(params, 0.25)
    start_time = time.time()
    for i in range(iterations):
        ialt.forward(params)
        # if not ta.frame(): 
        #     ta.closeWindow()
        # phi, grads = finitDiffGrad(ialt, params, h = 0.15)
        phi, grads = ialt.backward()
        adam.step(grads)

        current = s.lights[0].position
        params_diff = np.sqrt(np.sum(np.power(np.subtract(current, target),2)))
        loss_hist.append(phi)
        print('Iteration: {}, Distance: {}, Loss: {}'.format(i, params_diff, phi))
        
    elapsed_time = time.time() - start_time
    print('Done! Elapsed time: {}s'.format(elapsed_time))
    # ta.closeWindow()

    ta.frame()
    #t.impls.path_tracer.switch()
    #opti_img = t.impls.path_tracer.render(200)
    opti_img = ta.captureCurrentFrame()

    fig, axs = plt.subplots(2, 2, figsize=(10, 10))
    axs[0][0].plot(loss_hist)
    axs[0][0].set_xlabel('iteration')
    axs[0][0].set_ylabel('Loss')
    axs[0][0].set_title('Parameter error plot')

    axs[0][1].imshow(init_image)
    axs[0][1].axis('off')
    axs[0][1].set_title('Initial Image')

    axs[1][0].imshow(opti_img)
    axs[1][0].axis('off')
    axs[1][0].set_title('Optimized image')

    axs[1][1].imshow(target_image)
    axs[1][1].axis('off')
    axs[1][1].set_title('Target Image')
    plt.show()


def main():
    testRabbitScenePath = os.path.join(launchDir , "rabbit.gltf" )
    print(testRabbitScenePath)
    runGradientDecent(testRabbitScenePath, 100)

if __name__ == '__main__':
    main()
