Files
onepilot/tinygrad_repo/test/external/verify_kernel.py
T
carrot 77a8919349 TR16 Model, fix radar routine (#211)
* UV+DTR model

* DTR model.. again.

* fix naviGPS

* fix radar...

* fix..

* test

* fix..

* carrot serv

* fix..

* fix.. fleet

* fix.. radar

* fix atc

* Steam Powered model..

* fix.. radarLatFactor range.. 200->500

* fix.. dbc..

* side

* SP v2

* brake light

* fix brakelight

* fix..

* add datetime...

* fix..

* fix..

* fix..

* fix..

* blind spot

* fix tz

* fix..

* ff

* radarLatFactor

* fix.. bsd

* Revert "fix.. bsd"

This reverts commit 1d0d1434470e1b92c65eaffaeb8dd7cd779f85ee.

* fix.. bsd side..

* test

* fix.. e2e conditions

* Revert "test"

This reverts commit 0ce791dbd66c17260366ed1a4df2626c602dbb7d.

* TR16

* fix cut-in detect threshold  3.4 -> 2.6

* fix.. jerk_l limit 5->10

* fix..

* fix.. gm

* fix.. OPTIMA_H mass

* fix.. radar..

* fix radar..

* fix..

* Radar...

* fix..

* fix..

* fix..

* fix.. radartrack 3

* fix..

* fix..

* fix..

* merge..

* fix.. canfd

* fix..

* fix..

* fix..

* fix.. radard

* new cut_in

* Revert "new cut_in"

This reverts commit b9b6e9b33318fe1ce7d626468139b17848efcdcd.

* fix..

* new cut_in detect...

* fix.. disp..

* fix..

* fix..

* fix.. center radar..

* fix.. radar y_sane..

* fix..

* fix..

* hkg jerk 10 -> 5

* fix..

* fix..

* fix.. radar dbc..

* fix..

* fix.. jLead filter..

* test new radar interface..

* fix..

* fix..

* test time...

* Revert "test time..."

This reverts commit 63e9187736985c4dc4b4f3736674ba7cda6adc3f.

* fix radar..

* fix..

* FireHose model..

* tinygrad

* Update interface.py

* fix..

* fix.. nff toyota corolla_tss2

* fix..

* fix..

* fix.. radar

* fix..

* fix.. radar, y_gate

* fix.. radar..

* fix.. for clone..

* scc radar enable at low speed..

* fix.. settings..

* fix.

* fix..

* fix.. radarTimeStep.

* TR16 model again..

* RELEASE.md

* fix cut-in detection...

* fix.. registeration timeout 15sec..

* fix..

* fix.. radar processing.

* fix..

* fix..

* fix..

* fix..

* fix..

* fix..
2025-09-05 15:43:10 +09:00

79 lines
3.7 KiB
Python

import argparse
from collections import defaultdict
from extra.optimization.helpers import kern_str_to_lin, time_linearizer
from test.external.fuzz_linearizer import compare_linearizer
from tinygrad.helpers import colored
from tinygrad.codegen.opt.kernel import Kernel
# Use this with the LOGKERNS options to verify that all executed kernels are valid and evaluate to the same ground truth results
# Example for GPT2:
# 1) Run the model to log all kernels: `PYTHONPATH=. LOGKERNS=/tmp/gpt2_kerns.txt JIT=1 HALF=1 BEAM=2 CACHELEVEL=0 python3 examples/gpt2.py --count 10 --temperature 0 --timing` # noqa: E501
# 2) Validate the kernel correctness: `PYTHONPATH=. python3 ./test/external/verify_kernel.py --file /tmp/gpt2_kerns.txt`
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Verify the correctness of one or more kernel", formatter_class=argparse.ArgumentDefaultsHelpFormatter) # noqa: E501
parser.add_argument("--kernel", type=str, default=None, help="a string of a tuple of (ast, applied_opts,)")
parser.add_argument("--file", type=str, default=None, help="a file containing a tuple of ast and applied_opts, one per line")
parser.add_argument("--pkl", type=str, default=None, help="a pickle file containing a single tuple of ast and applied_opts")
parser.add_argument("--rtol", type=float, default=1e-2, help="relative tolerance for numerical comparison")
parser.add_argument("--atol", type=float, default=1e-2, help="absolute tolerance for numerical comparison")
parser.add_argument("--timing", action='store_true', help="show final timing for the kernel")
parser.add_argument("--expected-failures", type=int, default=0, help="the number of expected failed kernels")
args = parser.parse_args()
if args.kernel is not None:
print("loading kernel from args")
test_lins = [kern_str_to_lin(args.kernel)]
elif args.file is not None:
print(f"loading kernel from file '{args.file}'")
with open(args.file, 'r') as file:
kern_strs = file.readlines()
test_lins = [kern_str_to_lin(kern_str) for kern_str in kern_strs]
elif args.pkl is not None:
print(f"loading kernel from pickle file '{args.file}'")
import pickle
with open(args.pkl, 'rb') as file:
(ast, applied_opts,) = pickle.load(file)
lin = Kernel(ast)
lin.apply_opts(applied_opts)
test_lins = [lin]
else:
raise RuntimeError("no kernel specified; use --kernel, --file, or --pkl options")
print(f"verifying {len(test_lins)} kernels")
failed_ids = []
failures = defaultdict(list)
for i, test_lin in enumerate(test_lins):
print(f"testing kernel {i}")
print(test_lin.ast)
print(test_lin.applied_opts)
unoptimized_lin = Kernel(test_lin.ast)
print(f"{unoptimized_lin.colored_shape()} -> {test_lin.colored_shape()}")
(msg,rb,vv,gt) = compare_linearizer(test_lin, None, None, None, rtol=args.rtol, atol=args.atol)
if msg != "PASS":
failed_ids.append(i)
failures[msg].append((test_lin.ast, test_lin.applied_opts))
if args.timing:
tm = time_linearizer(test_lin, rb, allow_test_size=False, cnt=10)
print(f"final time {tm*1e6:9.0f} us")
for msg, errors in failures.items():
for i, (ast, opts) in enumerate(errors):
print(f"{msg} {i} AST: {ast}")
print(f"{msg} {i} OPTS: {opts}\n")
print(f"tested {len(test_lins)} kernels")
if failures:
print(f"{failed_ids=}")
for msg, errors in failures.items():
print(f"{msg}: {len(errors)}")
if len(failed_ids) == args.expected_failures:
print(colored(f"{len(failed_ids)} failed as expected", "yellow"))
if len(failed_ids) != args.expected_failures:
raise RuntimeError(f"failed on {len(failed_ids)} kernels, expected {args.expected_failures}")
else:
print(colored("all passed", "green"))