import numpy as np
import subprocess
import itertools
import scipy
import scipy.spatial
import os
import re

ground_truth_params = np.array([[0.469423, 2.35922, 2.00195, 16.8788/4],
                                [-1.4456, 6.0325, -1.19891, 8.16044/4],
                                [-2.15449, 6.39304, 4.48958, 20.6988/4],
                                [-3.93203, 1.70284, 7.0013, 12.943/4]])
permut = np.array(list(itertools.permutations(np.arange(4))))
def groundTruthLoss(params):
    cdist = scipy.spatial.distance.cdist(params, ground_truth_params)
    dists = np.sum(cdist[[0, 1, 2, 3], permut], axis=1)
    return np.min(dists)

exe_path = "_install/bin/inter_adj_light_trace"
scene_path = "../tamashii-resources/ialt_paper_scenes/pcon/large_office/gltf/large_office_tgt_cams.glb"
out_dir = "_workdir_largeOffice_cmp/data/"
os.makedirs(out_dir, exist_ok=True)

for step_size in [0.3]:
    for num_rays in 2 ** (8+np.arange(4)):
        cmd = [exe_path,
            "load_scene", scene_path,
            "-runPredefinedTest", "large-office-cmp-%.2f" % step_size,
            "-headless", "1",
            "-numRaysXperLight", str(num_rays),
            "-numRaysYperLight", str(num_rays),
            ]
        out_path = out_dir + "tamashii_largeOffice_%drays_ADAM%.2f.txt" % (num_rays, step_size)
        print(out_path)

        output = subprocess.run(cmd, cwd="../../", env=os.environ, capture_output=True)
        lines = re.findall("[\d:.]+\s+info\s+params[^\]]+.", output.stdout.decode())

        losses = []
        times = []
        start_time = -1
        for line in lines:
            groups = re.match("([\d:.]+)\s+info\s+params = \[([^\]]+).", line)
            timestamp = sum(np.array([float(n) for n in re.split("[:.]", groups[1])]) * [60*60, 60, 1, 1e-3])
            if start_time == -1:
                start_time = timestamp
            params = np.array([float(n) for n in re.split("\s+", groups[2].strip())]).reshape((4, -1))
            if params.shape != ground_truth_params.shape:
                continue
            params[:,3] = (0.5 * params[:,3] ** 2) / 12.566370614359172
            times.append(timestamp - start_time)
            losses.append(groundTruthLoss(params))

        with open(out_path, "w") as f:
            f.write(",".join(map(str, losses)) + "\n")
            f.write(",".join(map(str, times)) + "\n")
