mirror of
https://github.com/sunnypilot/sunnypilot.git
synced 2026-06-22 23:12:09 +08:00
243 lines
7.9 KiB
Python
Executable File
243 lines
7.9 KiB
Python
Executable File
"""
|
|
Copyright (c) 2021-, Haibin Wen, sunnypilot, and a number of other contributors.
|
|
|
|
This file is part of sunnypilot and is licensed under the MIT License.
|
|
See the LICENSE.md file in the root directory for more details.
|
|
"""
|
|
|
|
import os
|
|
import pickle
|
|
import sys
|
|
import hashlib
|
|
import json
|
|
import re
|
|
from pathlib import Path
|
|
from datetime import datetime, UTC
|
|
|
|
REQUIRED_OUTPUT_KEYS = frozenset({
|
|
"plan",
|
|
"lane_lines",
|
|
"road_edges",
|
|
"lead",
|
|
"desire_state",
|
|
"desire_pred",
|
|
"meta",
|
|
"lead_prob",
|
|
"lane_lines_prob",
|
|
"pose",
|
|
"wide_from_device_euler",
|
|
"road_transform",
|
|
"hidden_state",
|
|
})
|
|
OPTIONAL_OUTPUT_KEYS = frozenset({
|
|
"planplus",
|
|
"sim_pose",
|
|
"desired_curvature",
|
|
})
|
|
|
|
|
|
def validate_model_outputs(metadata_paths: list[Path]) -> None:
|
|
combined_keys: set[str] = set()
|
|
for path in metadata_paths:
|
|
if path.stat().st_size == 0:
|
|
print(f"skipping empty metadata: {path}")
|
|
continue
|
|
with open(path, "rb") as f:
|
|
metadata = pickle.load(f)
|
|
combined_keys.update(metadata.get("output_slices", {}).keys())
|
|
missing = REQUIRED_OUTPUT_KEYS - combined_keys
|
|
if missing:
|
|
raise ValueError(f"Combined model metadata is missing required output keys: {sorted(missing)}")
|
|
detected_optional = sorted(OPTIONAL_OUTPUT_KEYS & combined_keys)
|
|
if detected_optional:
|
|
print(f"Optional output keys detected: {detected_optional}")
|
|
|
|
|
|
def create_short_name(full_name):
|
|
# Remove parentheses and extract alphanumeric words
|
|
clean_name = re.sub(r'\([^)]*\)', '', full_name)
|
|
words = [re.sub(r'[^a-zA-Z0-9]', '', word) for word in clean_name.split() if re.sub(r'[^a-zA-Z0-9]', '', word)]
|
|
|
|
if len(words) == 1:
|
|
return words[0][:8].upper()
|
|
|
|
# Handle special case: Name + Version (e.g., "Word A1" -> "WordA1")
|
|
if len(words) == 2 and re.match(r'^[A-Za-z]\d+$', words[1]):
|
|
return (words[0] + words[1])[:8].upper()
|
|
|
|
result = ""
|
|
for word in words:
|
|
# Version or number patterns
|
|
if (re.match(r'^\d+[a-zA-Z]+$', word) or
|
|
re.match(r'^\d+[vVbB]\d+$', word) or
|
|
re.match(r'^[vVbB]\d+$', word) or
|
|
re.match(r'^\d{4}$', word)):
|
|
result += word.upper()
|
|
# All uppercase abbreviations (2-3 letters)
|
|
elif re.match(r'^[A-Z]{2,3}$', word):
|
|
result += word
|
|
# Letters+digits (for example tr15 rev2)
|
|
elif re.match(r'^[a-zA-Z]+[0-9]+$', word):
|
|
result += word[0].upper() + ''.join(re.findall(r'\d+', word))
|
|
elif word.isalpha():
|
|
result += word[0].upper()
|
|
elif word.isdigit():
|
|
result += word
|
|
else:
|
|
result += word[0].upper()
|
|
return result[:8]
|
|
|
|
|
|
def _read_pkl_bytes(pkl_path: Path) -> bytes:
|
|
manifest = Path(f"{pkl_path}.chunkmanifest")
|
|
if manifest.exists():
|
|
num_chunks = int(manifest.read_text().strip())
|
|
parts = []
|
|
for i in range(num_chunks):
|
|
chunk = Path(f"{pkl_path}.chunk{i + 1:02d}of{num_chunks:02d}")
|
|
parts.append(chunk.read_bytes())
|
|
return b''.join(parts)
|
|
return pkl_path.read_bytes()
|
|
|
|
|
|
def _find_driving_pkl(output_path: Path) -> Path | None:
|
|
for pattern in ('driving_tinygrad.pkl', 'driving_*_tinygrad.pkl'):
|
|
matches = sorted(output_path.glob(pattern))
|
|
if matches:
|
|
return matches[0]
|
|
for pattern in ('driving_tinygrad.pkl.chunkmanifest', 'driving_*_tinygrad.pkl.chunkmanifest'):
|
|
matches = sorted(output_path.glob(pattern))
|
|
if matches:
|
|
return Path(str(matches[0]).removesuffix('.chunkmanifest'))
|
|
return None
|
|
|
|
|
|
def _rename_pkl_with_chunks(old_pkl: Path, new_pkl: Path) -> Path:
|
|
manifest = Path(f"{old_pkl}.chunkmanifest")
|
|
if manifest.exists():
|
|
for f in sorted(old_pkl.parent.glob(f"{old_pkl.name}.chunk*")):
|
|
f.rename(old_pkl.parent / f.name.replace(old_pkl.name, new_pkl.name, 1))
|
|
return new_pkl
|
|
return old_pkl.rename(new_pkl)
|
|
|
|
|
|
def generate_metadata(model_path: Path, output_dir: Path, short_name: str, driving_pkl: Path):
|
|
base = model_path.stem
|
|
metadata_file = output_dir / f"{base}_metadata.pkl"
|
|
|
|
if short_name:
|
|
renamed_meta = output_dir / f"{base}_{short_name.lower()}_metadata.pkl"
|
|
if metadata_file.exists() and not renamed_meta.exists():
|
|
metadata_file = metadata_file.rename(renamed_meta)
|
|
elif renamed_meta.exists():
|
|
metadata_file = renamed_meta
|
|
|
|
if not metadata_file.exists():
|
|
print(f"Warning: Missing metadata for {base} ({metadata_file}), skipping", file=sys.stderr)
|
|
return
|
|
|
|
tinygrad_hash = hashlib.sha256(_read_pkl_bytes(driving_pkl)).hexdigest()
|
|
|
|
with open(metadata_file, 'rb') as f:
|
|
metadata_hash = hashlib.sha256(f.read()).hexdigest()
|
|
|
|
model_type = "offPolicy" if "off_policy" in base else "onPolicy" if "on_policy" in base else base.split("_")[-1]
|
|
|
|
return {
|
|
"type": model_type,
|
|
"artifact": {
|
|
"file_name": driving_pkl.name,
|
|
"download_uri": {
|
|
"url": "https://gitlab.com/sunnypilot/public/docs.sunnypilot.ai/-/raw/main/",
|
|
"sha256": tinygrad_hash
|
|
}
|
|
},
|
|
"metadata": {
|
|
"file_name": metadata_file.name,
|
|
"download_uri": {
|
|
"url": "https://gitlab.com/sunnypilot/public/docs.sunnypilot.ai/-/raw/main/",
|
|
"sha256": metadata_hash
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
def create_metadata_json(models: list, output_dir: Path, custom_name=None, short_name=None, is_20hz=False, upstream_branch="unknown"):
|
|
metadata_json = {
|
|
"short_name": short_name,
|
|
"display_name": custom_name or upstream_branch,
|
|
"is_20hz": is_20hz,
|
|
"ref": upstream_branch,
|
|
"environment": "development",
|
|
"runner": "tinygrad",
|
|
"index": -1,
|
|
"minimum_selector_version": "-1",
|
|
"generation": "-1",
|
|
"build_time": datetime.now(UTC).strftime("%Y-%m-%dT%H:%M:%SZ"),
|
|
"overrides": {},
|
|
"models": models,
|
|
}
|
|
|
|
# Write metadata to output_dir
|
|
with open(output_dir / "metadata.json", "w") as f:
|
|
json.dump(metadata_json, f, indent=2)
|
|
|
|
print(f"Generated metadata.json with {len(models)} models.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import argparse
|
|
import glob
|
|
|
|
parser = argparse.ArgumentParser(description="Generate metadata for model files")
|
|
parser.add_argument("--model-dir", default="./models", help="Directory containing ONNX model files")
|
|
parser.add_argument("--output-dir", default="./output", help="Output directory for metadata")
|
|
parser.add_argument("--custom-name", help="Custom display name for the model")
|
|
parser.add_argument("--is-20hz", action="store_true", help="Whether this is a 20Hz model")
|
|
parser.add_argument("--validate-only", action="store_true")
|
|
parser.add_argument("--upstream-branch", default="unknown", help="Upstream branch name")
|
|
args = parser.parse_args()
|
|
|
|
if args.validate_only:
|
|
metadata_paths = glob.glob(os.path.join(args.model_dir, "*_metadata.pkl"))
|
|
if not metadata_paths:
|
|
print(f"No metadata files found in {args.model_dir}", file=sys.stderr)
|
|
sys.exit(1)
|
|
validate_model_outputs([Path(p) for p in metadata_paths])
|
|
print(f"Validated {len(metadata_paths)} metadata files successfully.")
|
|
sys.exit(0)
|
|
|
|
# Find all ONNX files in the given directory
|
|
model_paths = glob.glob(os.path.join(args.model_dir, "*.onnx"))
|
|
if not model_paths:
|
|
print(f"No ONNX files found in {args.model_dir}", file=sys.stderr)
|
|
sys.exit(1)
|
|
|
|
_output_dir = Path(args.output_dir)
|
|
_output_dir.mkdir(exist_ok=True, parents=True)
|
|
_short_name = create_short_name(args.custom_name) if args.custom_name else None
|
|
|
|
_driving_pkl = _find_driving_pkl(_output_dir)
|
|
if not _driving_pkl:
|
|
print(f"No driving_tinygrad.pkl found in {_output_dir}", file=sys.stderr)
|
|
sys.exit(1)
|
|
|
|
if _short_name:
|
|
new_pkl = _output_dir / f"driving_{_short_name.lower()}_tinygrad.pkl"
|
|
if not new_pkl.exists():
|
|
_driving_pkl = _rename_pkl_with_chunks(_driving_pkl, new_pkl)
|
|
else:
|
|
_driving_pkl = new_pkl
|
|
|
|
_models = []
|
|
|
|
for _model_path in model_paths:
|
|
_model_metadata = generate_metadata(Path(_model_path), _output_dir, _short_name, _driving_pkl)
|
|
if _model_metadata:
|
|
_models.append(_model_metadata)
|
|
|
|
if _models:
|
|
create_metadata_json(_models, _output_dir, args.custom_name, _short_name, args.is_20hz, args.upstream_branch)
|
|
else:
|
|
print("No models processed.", file=sys.stderr)
|