diff --git a/extra/optimization/get_action_space.py b/extra/optimization/get_action_space.py index 81ac990e70..b0e6721919 100644 --- a/extra/optimization/get_action_space.py +++ b/extra/optimization/get_action_space.py @@ -6,7 +6,7 @@ from tinygrad.helpers import tqdm tactions = set() def test_rebuild(lin): - linr = Linearizer(*lin.ast) + linr = Linearizer(lin.ast) for o in lin.applied_opts: assert o in actions, f"{o} is not in actions" tactions.add(o) diff --git a/test/external/fuzz_linearizer.py b/test/external/fuzz_linearizer.py index 35f4c36cec..007ec8cf67 100644 --- a/test/external/fuzz_linearizer.py +++ b/test/external/fuzz_linearizer.py @@ -87,10 +87,10 @@ def compare_linearizer(lin: Linearizer, rawbufs=None, var_vals=None, ground_trut if var_vals is None: # TODO: handle symbolic max case - var_vals = {v: random.randint(v.min, v.max if isinstance(v.max, int) else v.min) for v in lin.ast[0].vars()} + var_vals = {v: random.randint(v.min, v.max if isinstance(v.max, int) else v.min) for v in lin.ast.vars()} if ground_truth is None and not has_bf16: - unoptimized = Linearizer(*lin.ast) + unoptimized = Linearizer(lin.ast) unoptimized.required_optimizations() if run_linearizer(unoptimized, rawbufs, var_vals) != "PASS": return ("BASELINE_ERROR", rawbufs, var_vals, ground_truth,) @@ -121,7 +121,7 @@ def fuzz_linearizer(lin: Linearizer, rtol=1e-2, atol=1e-2): SEED = getenv("SEED", 42) random.seed(SEED) np.random.seed(SEED) - for op in lin.ast: print_tree(op) + print_tree(lin.ast) print(lin.colored_shape()) seen_uops = {} last_lins = [lin] @@ -178,8 +178,8 @@ def fuzz_linearizer(lin: Linearizer, rtol=1e-2, atol=1e-2): return failures def _is_simple(lin: Linearizer) -> bool: - if len(lin.ast) > 1: return False - ast:LazyOp = lin.ast[0] + if len(lin.ast.src) > 1: return False + ast:LazyOp = lin.ast.src[0] if ast.src[0] and ast.src[0].op is UnaryOps.CAST and ast.src[0].src[0] and ast.src[0].src[0].op is BufferOps.LOAD: return True return False diff --git a/test/external/process_replay/process_replay.py b/test/external/process_replay/process_replay.py index c7ef5d6349..952ce74165 100755 --- a/test/external/process_replay/process_replay.py +++ b/test/external/process_replay/process_replay.py @@ -17,7 +17,7 @@ for offset in tqdm(range(0, row_count, page_size)): with Context(**{k:v for k,v in ctx.items() if k in ContextVar._cache}): # try linearize try: - k = Linearizer(*ast, opts=opts) + k = Linearizer(ast, opts=opts) for opt in applied_opts: k.apply_opt(opt) good_src = k.opts.render(name, k.linearize().uops) except Exception as e: diff --git a/test/test_verify_lazyop.py b/test/test_verify_lazyop.py index ed132d096f..7926d41eac 100644 --- a/test/test_verify_lazyop.py +++ b/test/test_verify_lazyop.py @@ -4,18 +4,19 @@ from tinygrad.codegen.linearizer import Linearizer #from tinygrad.codegen.lowerer import Lowerer from tinygrad.engine.graph import print_tree from tinygrad.helpers import DEBUG -from tinygrad.ops import BufferOps, MemBuffer, LazyOp, ReduceOps, verify_lazyop +from tinygrad.ops import BufferOps, MemBuffer, LazyOp, ReduceOps, MetaOps, verify_lazyop from tinygrad.shape.shapetracker import ShapeTracker from tinygrad import dtypes from tinygrad.shape.view import View class InvalidLazyOpException(Exception): pass def lower(*ast:LazyOp): + sink_ast = LazyOp(MetaOps.SINK, ast) if DEBUG >= 3: for op in ast: print_tree(op) - try: verify_lazyop(*ast) + try: verify_lazyop(sink_ast) except AssertionError: raise InvalidLazyOpException() - k = Linearizer(*ast) + k = Linearizer(sink_ast) k.linearize() if DEBUG >= 6: k.uops.print() if DEBUG >= 4: print(k.to_program().src) diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index 24f0574b2a..fb283b4363 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -8,7 +8,7 @@ from tinygrad.ops import LazyOp, UnaryOps, BinaryOps, ReduceOps, MemBuffer, Cons from tinygrad.device import Device from tinygrad.renderer import Renderer, TensorCore from tinygrad.dtype import dtypes, ImageDType -from tinygrad.helpers import all_same, colored, ansilen, dedup, flatten, getenv, prod, DEBUG, TC_OPT, USE_TC, round_up, all_int, get_contraction, to_function_name # noqa: E501 +from tinygrad.helpers import all_same, colored, ansilen, dedup, getenv, prod, DEBUG, TC_OPT, USE_TC, round_up, all_int, get_contraction, to_function_name # noqa: E501 from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.symbolic import sint from tinygrad.shape.view import strides_for_shape @@ -61,22 +61,28 @@ class TensorCoreOptions: class Kernel: def __init__(self, *ast:LazyOp, opts:Optional[Renderer]=None): + if len(ast) > 1 or ast[0].op is BufferOps.STORE: + assert all(x.op is BufferOps.STORE for x in ast) + self.ast = LazyOp(MetaOps.SINK, ast) + else: + assert len(ast) == 1 and ast[0].op is MetaOps.SINK + self.ast = ast[0] + self.opts = opts if opts is not None else Device[Device.DEFAULT].renderer - try: lazyop_sts_map = verify_lazyop(*ast) + try: lazyop_sts_map = verify_lazyop(self.ast) except AssertionError as e: print("INVALID AST") for op in ast: print_tree(op) raise e - self.ast = ast - self.lazyops = flatten([op.lazyops for op in self.ast]) + self.lazyops = self.ast.lazyops cached_ordered_lazyops: Dict[LazyOp, List[LazyOp]] = {} def ordered_lazyops(op): if op not in cached_ordered_lazyops: cached_ordered_lazyops[op] = dedup([item for x in op.src for item in ordered_lazyops(x)] + [op]) return cached_ordered_lazyops[op] - self.reduceops = dedup([x for out in self.ast for x in ordered_lazyops(out) if x.op in ReduceOps]) + self.reduceops = dedup([x for x in ordered_lazyops(self.ast) if x.op in ReduceOps]) - self.vars = flatten([x.vars() for x in self.ast]) + self.vars = self.ast.vars() self.bufs: List[Union[MemBuffer, ConstBuffer]] = dedup([x.arg for x in self.lazyops if x.op in BufferOps]) # get earlybufs, before any reduceops @@ -645,7 +651,7 @@ class Kernel: def name(self) -> str: # kernel name (before late upcast) name = ("r" if self.reduceop else ("C" if all(x.op in BufferOps for x in self.lazyops) else "E")) + \ - (f"{len(self.ast)}_" if len(self.ast) > 1 else "_") + \ + (f"{len(self.ast.src)}_" if len(self.ast.src) > 1 else "_") + \ colored('_', 'BLACK').join([colored(str(x), c) for x,c in zip(self.full_shape, self.colors())]) # name the function something unique @@ -720,5 +726,4 @@ class Kernel: else: arg = op.arg return LazyOp(op.op, tuple(fixup_ast(x) for x in op.src), arg) - return fixup_ast(LazyOp(MetaOps.SINK, src=self.ast)), \ - KernelInfo(self.full_shape, self.global_dims, self.first_reduce, self.group_for_reduces, self.upcasted) + return fixup_ast(self.ast), KernelInfo(self.full_shape, self.global_dims, self.first_reduce, self.group_for_reduces, self.upcasted) diff --git a/tinygrad/codegen/lowerer.py b/tinygrad/codegen/lowerer.py index a4f9c5f137..81ca7cb649 100644 --- a/tinygrad/codegen/lowerer.py +++ b/tinygrad/codegen/lowerer.py @@ -88,7 +88,7 @@ class Lowerer(Kernel): buf = UOp(UOps.DEFINE_LOCAL, PtrDType(x.arg.dtype.base if isinstance(x.arg.dtype, ImageDType) else x.arg.dtype), (), ("temp", x.arg.st.size)) else: buf = UOp(UOps.DEFINE_GLOBAL, x.arg.dtype if isinstance(x.arg.dtype, ImageDType) else PtrDType(x.arg.dtype), (), - (x.arg.idx, any(x.arg.idx == y.arg.idx for y in self.ast))) + (x.arg.idx, any(x.arg.idx == y.arg.idx for y in self.ast.src))) if x.op is BufferOps.LOAD: barrier = (UOp(UOps.BARRIER, None, (self.to_uop(x.src[0]),)),) if len(x.src) else () return UOp(UOps.LOAD, x.arg.dtype.scalar(), (buf, idx) + ((valid, UOp.const(x.arg.dtype.scalar(), 0)) if has_valid else ()) + barrier) @@ -168,7 +168,7 @@ class Lowerer(Kernel): if getenv("RUN_PROCESS_REPLAY"): table_name = f"process_replay_{getenv('GITHUB_SHA', 'HEAD')}" diskcache_put(table_name, id(self), (self.ast, self.opts, self.applied_opts, name, src, {k:v.value for k,v in ContextVar._cache.items()})) - info = get_lazyop_info(self.ast[0]) + info = get_lazyop_info(self.ast.src[0]) # TODO: this should be removed ops, mem = flops_mem(self.uops.uops) run_count = prod((self.global_size or []) + (self.local_size or [])) return Program(self.name, src, self.opts.device, self.global_size, self.local_size, diff --git a/tinygrad/engine/search.py b/tinygrad/engine/search.py index 139c25decb..7efd1b13d2 100644 --- a/tinygrad/engine/search.py +++ b/tinygrad/engine/search.py @@ -117,7 +117,7 @@ def get_linearizer_actions(lin:Linearizer, include_0=True) -> Dict[int, Lineariz beam_pool, BEAM_DEBUG = None, getenv("BEAM_DEBUG") def beam_search(lin:Linearizer, rawbufs:List[Buffer], amt:int, allow_test_size=True) -> Linearizer: global beam_pool - key = {"ast": lin.ast[0].key, "amt": amt, "allow_test_size": allow_test_size, "device": lin.opts.device, "suffix": lin.opts.suffix} + key = {"ast": lin.ast.key, "amt": amt, "allow_test_size": allow_test_size, "device": lin.opts.device, "suffix": lin.opts.suffix} if not getenv("IGNORE_BEAM_CACHE") and CACHELEVEL >= 1 and (val:=diskcache_get("beam_search", key)) is not None: ret = lin.copy() for o in val[len(lin.applied_opts):]: ret.apply_opt(o) @@ -136,7 +136,7 @@ def beam_search(lin:Linearizer, rawbufs:List[Buffer], amt:int, allow_test_size=T try: rawbufs = _ensure_buffer_alloc(rawbufs) - var_vals = {k:(k.max+k.min)//2 for k in lin.ast[0].vars()} + var_vals = {k:(k.max+k.min)//2 for k in lin.ast.vars()} exiting, st = False, time.perf_counter() dev = Device[lin.opts.device] while not exiting: @@ -182,7 +182,7 @@ def optimize_local_size(clprg:Callable, global_size:List[int], rawbufs:List[Buff return ret[1] def time_linearizer(lin:Linearizer, rawbufs:List[Buffer], allow_test_size=True, max_global_size=65536, cnt=3, disable_cache=False, clear_l2=False) -> float: # noqa: E501 - key = {"ast": lin.ast[0].key, "opts": str(lin.applied_opts), "allow_test_size": allow_test_size, + key = {"ast": lin.ast.key, "opts": str(lin.applied_opts), "allow_test_size": allow_test_size, "max_global_size": max_global_size, "clear_l2": clear_l2, "device": lin.opts.device, "suffix": lin.opts.suffix} if not disable_cache and CACHELEVEL >= 2 and (val:=diskcache_get("time_linearizer", key)) is not None: return min(val) @@ -190,7 +190,7 @@ def time_linearizer(lin:Linearizer, rawbufs:List[Buffer], allow_test_size=True, assert dev.compiler is not None rawbufs = _ensure_buffer_alloc(rawbufs) - var_vals = {k:(k.max+k.min)//2 for k in lin.ast[0].vars()} + var_vals = {k:(k.max+k.min)//2 for k in lin.ast.vars()} p = lin.to_program() tms = _time_program(p, dev.compiler.compile(p.src), var_vals, rawbufs, max_global_size=max_global_size if allow_test_size else None, clear_l2=clear_l2, cnt=cnt, name=to_function_name(lin.name)) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 59dbbaee93..1c0e7cb08f 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -155,7 +155,8 @@ truncate: Dict[DType, Callable] = {dtypes.bool: bool, def exec_alu(op:Op, dtype:DType, operands): return truncate.get(dtype, lambda x: x)(python_alu[op](*operands)) # the living definition of LazyOps -def verify_lazyop(*ast:LazyOp) -> Dict[LazyOp, ShapeTracker]: +def verify_lazyop(ast:LazyOp) -> Dict[LazyOp, ShapeTracker]: + assert ast.op is MetaOps.SINK, "must be SINK" sts: Dict[LazyOp, ShapeTracker] = {} def dfs(op:LazyOp, st:ShapeTracker): if op in sts: return @@ -170,9 +171,9 @@ def verify_lazyop(*ast:LazyOp) -> Dict[LazyOp, ShapeTracker]: else: st = sts[op.src[0]] for x in op.src: assert sts[x].shape == st.shape, f"found implicit movement op {x.op} {sts[x].shape} != {op.op} {st.shape}" sts[op] = st - for i, out in enumerate(ast): + for i, out in enumerate(ast.src): assert out.arg.idx == i, f"unexpected output buffer idx {out.arg.idx} != {i}" assert out.op is BufferOps.STORE, f"kernels must have stores as the output, got {out.op}" - assert out.arg.st.size == ast[-1].arg.st.size, f"outputs must have the same size, got {out.arg.st.size}" + assert out.arg.st.size == ast.src[-1].arg.st.size, f"outputs must have the same size, got {out.arg.st.size}" dfs(out, out.arg.st) return sts