import numpy as np
import onnxruntime as rt
import onnx

#Create artifacts
if True:
    output_dir = '/outputs//TDA4_TIDL/'
    onnx_model_path = 'test.onnx'
    onnx.shape_inference.infer_shapes_path(onnx_model_path, onnx_model_path)

    # model compilation options
    compile_options = {
        'tidl_tools_path': tidl_tools_path,
        'artifacts_folder': output_dir,
        'tensor_bits': 8,
        'accuracy_level': 0,
        'advanced_options:calibration_frames': 1, 
        'advanced_options:calibration_iterations': 1,
        'advanced_options:quantization_scale_type': 1,
        'debug_level': 0,
    }

    # create the output dir if not present
    # clear the directory
    os.makedirs(output_dir, exist_ok=True)
    for root, dirs, files in os.walk(output_dir, topdown=False):
        [os.remove(os.path.join(root, f)) for f in files]
        [os.rmdir(os.path.join(root, d)) for d in dirs]

    so = rt.SessionOptions()

    EP_list = ['TIDLCompilationProvider']#, 'CPUExecutionProvider']
    sess = rt.InferenceSession(onnx_model_path, providers=EP_list, provider_options=[compile_options], sess_options=so)
    for num in range(1):
        img = np.random.rand(1,3,96, 896).astype(np.float32) 
        onnx_output = sess.run(None, {sess.get_inputs()[0].name: img})

#Use compiled model for inference
if True:
    output_dir = '/outputs/TDA4_TIDL/'
    onnx_model_path = test.onnx'
    onnx.shape_inference.infer_shapes_path(onnx_model_path, onnx_model_path)

    compile_options = {
        'tidl_tools_path': tidl_tools_path,
        'artifacts_folder': output_dir,,
        'debug_level': 3,
    }

    so = rt.SessionOptions()
    EP_list = ['TIDLExecutionProvider']
    test_images = np.load(test_set.npy')
    depthnet_outputs = np.load('output.npy')
    sess = rt.InferenceSession(onnx_model_path, providers=EP_list, provider_options=[compile_options], sess_options=so)
    for num in tqdm.trange(len(test_images)):
        img = test_images[num]
        img = np.expand_dims(img, axis=0)
        img = np.moveaxis(img, 3, 1)
        onnx_output = sess.run(None, {sess.get_inputs()[0].name : img})
        depthnet_output = depthnet_outputs[num] #np.moveaxis(depthnet_outputs[num], 3, 1)
        diff = onnx_output[0] - depthnet_output[:, 0:1, ...]

