mirror of
https://github.com/firestar5683/StarPilot.git
synced 2026-06-15 03:54:49 +08:00
67 lines
2.7 KiB
Python
67 lines
2.7 KiB
Python
#!/usr/bin/env python3
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
|
|
from pathlib import Path
|
|
|
|
if __package__ in (None, ""):
|
|
import sys
|
|
sys.path.insert(0, str(Path(__file__).resolve().parent))
|
|
from common import DEFAULT_WORKSPACE, resolve_workspace # type: ignore
|
|
else:
|
|
from .common import DEFAULT_WORKSPACE, resolve_workspace
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
parser = argparse.ArgumentParser(description="Train the speed-limit value classifier using Ultralytics YOLO classification.")
|
|
parser.add_argument("--workspace", type=Path, default=DEFAULT_WORKSPACE, help="Training workspace root.")
|
|
parser.add_argument("--data", type=Path, help="Classifier dataset root. Defaults to <workspace>/classifier.")
|
|
parser.add_argument("--model", default="yolo11n-cls.pt", help="Ultralytics classification checkpoint to fine-tune.")
|
|
parser.add_argument("--epochs", type=int, default=60, help="Training epochs.")
|
|
parser.add_argument("--imgsz", type=int, default=128, help="Training image size.")
|
|
parser.add_argument("--batch", type=int, default=32, help="Batch size.")
|
|
parser.add_argument("--workers", type=int, default=8, help="Data loader workers.")
|
|
parser.add_argument("--device", default="cpu", help="Ultralytics device string, for example cpu, mps, 0, or 0,1.")
|
|
parser.add_argument("--project", type=Path, help="Training output directory. Defaults to <workspace>/runs/classifier.")
|
|
parser.add_argument("--name", default="yolo11n-cls-speed-limit-us", help="Run name under --project.")
|
|
parser.add_argument("--patience", type=int, default=15, help="Early stopping patience.")
|
|
parser.add_argument("--cache", action="store_true", help="Cache images in RAM if supported.")
|
|
parser.add_argument("--exist-ok", action="store_true", help="Allow overwriting an existing run directory.")
|
|
return parser.parse_args()
|
|
|
|
|
|
def main() -> int:
|
|
args = parse_args()
|
|
workspace = resolve_workspace(args.workspace)
|
|
data_path = args.data.resolve() if args.data else (workspace / "classifier")
|
|
project_path = args.project.resolve() if args.project else (workspace / "runs" / "classifier")
|
|
|
|
try:
|
|
from ultralytics import YOLO
|
|
except Exception as exc:
|
|
raise SystemExit(
|
|
"Ultralytics is not installed. Run `uv sync --extra speedvision` in the repo root before training."
|
|
) from exc
|
|
|
|
model = YOLO(args.model)
|
|
model.train(
|
|
data=str(data_path),
|
|
epochs=args.epochs,
|
|
imgsz=args.imgsz,
|
|
batch=args.batch,
|
|
workers=args.workers,
|
|
device=args.device,
|
|
project=str(project_path),
|
|
name=args.name,
|
|
patience=args.patience,
|
|
cache=args.cache,
|
|
exist_ok=args.exist_ok,
|
|
)
|
|
print(f"Classifier training complete under {project_path / args.name}")
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise SystemExit(main())
|