import os
import glob

from openpilot.common.transformations.camera import _ar_ox_fisheye, _os_fisheye
from openpilot.common.transformations.model import MEDMODEL_INPUT_SIZE
from openpilot.system.hardware import HARDWARE, PC

Import('env', 'arch', 'release')
lenv = env.Clone()
tinygrad_files = ["#"+x for x in glob.glob(env.Dir("#tinygrad_repo").relpath + "/**", recursive=True, root_dir=env.Dir("#").abspath) if 'pycache' not in x]


def get_camera_configs():
  DEVICE_RESOLUTIONS = {
    "tici": (_ar_ox_fisheye.width, _ar_ox_fisheye.height),
    "tizi": (_ar_ox_fisheye.width, _ar_ox_fisheye.height),
    "mici": (_os_fisheye.width, _os_fisheye.height),
  }
  if release or PC or 'CI' in os.environ:
    return set(DEVICE_RESOLUTIONS.values())
  return [DEVICE_RESOLUTIONS[HARDWARE.get_device_type()]]

CAMERA_CONFIGS = get_camera_configs()

tg_flags = {
  'larch64': 'DEV=QCOM FLOAT16=1 NOLOCALS=1 JIT_BATCH_SIZE=0',
  'Darwin': f'DEV=CPU HOME={os.path.expanduser("~")}',
}.get(arch, 'DEV=CPU:LLVM')

image_flag = {
  'larch64': 'IMAGE=2',
}.get(arch, 'IMAGE=0')

model_w, model_h = MEDMODEL_INPUT_SIZE
from openpilot.selfdrive.modeld.constants import ModelConstants
frame_skip = ModelConstants.MODEL_RUN_FREQ // ModelConstants.MODEL_CONTEXT_FREQ
camera_res_args = ' '.join(f'{cw}x{ch}' for cw, ch in CAMERA_CONFIGS)

pythonpath_string = 'PYTHONPATH="${PYTHONPATH}:' + env.Dir("#tinygrad_repo").abspath + ':' + env.Dir("#").abspath + '"'
compile_modeld_script = File("compile_modeld.py").abspath
upstream_compile_script = File(Dir("#selfdrive/modeld").File("compile_modeld.py").abspath)
script_deps = [File("compile_modeld.py"), upstream_compile_script]

def compile_combined(model_type, onnx_args, output_name):
  output_pkl = File(f"models/{output_name}").abspath
  cmd = (f'{pythonpath_string} {tg_flags} {image_flag} python3 {compile_modeld_script} '
         f'--model-type {model_type} '
         f'--model-size {model_w}x{model_h} '
         f'--camera-resolutions {camera_res_args} '
         f'{onnx_args} '
         f'--frame-skip {frame_skip} '
         f'--output {output_pkl}')
  onnx_files = [f for f in onnx_args.split() if f.endswith('.onnx')]
  return lenv.Command(output_pkl, tinygrad_files + script_deps + [File(f) for f in onnx_files if os.path.isfile(f)], cmd)

# Vision + Policy (stock default model)
vision_onnx = File("models/driving_vision.onnx").abspath
policy_onnx = File("models/driving_policy.onnx").abspath
if os.path.isfile(vision_onnx) and os.path.isfile(policy_onnx):
  compile_combined('vision_policy',
                   f'--vision-onnx {vision_onnx} --policy-onnx {policy_onnx}',
                   'driving_combined_tinygrad.pkl')

# Vision + Off-Policy
off_policy_onnx = File("models/driving_off_policy.onnx").abspath
if os.path.isfile(vision_onnx) and os.path.isfile(off_policy_onnx):
  policy_arg = f'--policy-onnx {policy_onnx}' if os.path.isfile(policy_onnx) else ''
  compile_combined('vision_multi_policy',
                   f'--vision-onnx {vision_onnx} {policy_arg} --off-policy-onnx {off_policy_onnx}',
                   'driving_combined_multi_tinygrad.pkl')

# Vision + On-Policy + Off-Policy
on_policy_onnx = File("models/driving_on_policy.onnx").abspath
if os.path.isfile(vision_onnx) and os.path.isfile(on_policy_onnx) and os.path.isfile(off_policy_onnx):
  compile_combined('vision_multi_policy',
                   f'--vision-onnx {vision_onnx} --off-policy-onnx {off_policy_onnx} --on-policy-onnx {on_policy_onnx}',
                   'driving_combined_tri_tinygrad.pkl')

# Supercombo
supercombo_onnx = File("models/supercombo.onnx").abspath
if os.path.isfile(supercombo_onnx):
  compile_combined('supercombo',
                   f'--supercombo-onnx {supercombo_onnx}',
                   'driving_combined_supercombo_tinygrad.pkl')

if PC:
  inputs = tinygrad_files + [File(Dir("#sunnypilot/modeld_v2").File("install_models_pc.py").abspath)]
  outputs = []
  model_dir = Dir("models").abspath
  cmd = f'python3 {Dir("#sunnypilot/modeld_v2").abspath}/install_models_pc.py {model_dir}'

  for model_name in ['supercombo', 'driving_vision', 'driving_off_policy', 'driving_on_policy', 'driving_policy']:
    if File(f"models/{model_name}.onnx").exists():
        inputs.append(File(f"models/{model_name}.onnx"))
        inputs.append(File(f"models/{model_name}_tinygrad.pkl"))
        outputs.append(File(f"models/{model_name}_metadata.pkl"))
  if outputs:
    lenv.Command(outputs, inputs, cmd)

