delete old rules [pr] (#7400)

This commit is contained in:
George Hotz
2024-10-30 18:45:04 +07:00
committed by GitHub
parent 573a848229
commit ee9ef93617
2 changed files with 5 additions and 16 deletions

View File

@@ -122,8 +122,8 @@ class TestUOpsStats(unittest.TestCase):
globl = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), tuple())
o1 = UOp(UOps.CONST, dtypes.int, tuple(), 1)
o2 = UOp(UOps.CONST, dtypes.int, tuple(), 2)
u1 = UOp(UOps.LOAD, dtypes.int, (globl, o1))
u2 = UOp(UOps.LOAD, dtypes.int, (globl, o2))
u1 = UOp(UOps.LOAD, dtypes.int, (globl.index(o1),))
u2 = UOp(UOps.LOAD, dtypes.int, (globl.index(o2),))
u3 = UOp(UOps.CONST, dtypes.int, tuple(), 3)
u4 = UOp(UOps.ALU, dtypes.int, (u1,u2), BinaryOps.MUL)
u5 = UOp(UOps.ALU, dtypes.int, (u4,u3), BinaryOps.ADD)
@@ -132,8 +132,8 @@ class TestUOpsStats(unittest.TestCase):
globl = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), tuple())
o1 = UOp(UOps.CONST, dtypes.int, tuple(), 1)
o2 = UOp(UOps.CONST, dtypes.int, tuple(), 2)
u1 = UOp(UOps.LOAD, dtypes.int, (globl, o1))
u2 = UOp(UOps.LOAD, dtypes.int, (globl, o2))
u1 = UOp(UOps.LOAD, dtypes.int, (globl.index(o1),))
u2 = UOp(UOps.LOAD, dtypes.int, (globl.index(o2),))
u3 = UOp(UOps.CONST, dtypes.int, tuple(), 3)
u4 = UOp(UOps.ALU, dtypes.int, (u1,u2,u3), TernaryOps.MULACC)
uops_fma = linearize_uop(u4.sink())

View File

@@ -302,7 +302,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
assert isinstance(ret, tuple) and all(isinstance(x, int) for x in ret), f"axis_arg trying to return {ret}"
return ret
def sink(self, *srcs:UOp): return UOp(UOps.SINK, dtypes.void, (self,)+srcs)
def index(self, idx:UOp): return UOp(UOps.INDEX, self.dtype, (self,idx))
def index(self, idx:UOp, valid:Optional[UOp]=None): return UOp(UOps.INDEX, self.dtype, (self,idx,valid) if valid is not None else (self,idx))
def view(self, st:ShapeTracker): return UOp(UOps.VIEW, self.dtype, (self,), st)
def const_like(self, b:ConstType|Variable|Tuple[ConstType, ...]): return UOp.const(self.dtype, b)
def broadcast(self, count:int):
@@ -763,17 +763,6 @@ spec = PatternMatcher([
# early STORE has a <buf, shapetracker, val>
(UPat(UOps.STORE, src=(UPat((UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL)), UPat(UOps.VIEW), UPat())), lambda: True),
# LOAD takes a <buf, idx, alt?, gate?, barrier?>
(UPat(UOps.LOAD, src=(UPat((UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL)), UPat())), lambda: True),
(UPat(UOps.LOAD, src=(UPat((UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL)), UPat(), UPat((UOps.IF, UOps.BARRIER)))), lambda: True),
(UPat(UOps.LOAD, src=(UPat((UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL)), UPat(), UPat(name="alt"), UPat(dtype=dtypes.bool)), name="ld"),
lambda ld,alt: ld.dtype == alt.dtype),
# STORE takes a <buf, idx, val, gate?>
(UPat(UOps.STORE, dtype=dtypes.void, src=(UPat((UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL)), UPat(), UPat())), lambda: True),
(UPat(UOps.STORE, dtype=dtypes.void, src=(UPat((UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL)), UPat(), UPat(), UPat(dtype=dtypes.bool))), lambda: True),
(UPat(UOps.STORE, dtype=dtypes.void, src=(UPat((UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL)), UPat(), UPat(), UPat(UOps.IF))), lambda: True),
# **** new style load/store ****
# INDEX is used in new style load/store