import os
import glob
import subprocess

PREFIX = "arm-none-eabi-"
BUILDER = "DEV"

LOCAL_OBJ_DIR = Dir("./obj")
PUBLIC_OBJ_DIR = Dir("../../obj")

common_flags = []

if os.getenv("RELEASE"):
  BUILD_TYPE = "RELEASE"
  cert_fn = os.path.abspath(os.path.join("../../../", os.getenv("CERT")))
  print(cert_fn)
  assert cert_fn is not None, "No certificate file specified. Please set CERT env variable"
  assert os.path.exists(cert_fn), "Certificate file not found. Please specify absolute path"
else:
  BUILD_TYPE = "DEBUG"
  cert_fn = File("../../certs/debug").srcnode().relpath
  common_flags += ["-DALLOW_DEBUG"]

if os.getenv("DEBUG"):
  common_flags += ["-DDEBUG"]

def objcopy(source, target, env, for_signature):
  return f"$OBJCOPY -O binary {source[0]} {target[0]}"

def get_libgcc():
  libgcc = subprocess.check_output([
    PREFIX + "gcc",
    "-mcpu=cortex-m4",
    "-mthumb",
    "-mhard-float",
    "-mfpu=fpv4-sp-d16",
    "-print-libgcc-file-name",
  ], encoding="utf8").strip()

  if os.path.exists(libgcc):
    return libgcc

  # The vendored toolchain in this repo can return just "libgcc.a" here.
  search_dirs = subprocess.check_output([PREFIX + "gcc", "-print-search-dirs"], encoding="utf8")
  install_dir = next(line.split(": ", 1)[1].strip() for line in search_dirs.splitlines() if line.startswith("install:"))
  candidates = sorted(glob.glob(os.path.join(install_dir, "**", "libgcc.a"), recursive=True))
  assert len(candidates) > 0, "Unable to locate libgcc.a for body_v1_f4"
  return candidates[0]

def get_version(builder, build_type):
  try:
    git = subprocess.check_output(["git", "rev-parse", "--short=8", "HEAD"], encoding="utf8").strip()
  except subprocess.CalledProcessError:
    git = "00000000"
  return f"{git}"

def to_c_uint32(x):
  nums = []
  for _ in range(0x20):
    nums.append(x % (2 ** 32))
    x //= (2 ** 32)
  return "{" + "U,".join(map(str, nums)) + "U}"

def get_key_header(name):
  from Crypto.PublicKey import RSA

  public_fn = File(f"../../certs/{name}.pub").srcnode().get_path()
  with open(public_fn) as f:
    rsa = RSA.importKey(f.read())
  assert rsa.size_in_bits() == 1024

  rr = pow(2 ** 1024, 2, rsa.n)
  n0inv = 2 ** 32 - pow(rsa.n, -1, 2 ** 32)

  return [
    f"RSAPublicKey {name}_rsa_key = {{",
    "  .len = 0x20,",
    f"  .n0inv = {n0inv}U,",
    f"  .n = {to_c_uint32(rsa.n)},",
    f"  .rr = {to_c_uint32(rr)},",
    f"  .exponent = {rsa.e},",
    "};",
  ]

includes = [
  Dir("../../..").abspath,
  Dir(".").abspath,
  Dir("..").abspath,
  Dir("../bldc").abspath,
  Dir("./inc").abspath,
  Dir("./inc/STM32F4xx_HAL_Driver/Inc").abspath,
  Dir("../..").abspath,
]

c_sources = [
  ["hal_rcc", "inc/STM32F4xx_HAL_Driver/Src/stm32f4xx_hal_rcc.c"],
  ["hal_tim", "inc/STM32F4xx_HAL_Driver/Src/stm32f4xx_hal_tim.c"],
  ["hal_tim_ex", "inc/STM32F4xx_HAL_Driver/Src/stm32f4xx_hal_tim_ex.c"],
  ["hal_adc_ex", "inc/STM32F4xx_HAL_Driver/Src/stm32f4xx_hal_adc_ex.c"],
  ["hal_cortex", "inc/STM32F4xx_HAL_Driver/Src/stm32f4xx_hal_cortex.c"],
  ["hal_gpio", "inc/STM32F4xx_HAL_Driver/Src/stm32f4xx_hal_gpio.c"],
  ["hal_rcc_ex", "inc/STM32F4xx_HAL_Driver/Src/stm32f4xx_hal_rcc_ex.c"],
  ["hal", "inc/STM32F4xx_HAL_Driver/Src/stm32f4xx_hal.c"],
  ["hal_adc", "inc/STM32F4xx_HAL_Driver/Src/stm32f4xx_hal_adc.c"],
  ["system", "inc/system_stm32f4xx.c"],
  ["it", "inc/stm32f4xx_it.c"],
  ["bldc", "bldc.c"],
  ["bldc_data", "../bldc/BLDC_controller_data.c"],
  ["bldc_ctrl", "../bldc/BLDC_controller.c"],
  ["util", "util.c"],
]

c_bootstub_sources = [
  ["hal_rcc", "inc/STM32F4xx_HAL_Driver/Src/stm32f4xx_hal_rcc.c"],
  ["hal_cortex", "inc/STM32F4xx_HAL_Driver/Src/stm32f4xx_hal_cortex.c"],
  ["hal_gpio", "inc/STM32F4xx_HAL_Driver/Src/stm32f4xx_hal_gpio.c"],
  ["hal_rcc_ex", "inc/STM32F4xx_HAL_Driver/Src/stm32f4xx_hal_rcc_ex.c"],
  ["hal", "inc/STM32F4xx_HAL_Driver/Src/stm32f4xx_hal.c"],
  ["system", "inc/system_stm32f4xx.c"],
  ["it", "inc/stm32f4xx_it.c"],
  ["util", "util.c"],
]

flags = [
  "-Wall",
  "-Wextra",
  "-Wstrict-prototypes",
  "-Werror",
  "-mlittle-endian",
  "-mthumb",
  "-nostdlib",
  "-fno-builtin",
  "-std=gnu11",
  "-fdata-sections",
  "-ffunction-sections",
  "-Wl,--gc-sections",
  f"-T{File('./stm32fx_flash.ld').srcnode().abspath}",
  "-mcpu=cortex-m4",
  "-mhard-float",
  "-DSTM32F4",
  "-DSTM32F413xx",
  "-mfpu=fpv4-sp-d16",
  "-fsingle-precision-constant",
  "-Os",
  "-g",
] + common_flags

libgcc = get_libgcc()

env = Environment(
  ENV=os.environ,
  CC=PREFIX + "gcc",
  AS=PREFIX + "gcc",
  OBJCOPY=PREFIX + "objcopy",
  OBJDUMP=PREFIX + "objdump",
  ASCOM="$AS $ASFLAGS -o $TARGET -c $SOURCES",
  CFLAGS=flags,
  ASFLAGS=flags,
  LINKFLAGS=flags,
  CPPPATH=includes,
  BUILDERS={
    "Objcopy": Builder(generator=objcopy, suffix=".bin", src_suffix=".elf"),
  },
)

obj_dir = os.path.dirname(File("./obj/gitversion.h").srcnode().abspath)
os.makedirs(obj_dir, exist_ok=True)

git_ver = get_version(BUILDER, BUILD_TYPE)
with open(File("./obj/gitversion.h").srcnode().abspath, "w") as f:
  f.write(f'const uint8_t gitversion[9] = "{git_ver}";\n')

certs = [get_key_header(name) for name in ["debug", "release"]]
with open(File("./obj/cert.h").srcnode().abspath, "w") as f:
  for cert in certs:
    f.write("\n".join(cert) + "\n")

def obj_path(name, source_path):
  if name.startswith("hal"):
    return f"obj/hal/{name}"
  if name in ("bldc", "bldc_data", "bldc_ctrl"):
    return f"obj/bldc/{name}"
  return f"obj/{name}"

startup = env.Object("obj/startup_stm32f413xx", "startup_stm32f413xx.s")

bootstub_objects = [startup]
for obj_name, source_path in c_bootstub_sources:
  bootstub_objects.append(env.Object(f"{obj_path(obj_name, source_path)}-bootstub", source_path))
bootstub_objects += [
  env.Object("obj/rsa-bootstub", "../../crypto/rsa.c"),
  env.Object("obj/sha-bootstub", "../../crypto/sha.c"),
  env.Object("obj/body_v1_bootstub", "bootstub.c"),
]
bootstub_elf = env.Program("obj/bootstub.body_v1_f4.elf", bootstub_objects + [libgcc])
env.Objcopy("../../obj/bootstub.body_v1_f4.bin", bootstub_elf)

main_objects = [startup, env.Object("obj/body_v1_main", "main.c")]
for obj_name, source_path in c_sources:
  main_objects.append(env.Object(f"{obj_path(obj_name, source_path)}-main", source_path))

main_elf = env.Program(
  "obj/body_v1_f4.elf",
  main_objects + [libgcc],
  LINKFLAGS=[f"-Wl,--section-start,.isr_vector=0x08004000"] + flags,
)
main_bin = env.Objcopy("obj/body_v1_f4.bin", main_elf)
sign_py = File("../../crypto/sign.py").srcnode().relpath
env.Command("../../obj/body_v1_f4.bin.signed", main_bin, f"SETLEN=1 {sign_py} $SOURCE $TARGET {cert_fn}")
