From ee9ef936177976c5f5beaaeecf5707e7bedea592 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Wed, 30 Oct 2024 18:45:04 +0700 Subject: [PATCH] delete old rules [pr] (#7400) --- test/test_uops_stats.py | 8 ++++---- tinygrad/ops.py | 13 +------------ 2 files changed, 5 insertions(+), 16 deletions(-) diff --git a/test/test_uops_stats.py b/test/test_uops_stats.py index 8198c6c165..806e5db328 100644 --- a/test/test_uops_stats.py +++ b/test/test_uops_stats.py @@ -122,8 +122,8 @@ class TestUOpsStats(unittest.TestCase): globl = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), tuple()) o1 = UOp(UOps.CONST, dtypes.int, tuple(), 1) o2 = UOp(UOps.CONST, dtypes.int, tuple(), 2) - u1 = UOp(UOps.LOAD, dtypes.int, (globl, o1)) - u2 = UOp(UOps.LOAD, dtypes.int, (globl, o2)) + u1 = UOp(UOps.LOAD, dtypes.int, (globl.index(o1),)) + u2 = UOp(UOps.LOAD, dtypes.int, (globl.index(o2),)) u3 = UOp(UOps.CONST, dtypes.int, tuple(), 3) u4 = UOp(UOps.ALU, dtypes.int, (u1,u2), BinaryOps.MUL) u5 = UOp(UOps.ALU, dtypes.int, (u4,u3), BinaryOps.ADD) @@ -132,8 +132,8 @@ class TestUOpsStats(unittest.TestCase): globl = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), tuple()) o1 = UOp(UOps.CONST, dtypes.int, tuple(), 1) o2 = UOp(UOps.CONST, dtypes.int, tuple(), 2) - u1 = UOp(UOps.LOAD, dtypes.int, (globl, o1)) - u2 = UOp(UOps.LOAD, dtypes.int, (globl, o2)) + u1 = UOp(UOps.LOAD, dtypes.int, (globl.index(o1),)) + u2 = UOp(UOps.LOAD, dtypes.int, (globl.index(o2),)) u3 = UOp(UOps.CONST, dtypes.int, tuple(), 3) u4 = UOp(UOps.ALU, dtypes.int, (u1,u2,u3), TernaryOps.MULACC) uops_fma = linearize_uop(u4.sink()) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 7c2943ecbc..11dbebcf12 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -302,7 +302,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): assert isinstance(ret, tuple) and all(isinstance(x, int) for x in ret), f"axis_arg trying to return {ret}" return ret def sink(self, *srcs:UOp): return UOp(UOps.SINK, dtypes.void, (self,)+srcs) - def index(self, idx:UOp): return UOp(UOps.INDEX, self.dtype, (self,idx)) + def index(self, idx:UOp, valid:Optional[UOp]=None): return UOp(UOps.INDEX, self.dtype, (self,idx,valid) if valid is not None else (self,idx)) def view(self, st:ShapeTracker): return UOp(UOps.VIEW, self.dtype, (self,), st) def const_like(self, b:ConstType|Variable|Tuple[ConstType, ...]): return UOp.const(self.dtype, b) def broadcast(self, count:int): @@ -763,17 +763,6 @@ spec = PatternMatcher([ # early STORE has a (UPat(UOps.STORE, src=(UPat((UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL)), UPat(UOps.VIEW), UPat())), lambda: True), - # LOAD takes a - (UPat(UOps.LOAD, src=(UPat((UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL)), UPat())), lambda: True), - (UPat(UOps.LOAD, src=(UPat((UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL)), UPat(), UPat((UOps.IF, UOps.BARRIER)))), lambda: True), - (UPat(UOps.LOAD, src=(UPat((UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL)), UPat(), UPat(name="alt"), UPat(dtype=dtypes.bool)), name="ld"), - lambda ld,alt: ld.dtype == alt.dtype), - - # STORE takes a - (UPat(UOps.STORE, dtype=dtypes.void, src=(UPat((UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL)), UPat(), UPat())), lambda: True), - (UPat(UOps.STORE, dtype=dtypes.void, src=(UPat((UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL)), UPat(), UPat(), UPat(dtype=dtypes.bool))), lambda: True), - (UPat(UOps.STORE, dtype=dtypes.void, src=(UPat((UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL)), UPat(), UPat(), UPat(UOps.IF))), lambda: True), - # **** new style load/store **** # INDEX is used in new style load/store