From af4f9d1aa9cde9b53eca8300df72dedcb99e9e4e Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Fri, 31 Jan 2025 02:17:42 -0500 Subject: [PATCH] use matchers to verify AST shape [pr] (#8828) * use matchers to verify kernel AST [pr] * work * use swizzle_cnt * add comment * imports * modified_ast comment * brief --- test/test_schedule.py | 5 ++-- test/unit/test_verify_ast.py | 13 +++++----- tinygrad/codegen/kernel.py | 50 +++++------------------------------- tinygrad/spec.py | 17 ++++++++++-- 4 files changed, 32 insertions(+), 53 deletions(-) diff --git a/test/test_schedule.py b/test/test_schedule.py index 8f41fc7153..b434bdcfe8 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -14,12 +14,13 @@ from tinygrad.dtype import DType, ImageDType from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import View from tinygrad.ops import PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_rewrites, symbolic_simple, merge_views +from tinygrad.spec import type_verify, shape_spec from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, unwrap, prod, all_same, temp -from tinygrad.codegen.kernel import verify_ast from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars, view_right, view_left, remove_movement_ops, sym from tinygrad.engine.realize import CompiledRunner, run_schedule, lower_schedule from extra.models.llama import precompute_freqs_cis +def verify_ast(sink:UOp): return type_verify(list(sink.toposort), shape_spec) class KernelCountException(Exception): pass def check_schedule(t:Union[Tensor, List[Tensor], UOp], allowed:int, to_prerealize:Optional[List[Tensor]]=None, filter_sink=True): if to_prerealize: @@ -1824,7 +1825,7 @@ class TestIndexing(unittest.TestCase): sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), r)),)) rsink = graph_rewrite(sink, view_right) # this AST first needs to swizzle, but it doesn't have implicit movementops - with self.assertRaisesRegex(AssertionError, "swizzle"): verify_ast(sink) + self.assertEqual(swizzle_cnt(sink), 1) verify_ast(rsink) def test_no_reshape_reduceop(self): diff --git a/test/unit/test_verify_ast.py b/test/unit/test_verify_ast.py index bf23847e4e..ccae3810b3 100644 --- a/test/unit/test_verify_ast.py +++ b/test/unit/test_verify_ast.py @@ -5,7 +5,7 @@ from tinygrad import Tensor from tinygrad.codegen.kernel import Kernel from tinygrad.helpers import DEBUG from tinygrad.ops import UOp, Ops, print_uops -from tinygrad.codegen.kernel import verify_ast +from tinygrad.spec import type_verify, shape_spec from tinygrad.shape.shapetracker import ShapeTracker from tinygrad import dtypes from tinygrad.shape.view import View @@ -15,8 +15,8 @@ def helper_test_verify_ast(*stores:UOp) -> Kernel: sink = UOp(Ops.SINK, dtypes.void, stores) if DEBUG >= 3: for op in stores: print(op) - try: verify_ast(sink) - except AssertionError as e: raise InvalidASTException(e.args) + try: type_verify(list(sink.toposort), shape_spec) + except RuntimeError as e: raise InvalidASTException(e.args) k = Kernel(sink) k.linearize() if DEBUG >= 6: print_uops(k.uops) @@ -64,23 +64,24 @@ class TestVerifyAST(unittest.TestCase): a = UOp(Ops.LOAD, dtypes.float, (bufs[1], ShapeTracker.from_shape((32, 1)).to_uop())) r = UOp(Ops.REDUCE_AXIS, dtypes.float, (a,), (Ops.ADD, (0,))) st = UOp.store(bufs[0], ShapeTracker.from_shape((32, 1)).to_uop(), r) - with self.assertRaisesRegex(InvalidASTException, "implicit expand"): helper_test_verify_ast(st) + with self.assertRaises(InvalidASTException): helper_test_verify_ast(st) def test_reduce_add_store(self): bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), i) for i in range(2)] a = UOp(Ops.LOAD, dtypes.float, (bufs[1], ShapeTracker.from_shape((32, 1)).to_uop())) r = UOp(Ops.REDUCE_AXIS, dtypes.float, (a,), (Ops.ADD, (0,))) st = UOp.store(bufs[0], ShapeTracker.from_shape((32, 1)).to_uop(), r+a) - with self.assertRaisesRegex(InvalidASTException, "implicit expand"): helper_test_verify_ast(st) + with self.assertRaises(InvalidASTException): helper_test_verify_ast(st) def test_buffer_uops_st(self): a = Tensor.randn(4, 4)+2 - verify_ast(ast:=a.schedule()[-1].ast) + helper_test_verify_ast(ast:=a.schedule()[-1].ast) store_st = [u.st for u in ast.toposort if u.op is Ops.STORE][0] self.assertEqual(store_st, ShapeTracker.from_shape((4, 4))) const_st = [u.st for u in ast.toposort if u.op is Ops.CONST][0] self.assertEqual(const_st, ShapeTracker.from_shape((1, 1)).expand((4, 4))) + @unittest.skip("questionable if we want this") def test_assert_swizzle(self): buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0) a = UOp(Ops.LOAD, dtypes.float, (buf, ShapeTracker.from_shape((32, 1)).to_uop())) diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index 913c6ee747..49adcaf785 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -6,7 +6,7 @@ from typing import Optional, cast, Final, Callable, Sequence from enum import Enum, auto from tinygrad.ops import GroupOp, KernelInfo, UOp, Ops, can_pad, resolve, Variable, sint, graph_rewrite, track_rewrites, view_left, print_uops -from tinygrad.spec import type_verify +from tinygrad.spec import type_verify, shape_spec from tinygrad.device import Device from tinygrad.renderer import Renderer, TensorCore, ProgramSpec from tinygrad.dtype import ImageDType @@ -57,11 +57,8 @@ class Kernel: if ast.op is Ops.SINK: self.ast = ast self.opts = opts if opts is not None else Device[Device.DEFAULT].renderer - try: verify_ast(self.ast) - except AssertionError as e: - print("INVALID AST") - print(self.ast) - raise e + # verify AST matches the spec + if __debug__: type_verify(list(self.ast.toposort), shape_spec) self.reduceops = [x for x in self.ast.toposort if x.op is Ops.REDUCE_AXIS] @@ -673,7 +670,10 @@ class Kernel: if getenv("RAWAST"): print(self.ast) print(modified_ast) print(self.applied_opts) - verify_ast(modified_ast) + # verify AST matches the spec after applying opts + if __debug__: type_verify(list(modified_ast.toposort)) + # TODO: sadly modified_ast doesn't pass the shape spec because of how group_for_reduces constructs UOps, there's probably a way to fix this + #if __debug__: type_verify(list(modified_ast.toposort), shape_spec) self.uops:list[UOp] = linearize_uop(full_graph_rewrite(rewrite_shapetracker_with_index(modified_ast, self.opts), self.opts)) if DEBUG >= 5: print_uops(self.uops) @@ -693,39 +693,3 @@ class Kernel: key=lambda x: (x.op, x.src[0].arg))) return ProgramSpec(ansiname, src, self.opts.device, self.uops, mem_estimate=mem_bytes, global_size=[1,1,1] if self.opts.has_local else None, local_size=[1,1,1] if self.opts.has_local else None) - -# the living definition of intermediate UOps - -def _assert_valid_uop(uop:UOp, st:ShapeTracker, sts:dict[UOp, ShapeTracker]) -> None: - if uop in sts: return - # restore globals from the two stage reduce - # this is because this LOAD has an implicit movement op - if uop.op is Ops.LOAD and uop.src[0].op is Ops.DEFINE_LOCAL: - _assert_valid_uop(local_reduce:=uop.src[2].src[2], uop.st_arg, sts) - sts[uop] = sts[local_reduce] - return - for x in uop.src: _assert_valid_uop(x, st, sts) - # only reduceuop is allowed to change shape, limited to turning n to 1 - if uop.op in {Ops.REDUCE_AXIS, Ops.WMMA}: st = ShapeTracker.from_shape(sts[uop.src[0]].reduce(uop.axis_arg)) - # movementops are pushed to VIEW - elif uop.op is Ops.VIEW: - # NOTE: we disallow VIEW in the middle of the AST, if it has a DEVICE source it's fine - assert len(uop.src) == 0 or uop.src[0].op is Ops.DEVICE, f"can't swizzle in kernel yet {uop}" - st = uop.arg - # everything else inherits shape - else: - if len(src_sts:=[sts[x] for x in uop.src if x in sts]) == 0: return None - st = src_sts[0] - if not all_same(shapes:=[x.shape for x in src_sts]): - if all_same(sizes:=[prod(x) for x in shapes]): raise AssertionError(f"found implicit reshape {shapes}") - raise AssertionError(f"found implicit expand {sizes} {shapes}") - sts[uop] = st - -def verify_ast(ast:UOp) -> None: - assert ast.op is Ops.SINK and all(x.op is Ops.STORE for x in ast.src), "must be SINK" - assert all_same([x.st_arg.size for x in ast.src]), "outputs must be exactly the same size" - sts: dict[UOp, ShapeTracker] = {} - for out in ast.src: _assert_valid_uop(out, out.st_arg, sts) - shape_dims = [sorted(dedup(dims)) for dims in zip(*[x.shape for x in sts.values()])] - assert all(len(x) == 1 or (len(x) == 2 and x[0] == 1) for x in shape_dims), f"shapes must have either 1 or n in each dimension, {shape_dims}" - type_verify(list(sts)) diff --git a/tinygrad/spec.py b/tinygrad/spec.py index 1d432221be..32934eef60 100644 --- a/tinygrad/spec.py +++ b/tinygrad/spec.py @@ -1,7 +1,7 @@ from typing import cast from tinygrad.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, print_uops from tinygrad.dtype import DType, ImageDType, dtypes, PtrDType -from tinygrad.helpers import all_int, prod +from tinygrad.helpers import all_int, all_same, dedup, prod # *** this is the spec of a Tensor in UOp *** @@ -61,7 +61,7 @@ spec = PatternMatcher([ (UPat(Ops.VIEW, src=(UPat.var("src"),), name="x"), lambda x,src: src.op is not Ops.STORE and x.dtype == src.dtype), (UPat(Ops.VALID, dtypes.bool, (UPat(Ops.VIEW),)), lambda: True), - (UPat(Ops.CONST, name="x"), lambda x: x.dtype == x.dtype.scalar() and (type(x.arg) is type(dtypes.as_const(x.arg, x.dtype)))), + (UPat(Ops.CONST, name="x"), lambda x: type(x.arg) is type(dtypes.as_const(x.arg, x.dtype))), # early LOAD has a (UPat(Ops.LOAD, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat(Ops.VIEW))), lambda: True), @@ -121,6 +121,19 @@ spec = PatternMatcher([ (UPat((Ops.LOAD, Ops.STORE), src=(UPat(dtype=dtypes.int64),), allow_any_len=True), lambda: True), ]) +# *** this is the UOp shape spec *** + +def verify_sink_dims(sink:UOp): + shape_dims = [sorted(dedup(dims)) for dims in zip(*[x.shape for x in sink.toposort if x.op is not Ops.SINK and x.st is not None])] + return all_same([x.st_arg.size for x in sink.src]) and all(len(x) == 1 or (len(x) == 2 and x[0] == 1) for x in shape_dims) + +shape_spec = PatternMatcher([ + # shapes must have either 1 or n in each dimension + (UPat(Ops.SINK, src=UPat(Ops.STORE), allow_any_len=True, name="sink"), verify_sink_dims), + # all parent UOps must have the same shape + (UPat(set(Ops)-{Ops.SINK}, name="root"), lambda root: all_same([x.shape for x in root.src if x.st is not None])), +]) + # ***** uop helpers ***** def type_verify(uops:list[UOp], *extra_specs:PatternMatcher):