import loader
loader.loadTamashii()

import pyialt as ta
from adam import Adam
import numpy as np
import os
import time
from matplotlib import pyplot as plt
import matplotlib.cm as cm

def runGradientDecent(scenePath, iterations):
    ta.var.log_level = "none"
    ta.var.default_camera = "Camera.001"
    ta.var.numRaysXperLight = 2000
    ta.var.numRaysYperLight = 2000
    ta.var.constRandomSeed = True
    ta.var.usePathTracing = [False, False]
    ialt = ta.impls.ialt
    showWindow = False

    s = ta.openScene(scenePath)

    # find emissive mesh
    for mo in s.models:
        for me in mo.meshes:
            # set target for whole scene
            #me.setTarget([0,0,1])
            if np.any(np.array(me.emissiveFactor) > 0):
                print("There is light")
                #me.optimize.emissiveStrength = False
                #me.optimize.emissiveFactor = False
                #me.emissiveFactor = [0,0,1]

    # current
    ta.impls.path_tracer.switch()
    init_image = ta.impls.path_tracer.render(100)
    ialt.switch()

    if showWindow: ta.openWindow()
    # optimize
    print('Start...')
    loss_hist = []
    params = s.lightsToParameterVector()
    adam = Adam(params, 0.25)
    start_time = time.time()
    for i in range(iterations):
        ialt.forward(params)
        if showWindow: ta.frame()
        phi, grads= ialt.backward()
        
        adam.step(grads)

        loss_hist.append(phi)
        print('Iteration: {}, Loss: {}'.format(i, phi))
        
    elapsed_time = time.time() - start_time
    print('Done! Elapsed time: {}s'.format(elapsed_time))
    if showWindow: ta.closeWindow()

    ta.impls.path_tracer.switch()
    opti_img = ta.impls.path_tracer.render(100)

    fig, axs = plt.subplots(1, 3, figsize=(17, 5))
    axs[0].plot(loss_hist)
    axs[0].set_xlabel('Iteration')
    axs[0].set_ylabel('Loss')
    axs[0].set_title('L2 Loss')

    axs[1].imshow(init_image)
    axs[1].axis('off')
    axs[1].set_title('Initial Image')

    axs[2].imshow(opti_img)
    axs[2].axis('off')
    axs[2].set_title('Optimized image')
    plt.show()


def main():
    testRabbitScenePath = os.path.dirname(__file__) + "/emissive_test/emissive_test.gltf"
    runGradientDecent(testRabbitScenePath, 100)

if __name__ == '__main__':
    main()
