import loader
loader.loadTamashii()

import pyialt as ta
from adam import Adam
import numpy as np
import os
import time
from matplotlib import pyplot as plt
import matplotlib.cm as cm

def showGradImage(light):
    gimg = light.getGradImage(ta.impls.ialt.Param.X, 100)[:,:,0]

    cmap = cm.coolwarm_r
    vlim = np.max(np.abs(gimg))
    print(f'Remapping colors within range: [{-vlim:.2f}, {vlim:.2f}]')

    fig, axx = plt.subplots(1,1, figsize=(8, 8))
    axx.imshow(gimg, cmap=cmap, vmin=-vlim, vmax=vlim)
    axx.set_title('Full gradients')
    axx.axis('off')
    fig.tight_layout()
    plt.show()

def runGradientDecent(scenePath, iterations):
    ta.var.log_level = "none"
    ta.var.default_camera = "Camera"
    ta.var.numRaysXperLight = 2000
    ta.var.numRaysYperLight = 2000
    ta.var.constRandomSeed = True
    ta.var.usePathTracing = [False, False]
    ialt = ta.impls.ialt
    
    light_pos_hist = []

    s = ta.openScene(scenePath)
    light = s.lights[0]
    light.intensity = 100
    light_pos_hist += [light.position]
    # showGradImage(light)

    # prepare init
    ta.frame()
    init_image = ta.captureCurrentFrame()

    # t.showWindow(True)
    # optimize
    print('Start...')
    loss_hist = []
    params = s.lightsToParameterVector()
    adam = Adam(params, 0.02)
    start_time = time.time()
    for i in range(iterations):
        ialt.forward(params)
        # t.frame()
        phi, grads= ialt.backward()
        
        adam.step(grads)

        light_pos_hist += [s.lights[0].position]
        loss_hist.append(phi)
        print('Iteration: {}, Loss: {}'.format(i, phi))
        
    elapsed_time = time.time() - start_time
    print('Done! Elapsed time: {}s'.format(elapsed_time))
    # t.showWindow(False)

    ta.frame()
    #t.impls.path_tracer.switch()
    #opti_img = t.impls.path_tracer.render(200)
    opti_img = ta.captureCurrentFrame()

    fig, axs = plt.subplots(2, 2, figsize=(10, 10))
    axs[0][0].plot(loss_hist)
    axs[0][0].set_xlabel('iteration')
    axs[0][0].set_ylabel('Loss')
    axs[0][0].set_title('Parameter error plot')

    axs[0][1].imshow(init_image)
    axs[0][1].axis('off')
    axs[0][1].set_title('Initial Image')

    axs[1][0].imshow(opti_img)
    axs[1][0].axis('off')
    axs[1][0].set_title('Optimized image')

    # axs[1][1].imshow(target_image)
    axs[1][1].axis('off')
    axs[1][1].set_title('Target Image')
    plt.show()

    with open(os.path.dirname(__file__) + "/output.txt", "w") as txt_file:
        for pos in light_pos_hist:
            txt_file.write('{} {} {}\n'.format(pos[0],pos[1],pos[2]))
            print(pos)


def main():
    scene = "C:/Users/llipp/Nextcloud/tamashii-resources/ialt_paper_scenes/pcon/simple_office/gltf_v3/simple_office_v3_withCamera2.gltf"
    runGradientDecent(scene, 100)

if __name__ == '__main__':
    main()
