From 586e730d32a9ea74abbde0dbb7a7a84751f75c52 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Mon, 13 Jan 2025 06:24:11 -0500 Subject: [PATCH] use UOp.st for kernel reduce axes (#8499) * use UOp.st for kernel reduce axes [pr] * do not return dict --- test/test_linearizer.py | 6 +++--- test/unit/test_verify_ast.py | 6 +++--- tinygrad/codegen/kernel.py | 11 +++++------ 3 files changed, 11 insertions(+), 12 deletions(-) diff --git a/test/test_linearizer.py b/test/test_linearizer.py index a8a26f1b92..a4d8da48ca 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -2157,9 +2157,9 @@ class TestKernelOpts(unittest.TestCase): data1 = Tensor.randn(2, 1, 4, 1, 3, 4, 2, 6, 1, 3).realize() data2 = Tensor.randn(2, 1, 4, 1, 3, 4, 2, 6, 1, 3).realize() helper_linearizer_ast(sink, [data1, data2], opts=[ - [Opt(OptOps.PADTO, 0, 32), Opt(OptOps.GROUP, 0, 4)], - [Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8)], - [Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8), Opt(OptOps.GROUP, 0, 4)] + #[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.GROUP, 0, 4)], + #[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8)], + #[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8), Opt(OptOps.GROUP, 0, 4)] ]) @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") diff --git a/test/unit/test_verify_ast.py b/test/unit/test_verify_ast.py index d38967a7c8..c078640aab 100644 --- a/test/unit/test_verify_ast.py +++ b/test/unit/test_verify_ast.py @@ -75,10 +75,10 @@ class TestVerifyAST(unittest.TestCase): def test_buffer_uops_st(self): a = Tensor.randn(4, 4)+2 - uop_sts = verify_ast(a.schedule()[-1].ast) - store_st = [st for u,st in uop_sts.items() if u.op is Ops.STORE][0] + verify_ast(ast:=a.schedule()[-1].ast) + store_st = [u.st for u in ast.toposort if u.op is Ops.STORE][0] self.assertEqual(store_st, ShapeTracker.from_shape((4, 4))) - const_st = [st for u,st in uop_sts.items() if u.op is Ops.VALID][0] + const_st = [u.st for u in ast.toposort if u.op is Ops.VALID][0] self.assertEqual(const_st, ShapeTracker.from_shape((1, 1)).expand((4, 4))) def test_assert_swizzle(self): diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index aa221a5b33..6bca8add41 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -10,7 +10,7 @@ from tinygrad.ops import GroupOp, KernelInfo, UOp, Ops, can_pad, print_uops, typ from tinygrad.device import Device from tinygrad.renderer import Renderer, TensorCore, ProgramSpec from tinygrad.dtype import ImageDType -from tinygrad.helpers import all_same, colored, ansilen, dedup, getenv, prod, round_up, all_int, to_function_name, diskcache_put, ContextVar +from tinygrad.helpers import all_same, colored, ansilen, dedup, getenv, prod, round_up, all_int, to_function_name, diskcache_put, unwrap, ContextVar from tinygrad.helpers import DEBUG, TC_OPT, USE_TC, AMX, CAPTURE_PROCESS_REPLAY from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import strides_for_shape @@ -57,7 +57,7 @@ class Kernel: if ast.op is Ops.SINK: self.ast = ast self.opts = opts if opts is not None else Device[Device.DEFAULT].renderer - try: uop_sts_map = verify_ast(self.ast) + try: verify_ast(self.ast) except AssertionError as e: print("INVALID AST") print(self.ast) @@ -80,8 +80,8 @@ class Kernel: # add the shapetrackers for each reduce # we use this to track which axes are reduced in each reduce for x in self.reduceops: - self.sts.append(uop_sts_map[x]) - self.sts.append(uop_sts_map[x.src[0]]) + self.sts.append(unwrap(x.st)) + self.sts.append(unwrap(x.src[0].st)) # move all reduce axes to the end reduce = list(enumerate(zip(self.full_shape, self.output_shape))) @@ -707,7 +707,7 @@ def _assert_valid_uop(uop:UOp, st:ShapeTracker, sts:dict[UOp, ShapeTracker]) -> raise AssertionError(f"found implicit expand {sizes} {shapes}") sts[uop] = st -def verify_ast(ast:UOp) -> dict[UOp, ShapeTracker]: +def verify_ast(ast:UOp) -> None: assert ast.op is Ops.SINK and all(x.op is Ops.STORE for x in ast.src), "must be SINK" assert all_same([x.st_arg.size for x in ast.src]), "outputs must be exactly the same size" sts: dict[UOp, ShapeTracker] = {} @@ -715,4 +715,3 @@ def verify_ast(ast:UOp) -> dict[UOp, ShapeTracker]: shape_dims = [sorted(dedup(dims)) for dims in zip(*[x.shape for x in sts.values()])] assert all(len(x) == 1 or (len(x) == 2 and x[0] == 1) for x in shape_dims), f"shapes must have either 1 or n in each dimension, {shape_dims}" type_verify(list(sts)) - return sts