mirror of
https://github.com/firestar5683/StarPilot.git
synced 2026-06-28 01:52:06 +08:00
388 lines
13 KiB
Python
388 lines
13 KiB
Python
#!/usr/bin/env python3
|
|
import argparse
|
|
import codecs
|
|
import json
|
|
import os
|
|
import pickle
|
|
import subprocess
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
REPO_ROOT = Path(__file__).resolve().parents[1]
|
|
if str(REPO_ROOT) not in sys.path:
|
|
sys.path.insert(0, str(REPO_ROOT))
|
|
|
|
DEFAULT_INPUT_ROOT = Path("/data/openpilot/uncompiledmodels")
|
|
DEFAULT_OUTPUT_ROOT = Path("/data/openpilot/compiledmodels")
|
|
COMPILE_SCRIPT = REPO_ROOT / "tinygrad_repo/examples/openpilot/compile3.py"
|
|
DRIVING_COMPILE_SCRIPT = REPO_ROOT / "selfdrive/modeld/compile_modeld.py"
|
|
DM_WARP_COMPILE_SCRIPT = REPO_ROOT / "selfdrive/modeld/compile_dm_warp.py"
|
|
MODEL_VERSIONS_CACHE = Path("/data/models/.model_versions.json")
|
|
|
|
DM_MODEL_KEY = "dm"
|
|
DM_MODEL_NAME = "dmonitoring_model"
|
|
DM_TARGET_ALIASES = {DM_MODEL_KEY, "dmonitoring", DM_MODEL_NAME}
|
|
DM_INPUT_CANDIDATES = ("dmonitoring_model.onnx", "dmonitoring.onnx", "dm.onnx")
|
|
|
|
COMPONENT_ALIASES = {
|
|
"driving_supercombo": ("driving_supercombo", "supercombo"),
|
|
"driving_off_policy": ("driving_off_policy", "off_policy", "offpolicy"),
|
|
"driving_on_policy": ("driving_on_policy", "on_policy", "onpolicy"),
|
|
"driving_policy": ("driving_policy", "policy"),
|
|
"driving_vision": ("driving_vision", "vision"),
|
|
}
|
|
DEFAULT_CAMERA_RESOLUTIONS = ((1928, 1208), (1344, 760))
|
|
MEDMODEL_INPUT_SIZE = (512, 256)
|
|
DM_INPUT_SIZE = (1440, 960)
|
|
MODEL_RUN_FREQ = 20
|
|
MODEL_CONTEXT_FREQ = 5
|
|
|
|
|
|
def build_compile_env() -> dict[str, str]:
|
|
env = os.environ.copy()
|
|
pythonpath = env.get("PYTHONPATH", "")
|
|
env["PYTHONPATH"] = f"{REPO_ROOT}:{pythonpath}" if pythonpath else str(REPO_ROOT)
|
|
for key, default in {
|
|
"DEBUG": "0",
|
|
"FLOAT16": "1",
|
|
"IMAGE": "2",
|
|
"JIT_BATCH_SIZE": "0",
|
|
"NOLOCALS": "1",
|
|
"OPENPILOT_HACKS": "1",
|
|
}.items():
|
|
try:
|
|
int(str(env.get(key)), 0)
|
|
except (TypeError, ValueError):
|
|
env[key] = default
|
|
return env
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
parser = argparse.ArgumentParser(
|
|
description="Compile staged ONNX models into StarPilot's unified tinygrad artifact format.",
|
|
)
|
|
parser.add_argument("--model", help="Output model ID, for example sc2.")
|
|
parser.add_argument("--dm", action="store_true", help="Build DM model, metadata, and both camera warps.")
|
|
parser.add_argument("--input-dir", type=Path, default=DEFAULT_INPUT_ROOT)
|
|
parser.add_argument("--output-dir", type=Path, default=DEFAULT_OUTPUT_ROOT)
|
|
parser.add_argument(
|
|
"--input-format",
|
|
choices=("auto", "supercombo", "split"),
|
|
default="auto",
|
|
help="Source ONNX layout. Auto prefers supercombo when present.",
|
|
)
|
|
parser.add_argument(
|
|
"--version",
|
|
help="Behavioral model version stored in the artifact. It does not control artifact layout.",
|
|
)
|
|
parser.add_argument("--list", action="store_true", help="List staged models and exit.")
|
|
parser.add_argument("--force", action="store_true", help="Accepted for compatibility; selected outputs are always replaced.")
|
|
|
|
args, unknown = parser.parse_known_args()
|
|
dynamic_flags = [value[2:] for value in unknown if value.startswith("--")]
|
|
invalid = [value for value in unknown if not value.startswith("--")]
|
|
if invalid:
|
|
parser.error(f"Unexpected arguments: {' '.join(invalid)}")
|
|
if len(dynamic_flags) > 1:
|
|
parser.error("Pass only one dynamic model flag, for example ./models --sc2")
|
|
if args.model and dynamic_flags and args.model != dynamic_flags[0]:
|
|
parser.error("Use either --model sc2 or --sc2, not both.")
|
|
args.model = args.model or (dynamic_flags[0] if dynamic_flags else None)
|
|
if args.model and args.model.strip().lower() in DM_TARGET_ALIASES:
|
|
args.dm = True
|
|
args.model = None
|
|
if args.dm and args.model:
|
|
parser.error("Use either --dm or a driving model ID.")
|
|
return args
|
|
|
|
|
|
def detect_component(path: Path) -> str | None:
|
|
stem = path.stem.lower()
|
|
for component, aliases in COMPONENT_ALIASES.items():
|
|
if any(alias in stem for alias in aliases):
|
|
return component
|
|
return None
|
|
|
|
|
|
def _model_key_from_flat_file(path: Path, component: str) -> str | None:
|
|
lowered = path.stem.lower()
|
|
for alias in COMPONENT_ALIASES[component]:
|
|
if lowered == alias:
|
|
return None
|
|
suffix = f"_{alias}"
|
|
if lowered.endswith(suffix):
|
|
key = path.stem[:-len(suffix)]
|
|
return None if key in ("", "driving") else key
|
|
return None
|
|
|
|
|
|
def find_staged_models(input_root: Path) -> dict[str, dict[str, Path]]:
|
|
found: dict[str, dict[str, Path]] = {}
|
|
if not input_root.is_dir():
|
|
return found
|
|
|
|
for child in sorted(input_root.iterdir()):
|
|
if not child.is_dir():
|
|
continue
|
|
files = {
|
|
component: path
|
|
for path in sorted(child.glob("*.onnx"))
|
|
if (component := detect_component(path)) is not None
|
|
}
|
|
if files:
|
|
found[child.name] = files
|
|
|
|
root_files: dict[str, Path] = {}
|
|
for path in sorted(input_root.glob("*.onnx")):
|
|
component = detect_component(path)
|
|
if component is None:
|
|
continue
|
|
model_key = _model_key_from_flat_file(path, component)
|
|
if model_key:
|
|
found.setdefault(model_key, {})[component] = path
|
|
else:
|
|
root_files[component] = path
|
|
if root_files:
|
|
found["_root"] = root_files
|
|
return found
|
|
|
|
|
|
def resolve_model_files(input_root: Path, model_key: str) -> dict[str, Path]:
|
|
staged = find_staged_models(input_root)
|
|
if model_key in staged:
|
|
return staged[model_key]
|
|
root_files = staged.get("_root")
|
|
if root_files and set(staged) == {"_root"}:
|
|
return root_files
|
|
return {
|
|
component: path
|
|
for path in sorted(input_root.glob(f"{model_key}_*.onnx"))
|
|
if (component := detect_component(path)) is not None
|
|
}
|
|
|
|
|
|
def find_staged_dm(input_root: Path) -> Path | None:
|
|
if not input_root.is_dir():
|
|
return None
|
|
for candidate in DM_INPUT_CANDIDATES:
|
|
path = input_root / candidate
|
|
if path.is_file():
|
|
return path
|
|
for child in sorted(input_root.iterdir()):
|
|
if child.is_dir():
|
|
for candidate in DM_INPUT_CANDIDATES:
|
|
path = child / candidate
|
|
if path.is_file():
|
|
return path
|
|
return None
|
|
|
|
|
|
def get_metadata_value_by_name(model, name: str):
|
|
for prop in model.metadata_props:
|
|
if prop.key == name:
|
|
return prop.value
|
|
return None
|
|
|
|
|
|
def write_metadata(onnx_path: Path, output_path: Path) -> dict:
|
|
import onnx
|
|
|
|
model = onnx.load(str(onnx_path))
|
|
output_slices = get_metadata_value_by_name(model, "output_slices")
|
|
if output_slices is None:
|
|
raise ValueError(f"output_slices not found in metadata for {onnx_path.name}")
|
|
|
|
def get_name_and_shape(value_info) -> tuple[str, tuple[int, ...]]:
|
|
shape = tuple(int(dim.dim_value) for dim in value_info.type.tensor_type.shape.dim)
|
|
return value_info.name, shape
|
|
|
|
metadata = {
|
|
"model_checkpoint": get_metadata_value_by_name(model, "model_checkpoint"),
|
|
"output_slices": pickle.loads(codecs.decode(output_slices.encode(), "base64")),
|
|
"input_shapes": dict(get_name_and_shape(value) for value in model.graph.input),
|
|
"output_shapes": dict(get_name_and_shape(value) for value in model.graph.output),
|
|
}
|
|
with open(output_path, "wb") as metadata_file:
|
|
pickle.dump(metadata, metadata_file)
|
|
return metadata
|
|
|
|
|
|
def infer_model_version(model_key: str, explicit_version: str | None) -> str:
|
|
if explicit_version:
|
|
return explicit_version.strip()
|
|
if MODEL_VERSIONS_CACHE.is_file():
|
|
try:
|
|
version = json.loads(MODEL_VERSIONS_CACHE.read_text()).get(model_key)
|
|
if isinstance(version, str):
|
|
return version.strip()
|
|
except Exception:
|
|
pass
|
|
return ""
|
|
|
|
|
|
def select_input_format(requested: str, files: dict[str, Path]) -> str:
|
|
if requested == "supercombo":
|
|
if "driving_supercombo" not in files:
|
|
raise SystemExit("--input-format supercombo requires driving_supercombo.onnx")
|
|
return requested
|
|
if requested == "split":
|
|
return requested
|
|
return "supercombo" if "driving_supercombo" in files else "split"
|
|
|
|
|
|
def driving_compile_args(files: dict[str, Path], input_format: str) -> tuple[str, list[str]]:
|
|
if input_format == "supercombo":
|
|
return "supercombo", ["--supercombo-onnx", str(files["driving_supercombo"])]
|
|
|
|
vision = files.get("driving_vision")
|
|
primary = files.get("driving_on_policy") or files.get("driving_policy")
|
|
off_policy = files.get("driving_off_policy")
|
|
if vision is None or primary is None:
|
|
missing = [
|
|
name for name, present in (
|
|
("driving_vision", vision),
|
|
("driving_policy or driving_on_policy", primary),
|
|
) if present is None
|
|
]
|
|
raise SystemExit(f"Missing required split ONNX files: {', '.join(missing)}")
|
|
|
|
args = ["--vision-onnx", str(vision)]
|
|
if off_policy is None:
|
|
args += ["--policy-onnx", str(primary)]
|
|
return "vision_policy", args
|
|
|
|
args += ["--on-policy-onnx", str(primary), "--off-policy-onnx", str(off_policy)]
|
|
return "vision_multi_policy", args
|
|
|
|
|
|
def remove_paths(paths: list[Path]) -> int:
|
|
count = 0
|
|
for path in paths:
|
|
if path.is_file() or path.is_symlink():
|
|
path.unlink()
|
|
count += 1
|
|
return count
|
|
|
|
|
|
def compile_driving(model_key: str, files: dict[str, Path], input_format: str, version: str, output_dir: Path) -> Path:
|
|
model_type, source_args = driving_compile_args(files, input_format)
|
|
output_path = output_dir / f"{model_key}_driving_tinygrad.pkl"
|
|
removed = remove_paths([
|
|
output_path,
|
|
*output_dir.glob(f"{model_key}_driving_*_tinygrad.pkl"),
|
|
*output_dir.glob(f"{model_key}_driving_*_metadata.pkl"),
|
|
])
|
|
if removed:
|
|
print(f" cleared {removed} existing output entries for {model_key}")
|
|
|
|
frame_skip = MODEL_RUN_FREQ // MODEL_CONTEXT_FREQ
|
|
command = [
|
|
sys.executable,
|
|
str(DRIVING_COMPILE_SCRIPT),
|
|
"--model-type",
|
|
model_type,
|
|
"--model-size",
|
|
f"{MEDMODEL_INPUT_SIZE[0]}x{MEDMODEL_INPUT_SIZE[1]}",
|
|
"--camera-resolutions",
|
|
*(f"{width}x{height}" for width, height in DEFAULT_CAMERA_RESOLUTIONS),
|
|
"--output",
|
|
str(output_path),
|
|
"--frame-skip",
|
|
str(frame_skip),
|
|
*source_args,
|
|
]
|
|
if version:
|
|
command += ["--behavior-version", version]
|
|
subprocess.run(command, cwd=REPO_ROOT, env=build_compile_env(), check=True)
|
|
return output_path
|
|
|
|
|
|
def compile_dm(onnx_path: Path, output_dir: Path) -> list[Path]:
|
|
outputs = [
|
|
output_dir / f"{DM_MODEL_NAME}_tinygrad.pkl",
|
|
output_dir / f"{DM_MODEL_NAME}_metadata.pkl",
|
|
*(output_dir / f"dm_warp_{width}x{height}_tinygrad.pkl" for width, height in DEFAULT_CAMERA_RESOLUTIONS),
|
|
]
|
|
removed = remove_paths(outputs)
|
|
if removed:
|
|
print(f" cleared {removed} existing DM output entries")
|
|
|
|
subprocess.run(
|
|
[sys.executable, str(COMPILE_SCRIPT), str(onnx_path), str(outputs[0])],
|
|
cwd=REPO_ROOT,
|
|
env=build_compile_env(),
|
|
check=True,
|
|
)
|
|
write_metadata(onnx_path, outputs[1])
|
|
dm_w, dm_h = DM_INPUT_SIZE
|
|
for (cam_w, cam_h), output_path in zip(DEFAULT_CAMERA_RESOLUTIONS, outputs[2:], strict=True):
|
|
subprocess.run(
|
|
[
|
|
sys.executable,
|
|
str(DM_WARP_COMPILE_SCRIPT),
|
|
"--camera-resolution",
|
|
f"{cam_w}x{cam_h}",
|
|
"--warp-to",
|
|
f"{dm_w}x{dm_h}",
|
|
"--output",
|
|
str(output_path),
|
|
],
|
|
cwd=REPO_ROOT,
|
|
env=build_compile_env(),
|
|
check=True,
|
|
)
|
|
return outputs
|
|
|
|
|
|
def list_models(staged: dict[str, dict[str, Path]], input_root: Path) -> int:
|
|
for model_key, files in sorted(staged.items()):
|
|
print(model_key)
|
|
for component, path in sorted(files.items()):
|
|
print(f" {component}: {path}")
|
|
if (dm_path := find_staged_dm(input_root)) is not None:
|
|
print(DM_MODEL_KEY)
|
|
print(f" {DM_MODEL_NAME}: {dm_path}")
|
|
if not staged and dm_path is None:
|
|
print(f"No staged models found in {input_root}")
|
|
return 0
|
|
|
|
|
|
def main() -> int:
|
|
args = parse_args()
|
|
staged = find_staged_models(args.input_dir)
|
|
if args.list:
|
|
return list_models(staged, args.input_dir)
|
|
|
|
args.output_dir.mkdir(parents=True, exist_ok=True)
|
|
if args.dm:
|
|
onnx_path = find_staged_dm(args.input_dir)
|
|
if onnx_path is None:
|
|
raise SystemExit(f"No staged DM ONNX found in {args.input_dir}")
|
|
print(f"Compiling DM artifacts from {onnx_path} -> {args.output_dir}")
|
|
for output in compile_dm(onnx_path, args.output_dir):
|
|
print(f" saved {output.name}")
|
|
print("Done.")
|
|
return 0
|
|
|
|
if not args.model:
|
|
available = ", ".join(sorted(key for key in staged if key != "_root"))
|
|
raise SystemExit(f"Choose a model ID, for example ./models --sc2. Available: {available or 'none'}")
|
|
|
|
model_key = args.model.strip()
|
|
files = resolve_model_files(args.input_dir, model_key)
|
|
if not files:
|
|
raise SystemExit(f"No staged ONNX files found for {model_key} in {args.input_dir}")
|
|
|
|
input_format = select_input_format(args.input_format, files)
|
|
version = infer_model_version(model_key, args.version)
|
|
version_label = version or "unspecified behavior"
|
|
print(f"Compiling {model_key} ({input_format}, {version_label}) from {args.input_dir} -> {args.output_dir}")
|
|
output = compile_driving(model_key, files, input_format, version, args.output_dir)
|
|
print(f" saved {output.name}")
|
|
print("Done.")
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise SystemExit(main())
|