mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-11 23:46:02 +08:00
log optimized kernels and a script to compare with non-optimized ones (#3829)
* search: add BEAM_VERIFY option to validate search results refactor fuzz_linearizer comparison to allow it to be used in for BEAM_VERIFY in device.py * search: fix to verify the beam_search result and not the fastest * search: fix typing and clean up * device: remove imports from test and add LOGKERN options LOGKERN output can be used with test/external/verify_kernel.py to validate correctness * fix example in verify_kernel.py * cleanup fixes * fix to use f-strings
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
# stuff needed to unpack a kernel
|
||||
from tinygrad.ops import LazyOp, TernaryOps, BinaryOps, UnaryOps, ReduceOps, BufferOps, MemBuffer, ConstBuffer
|
||||
from tinygrad.codegen.kernel import Opt, OptOps
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View
|
||||
@@ -10,6 +11,12 @@ inf, nan = float('inf'), float('nan')
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
def ast_str_to_ast(ast_str:str) -> LazyOp: return eval(ast_str)
|
||||
def ast_str_to_lin(ast_str:str, opts=None): return Linearizer(ast_str_to_ast(ast_str), opts=opts)
|
||||
def kern_str_to_lin(kern_str:str, opts=None):
|
||||
(ast, applied_opts,) = eval(kern_str)
|
||||
k = Linearizer(*ast, opts=opts)
|
||||
for opt in applied_opts:
|
||||
k.apply_opt(opt)
|
||||
return k
|
||||
|
||||
# load worlds, a dataset of about 12k kernels
|
||||
import gzip
|
||||
|
||||
85
test/external/fuzz_linearizer.py
vendored
85
test/external/fuzz_linearizer.py
vendored
@@ -1,15 +1,17 @@
|
||||
import random, traceback, ctypes
|
||||
from typing import List, Tuple
|
||||
from typing import List, Tuple, DefaultDict
|
||||
import numpy as np
|
||||
from collections import defaultdict
|
||||
from extra.optimization.helpers import load_worlds, ast_str_to_lin
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
from tinygrad.codegen.linearizer import Linearizer, UOp
|
||||
from tinygrad.codegen.kernel import Opt
|
||||
from tinygrad.features.search import get_linearizer_actions, bufs_from_lin
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.features.graph import print_tree
|
||||
from tinygrad.helpers import getenv, from_mv, prod, colored, Context
|
||||
from tinygrad.device import Device, Compiled
|
||||
from tinygrad.codegen.linearizer import UOp
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
from tinygrad.ops import LazyOp
|
||||
|
||||
def tuplize_uops(uops:List[UOp]) -> Tuple: return tuple([(x.uop, x.dtype, tuple(uops.index(x) for x in x.vin), x.arg) for x in uops])
|
||||
|
||||
@@ -24,7 +26,7 @@ def get_fuzz_rawbufs(lin):
|
||||
with Context(DEBUG=0):
|
||||
for rawbuf in rawbufs[1:]:
|
||||
t = Tensor.uniform((rawbuf.size,), dtype=rawbuf.dtype)
|
||||
rawbuf.copyin(t.realize().lazydata.realized.as_buffer())
|
||||
if isinstance(ld:=t.realize().lazydata, LazyBuffer) and ld.realized: rawbuf.copyin(ld.realized.as_buffer())
|
||||
return rawbufs
|
||||
|
||||
def get_fuzz_rawbuf_like(rawbuf, zero=False, size=None):
|
||||
@@ -38,32 +40,43 @@ def get_fuzz_rawbuf_like(rawbuf, zero=False, size=None):
|
||||
|
||||
def run_linearizer(lin: Linearizer, rawbufs=None, var_vals=None):
|
||||
if rawbufs is None: rawbufs = bufs_from_lin(lin)
|
||||
if var_vals is None: var_vals = {v: v.min for v in lin.ast.vars()}
|
||||
if var_vals is None: var_vals = {v: v.min for v in lin.ast[0].vars()}
|
||||
|
||||
# TODO: images needs required_optimization
|
||||
try:
|
||||
if isinstance(device, Compiled):
|
||||
prg = device.to_program(lin)
|
||||
else:
|
||||
prg = device.get_runner(lin.ast)
|
||||
prg = device.to_program(lin)
|
||||
except Exception:
|
||||
print(lin.ast)
|
||||
print(lin.applied_opts)
|
||||
traceback.print_exc()
|
||||
print("COMPILE FAILED!!")
|
||||
return "COMPILE_ERROR"
|
||||
|
||||
try:
|
||||
prg.exec(rawbufs, var_vals)
|
||||
prg(rawbufs, var_vals, wait=True, do_update_stats=False)
|
||||
except Exception:
|
||||
print(lin.ast)
|
||||
print(lin.applied_opts)
|
||||
traceback.print_exc()
|
||||
print("EXEC FAILED!!")
|
||||
return "EXEC_ERROR"
|
||||
|
||||
return "PASS"
|
||||
|
||||
def compare_linearizer(lin: Linearizer, rawbufs=None, var_vals=None, ground_truth=None, rtol=1e-2, atol=1e-2):
|
||||
try:
|
||||
if rawbufs is None:
|
||||
rawbufs = get_fuzz_rawbufs(lin)
|
||||
else:
|
||||
rawbufs[0] = get_fuzz_rawbuf_like(rawbufs[0], zero=True) # get a new output buffer
|
||||
except BaseException:
|
||||
return ("RAWBUFS_ERROR", rawbufs, var_vals, ground_truth,)
|
||||
if var_vals is None: var_vals = {v: random.randint(v.min, v.max if isinstance(v.max, int) else v.min) for v in lin.ast[0].vars()}
|
||||
if ground_truth is None:
|
||||
unoptimized = Linearizer(*lin.ast)
|
||||
unoptimized.required_optimizations()
|
||||
if run_linearizer(unoptimized, rawbufs, var_vals) != "PASS":
|
||||
return ("BASELINE_ERROR", rawbufs, var_vals, ground_truth,)
|
||||
ground_truth = np.frombuffer(rawbufs[0].as_buffer(), rawbufs[0].dtype.np).copy()
|
||||
|
||||
if (run_msg := run_linearizer(lin, rawbufs, var_vals)) != "PASS":
|
||||
return (run_msg, rawbufs, var_vals, ground_truth,)
|
||||
result = np.frombuffer(rawbufs[0].as_buffer(), rawbufs[0].dtype.np)
|
||||
return ("PASS" if np.allclose(result, ground_truth, rtol=rtol, atol=atol) else "COMPARE_ERROR", rawbufs, var_vals, ground_truth,)
|
||||
|
||||
def fuzz_linearizer(lin: Linearizer):
|
||||
SEED = getenv("SEED", 42)
|
||||
@@ -73,7 +86,8 @@ def fuzz_linearizer(lin: Linearizer):
|
||||
print(lin.colored_shape())
|
||||
seen_uops = {}
|
||||
last_lins = [lin]
|
||||
failures = defaultdict(list)
|
||||
failures:DefaultDict[str, List[Tuple[Tuple[LazyOp,...],List[Opt]]]] = defaultdict(list)
|
||||
rawbufs, var_vals, ground_truth = None, None, None
|
||||
|
||||
FUZZ_BEAM = getenv("FUZZ_BEAM", 0)
|
||||
FUZZ_MAX_SIZE = getenv("FUZZ_MAX_SIZE", 0)
|
||||
@@ -81,23 +95,6 @@ def fuzz_linearizer(lin: Linearizer):
|
||||
print("skipping large kernel")
|
||||
return failures
|
||||
|
||||
# get baseline unoptimized output
|
||||
unoptimized = Linearizer(*lin.ast)
|
||||
var_vals = {v: random.randint(v.min, v.max) for v in lin.ast[0].vars()}
|
||||
|
||||
try:
|
||||
rawbufs = get_fuzz_rawbufs(lin)
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
print("RAWBUFS FAILED!!")
|
||||
failures["RAWBUFS_ERROR"].append((unoptimized.ast, unoptimized.applied_opts))
|
||||
return failures
|
||||
|
||||
if run_linearizer(unoptimized, rawbufs, var_vals) != "PASS":
|
||||
failures["BASELINE_ERROR"].append((unoptimized.ast, unoptimized.applied_opts))
|
||||
return failures
|
||||
ground_truth = np.frombuffer(rawbufs[0].as_buffer(), rawbufs[0].dtype.np).copy()
|
||||
|
||||
for depth in range(getenv("DEPTH", 1 if FUZZ_BEAM else 10)):
|
||||
next_lins = []
|
||||
for lin in last_lins:
|
||||
@@ -118,23 +115,15 @@ def fuzz_linearizer(lin: Linearizer):
|
||||
seen_uops[tuops] = tuple(test_lin.applied_opts)
|
||||
|
||||
if not FUZZ_BEAM: print(test_lin.colored_shape())
|
||||
# get a new output buffer
|
||||
rawbufs[0] = get_fuzz_rawbuf_like(rawbufs[0], zero=True)
|
||||
if (msg := run_linearizer(test_lin, rawbufs, var_vals)) != "PASS":
|
||||
|
||||
(msg, rawbufs, var_vals, ground_truth) = compare_linearizer(test_lin, rawbufs, var_vals, ground_truth)
|
||||
if msg != "PASS":
|
||||
print(test_lin.ast)
|
||||
print(test_lin.applied_opts)
|
||||
print(msg)
|
||||
failures[msg].append((test_lin.ast, test_lin.applied_opts))
|
||||
continue
|
||||
|
||||
result = np.frombuffer(rawbufs[0].as_buffer(), rawbufs[0].dtype.np)
|
||||
try:
|
||||
# compare memoryviews directly
|
||||
np.testing.assert_allclose(result, ground_truth, rtol=1e-2, atol=1e-2)
|
||||
except AssertionError:
|
||||
print(test_lin.ast)
|
||||
print(test_lin.applied_opts)
|
||||
traceback.print_exc()
|
||||
print("COMPARE FAILED!!")
|
||||
failures["COMPARE_ERROR"].append((test_lin.ast, test_lin.applied_opts))
|
||||
continue
|
||||
next_lins.append(test_lin)
|
||||
|
||||
last_lins = next_lins
|
||||
|
||||
57
test/external/verify_kernel.py
vendored
Normal file
57
test/external/verify_kernel.py
vendored
Normal file
@@ -0,0 +1,57 @@
|
||||
import argparse
|
||||
from collections import defaultdict
|
||||
from extra.optimization.helpers import kern_str_to_lin
|
||||
from test.external.fuzz_linearizer import compare_linearizer
|
||||
from tinygrad.helpers import colored
|
||||
from tinygrad.features.graph import print_tree
|
||||
|
||||
# Use this with the LOGKERN options to verify that all executed kernels are valid and evaluate to the same ground truth results
|
||||
|
||||
# Example for GPT2:
|
||||
# 1) Run the model to log all kernels: `PYTHONPATH=. LOGKERN=/tmp/gpt2_kerns.txt JIT=1 HALF=1 BEAM=2 CACHELEVEL=0 CAST_BEFORE_VIEW=0 python3 examples/gpt2.py --count 10 --temperature 0 --timing` # noqa: E501
|
||||
# 2) Validate the kernel correctness: `PYTHONPATH=. python3 ./test/external/verify_kernel.py --file /tmp/gpt2_kerns.txt`
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Verify the correctness of one or more kernel", formatter_class=argparse.ArgumentDefaultsHelpFormatter) # noqa: E501
|
||||
parser.add_argument("--kernel", type=str, default=None, help="a string of a tuple of (ast, applied_opts,)")
|
||||
parser.add_argument("--file", type=str, default=None, help="a file containing a tuple of ast and applied_opts, one per line")
|
||||
parser.add_argument("--rtol", type=float, default=1e-2, help="relative tolerance for numerical comparison")
|
||||
parser.add_argument("--atol", type=float, default=1e-2, help="absolute tolerance for numerical comparison")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.kernel is not None:
|
||||
print("loading kernel from args")
|
||||
kern_strs = [args.kernel]
|
||||
elif args.file is not None:
|
||||
print(f"loading kernel from file '{args.file}'")
|
||||
with open(args.file, 'r') as file:
|
||||
kern_strs = file.readlines()
|
||||
else:
|
||||
raise RuntimeError("no kernel specified; use --kernel or --file options")
|
||||
|
||||
print(f"verifying {len(kern_strs)} kernels")
|
||||
|
||||
failed_ids = []
|
||||
failures = defaultdict(list)
|
||||
for i, kern_str in enumerate(kern_strs):
|
||||
print(f"testing kernel {i}")
|
||||
test_lin = kern_str_to_lin(kern_str)
|
||||
for op in test_lin.ast: print_tree(op)
|
||||
print(test_lin.colored_shape())
|
||||
if (msg:=compare_linearizer(test_lin, None, None, None, rtol=args.rtol, atol=args.atol)[0]) != "PASS":
|
||||
failed_ids.append(i)
|
||||
failures[msg].append((test_lin.ast, test_lin.applied_opts))
|
||||
|
||||
for msg, errors in failures.items():
|
||||
for i, (ast, opts) in enumerate(errors):
|
||||
print(f"{msg} {i} AST: {ast}")
|
||||
print(f"{msg} {i} OPTS: {opts}\n")
|
||||
|
||||
print(f"tested {len(kern_strs)} kernels")
|
||||
if failures:
|
||||
print(f"{failed_ids=}")
|
||||
for msg, errors in failures.items():
|
||||
print(f"{msg}: {len(errors)}")
|
||||
raise RuntimeError(f"failed on {len(failed_ids)} kernels")
|
||||
else:
|
||||
print(colored("all passed", "green"))
|
||||
@@ -239,6 +239,7 @@ class MultiDeviceJITGraph(JITRunner):
|
||||
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False) -> Optional[float]:
|
||||
raise NotImplementedError("override this")
|
||||
|
||||
logkern, logkern_level = open(getenv("LOGKERN", ""), "a") if getenv("LOGKERN", "") else None, getenv("LOGKERN_LEVEL", 1)
|
||||
class Compiled:
|
||||
def __init__(self, device:str, allocator:Allocator, compiler:Optional[Compiler], runtime, graph=None):
|
||||
self.dname, self.allocator, self.compiler, self.runtime, self.graph = device, allocator, compiler, runtime, graph
|
||||
@@ -278,6 +279,9 @@ class Compiled:
|
||||
timed = sorted([(nm, tk, time_linearizer(tk, test_rawbuffers, allow_test_size=False, clear_l2=True)) for nm, tk in lins], key=lambda x: x[2])
|
||||
if DEBUG >= 1: print(" < ".join(f"{nm:6s} : {lin.colored_shape(30, dense=True)} : {tm*1e6:8.2f} us" for nm, lin, tm in timed))
|
||||
k = timed[0][1]
|
||||
if logkern is not None and logkern_level > 1: logkern.writelines([f"{(lin.ast, lin.applied_opts)}\n" for (_,lin,_) in timed[1:]])
|
||||
# TODO: check the correctness inline once compare_linearizer is in core
|
||||
if logkern is not None: logkern.writelines([f"{(k.ast, k.applied_opts)}\n"])
|
||||
return k
|
||||
|
||||
@functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
|
||||
|
||||
Reference in New Issue
Block a user