From 4a8bf07a87ef298c92281e807c8ea08deff7b425 Mon Sep 17 00:00:00 2001 From: chenyu Date: Thu, 4 Jun 2026 11:29:46 -0400 Subject: [PATCH] remove CONST(DEVICE) (#16506) --- test/backend/test_const_folding.py | 2 +- test/null/test_viz.py | 4 ++-- test/unit/test_tensor_data.py | 4 ---- tinygrad/callify.py | 2 -- tinygrad/mixin/__init__.py | 2 +- tinygrad/tensor.py | 4 ++-- tinygrad/uop/ops.py | 11 ++++++----- tinygrad/uop/render.py | 7 ------- tinygrad/uop/spec.py | 3 +-- 9 files changed, 13 insertions(+), 26 deletions(-) diff --git a/test/backend/test_const_folding.py b/test/backend/test_const_folding.py index 9ad21e30c8..3731195bf2 100644 --- a/test/backend/test_const_folding.py +++ b/test/backend/test_const_folding.py @@ -167,7 +167,7 @@ class TestMultiConstFolding(unittest.TestCase): class TestThreefryConstFolding(unittest.TestCase): def test_threefry(self): - x = UOp.const(dtypes.uint64, 5, Device.DEFAULT, ()).threefry(UOp.const(dtypes.uint64, 10, Device.DEFAULT, ())) + x = UOp.const(dtypes.uint64, 5).threefry(UOp.const(dtypes.uint64, 10)) self.assertIs(x.simplify().op, Ops.CONST) class TestTautologicalCompare(unittest.TestCase): diff --git a/test/null/test_viz.py b/test/null/test_viz.py index d3727a09d0..120c20515b 100644 --- a/test/null/test_viz.py +++ b/test/null/test_viz.py @@ -233,7 +233,7 @@ class TestViz(unittest.TestCase): def test_const_reshape_expand_folded(self): # CONST->RESHAPE->EXPAND should be folded into the ALU node, not shown as separate RESHAPE/EXPAND nodes - c = UOp.const(dtypes.float, 1.0, device="CPU", shape=(3,4)) # creates CONST->RESHAPE->EXPAND chain + c = UOp.const(dtypes.float, 1.0, shape=(3,4)) # creates CONST->RESHAPE->EXPAND chain a = UOp(Ops.DEFINE_VAR, dtypes.float, arg=("a", 0.0, 10.0)) alu = a + c with save_viz() as viz: @@ -244,7 +244,7 @@ class TestViz(unittest.TestCase): self.assertIn("STACK", excluded_nodes) self.assertIn("RESHAPE", excluded_nodes) self.assertIn("EXPAND", excluded_nodes) - self.assertIn("CONST1 1 Ops.DEVICE", graph[id(alu)]["label"]) + self.assertIn("CONST1 1", graph[id(alu)]["label"]) def test_stack_movement_not_folded_unless_all_const(self): a = UOp.variable("a", 0, 10, dtype=dtypes.int) diff --git a/test/unit/test_tensor_data.py b/test/unit/test_tensor_data.py index 8c8f4eec3e..0d69e69429 100644 --- a/test/unit/test_tensor_data.py +++ b/test/unit/test_tensor_data.py @@ -83,9 +83,5 @@ class TestTensorData(unittest.TestCase): assert dat.shape == (2,2) # NOTE: python can't deref float16 - def test_data_uop_device(self): - uop = UOp.const(dtypes.float, 1.0, "DEVICE") - self.assertEqual(Tensor(uop).device, "DEVICE") - if __name__ == '__main__': unittest.main() \ No newline at end of file diff --git a/tinygrad/callify.py b/tinygrad/callify.py index 02f866974b..5ec1c7ed62 100644 --- a/tinygrad/callify.py +++ b/tinygrad/callify.py @@ -182,8 +182,6 @@ def replace_input_buffer(ctx:AllocCtx, b:UOp): pm_finalize_call = PatternMatcher([ (UPat(Ops.AFTER, name="x"), finalize_after), (UPat(Ops.COPY, name="x"), lambda ctx,x: ctx.assigns.append(x) if isinstance(x.device, str) and x.device.startswith(("DISK", "TINYFS")) else None), - # remove unique from const. TODO: this is copied in function.py - (UPat(Ops.CONST, src=(UPat(Ops.UNIQUE), UPat(Ops.DEVICE, name="d")), name="b"), lambda b,d: b.replace(src=(d,))), ]) pm_replace_buf = PatternMatcher([ diff --git a/tinygrad/mixin/__init__.py b/tinygrad/mixin/__init__.py index e5fad7db23..25d8ce9c56 100644 --- a/tinygrad/mixin/__init__.py +++ b/tinygrad/mixin/__init__.py @@ -19,7 +19,7 @@ class OpMixin(ElementwiseMixin, ReduceMixin): @staticmethod def unique_const(fill_value:ConstType, **kwargs): raise NotImplementedError("creation helpers are only supported on Tensor and UOp") @staticmethod - def const(dtype, b, device=None): raise NotImplementedError("creation helpers are only supported on Tensor and UOp") + def const(dtype, b): raise NotImplementedError("creation helpers are only supported on Tensor and UOp") @classmethod def full(cls, shape:tuple[sint, ...], fill_value:ConstType|UOp, dtype:DTypeLike|None=None, diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index f06ad2c90d..84a576005c 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -159,8 +159,8 @@ class Tensor(OpMixin): def alu(self, op: Ops, *src: Tensor) -> Tensor: return self._apply_uop(lambda *u: u[0].alu(op, *u[1:]), *src) def const_like(self, b:ConstType) -> Tensor: return Tensor(self.uop.const_like(b)) @staticmethod - def const(dtype:DType, b:ConstType|UOp, device:str|tuple[str, ...]|None=None) -> Tensor: - return Tensor(UOp.const(dtype, b, device)) + def const(dtype:DType, b:ConstType|UOp) -> Tensor: + return Tensor(UOp.const(dtype, b)) @staticmethod def unique_const(fill_value:ConstType|UOp, **kwargs) -> Tensor: if isinstance(fill_value, UOp): return Tensor(fill_value, **kwargs) diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index b7ddca653c..573e9cf888 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -537,24 +537,25 @@ class UOp(OpMixin, metaclass=UOpMetaClass): if op in {Ops.CMPLT, Ops.CMPNE, Ops.CMPEQ}: out_dtype = dtypes.bool.vec(out_dtype.count) if out_dtype.count > 1 else dtypes.bool return UOp(op, out_dtype, all_srcs, **kwargs) @staticmethod - def const(dtype:DType, b:ConstLike, device:str|tuple[str, ...]|None=None, shape:tuple[sint, ...]|None=None): + def const(dtype:DType, b:ConstLike, shape:tuple[sint, ...]|None=None): if isinstance(b, UOp): return b.cast(dtype) if isinstance(b, tuple) and all_same(b): assert len(b) > 0, "can't create const from empty tuple" b = b[0] # doesn't have to be a STACK if they are all the same if isinstance(b, tuple): - stk = [UOp(Ops.CONST, dtype.scalar(), arg=dtype.const(c), src=(UOp(Ops.DEVICE, arg=device),) if device is not None else ()) for c in b] + stk = [UOp(Ops.CONST, dtype.scalar(), arg=dtype.const(c), src=()) for c in b] ret = UOp.vectorize(*stk) else: - ret = UOp(Ops.CONST, dtype, arg=dtype.const(b), src=(UOp(Ops.DEVICE, arg=device),) if device is not None else ()) + ret = UOp(Ops.CONST, dtype, arg=dtype.const(b), src=()) return ret.reshape((1,)*len(shape)).expand(shape) if shape is not None and shape != () and ret.shape != shape else ret @staticmethod def unique_const(fill_value:ConstType, dtype:DTypeLike|None=None, device:str|tuple[str, ...]|None=None, # type: ignore[override] shape:tuple[sint, ...]|None=None, unique=True): # NOTE: fill_value is ConstType, not ConstLike, so UOps and tuples aren't allowed assert not isinstance(fill_value, (UOp, tuple)), "unique const only works on numbers" - ret = UOp.const(to_dtype(dtype) if dtype is not None else dtypes.from_py(fill_value), fill_value, canonicalize_device(device)) - ret = ret.replace(src=(UOp.unique(None if unique is True else unique),) + ret.src) + dt = to_dtype(dtype) if dtype is not None else dtypes.from_py(fill_value) + ret = UOp(Ops.CONST, dt, arg=dt.const(fill_value), + src=(UOp.unique(None if unique is True else unique), UOp(Ops.DEVICE, arg=canonicalize_device(device)))) return ret.reshape((1,)*len(shape)).expand(shape) if shape is not None and ret.shape != shape else ret @staticmethod def range(end:sint, axis_id, axis_type=AxisType.LOOP, *arg, dtype=dtypes.weakint, src=(), **kwargs): diff --git a/tinygrad/uop/render.py b/tinygrad/uop/render.py index 8b87605299..8c8634e3f9 100644 --- a/tinygrad/uop/render.py +++ b/tinygrad/uop/render.py @@ -81,7 +81,6 @@ sugar = {Ops.SINK, Ops.END, Ops.STORE, Ops.LOAD, Ops.UNIQUE, Ops.SQRT, Ops.INDEX pm_pyrender_extra = PatternMatcher([ (UPat(Ops.CONST, src=(UPat(Ops.UNIQUE, name="u"), UPat(Ops.DEVICE, name="d")), name="x"), lambda x,u,d: f"UOp.unique_const({x.arg}, dtype={x.dtype}, device={repr(d.arg)}, unique={u.arg})"), - (UPat(Ops.CONST, src=(UPat(Ops.DEVICE, name="d"),), name="x"), lambda x,d: f"UOp.const({x.dtype}, {x.arg}, device={repr(d.arg)})"), (UPat(Ops.CONST, src=(), name="x"), lambda x: f"UOp.const({x.dtype}, {x.arg})"), (UPat(Ops.DEFINE_VAR, src=(), name="x"), lambda x: f"UOp.variable(\"{x.arg[0]}\", {x.arg[1]}, {x.arg[2]}{', dtype='+str(x.dtype) if x.dtype is not dtypes.weakint else ''})"), @@ -106,12 +105,6 @@ pm_pyrender_extra = PatternMatcher([ # explicit trunc ops: `//` and `%` parse as FLOORDIV/FLOORMOD, so render CDIV/CMOD via .alu() (UPat(Ops.CDIV, name="x"), lambda ctx,x: f"{ctx[x.src[0]]}.alu(Ops.CDIV, {ctx[x.src[1]]})"), (UPat(Ops.CMOD, name="x"), lambda ctx,x: f"{ctx[x.src[0]]}.alu(Ops.CMOD, {ctx[x.src[1]]})"), - # NOTE: only match CONSTs without UNIQUE (len(src)==1), unique_const needs explicit rendering - (UPat(set(syms.keys())-{Ops.SUB, Ops.CMPNE, Ops.CDIV, Ops.CMOD}, src=(UPat(Ops.CONST, src=(UPat(Ops.DEVICE),), name="y"), UPat(name="z")), - name="x"), lambda ctx,x,y,z: strip_binary_parens(x, str(y.arg), ctx[z], lambda a,b: f"({a}{syms[x.op]}{b})") if y.device==z.device else None), - # NOTE: sub doesn't work cause it's written as add/mul - (UPat(set(syms.keys())-{Ops.SUB, Ops.CDIV, Ops.CMOD}, src=(UPat(name="y"), UPat(Ops.CONST, src=(UPat(Ops.DEVICE),), name="z")), name="x"), - lambda ctx,x,y,z: strip_binary_parens(x, ctx[y], str(z.arg), lambda a,b: f"({a}{syms[x.op]}{b})") if y.device==z.device else None), (UPat(set(syms.keys())-{Ops.SUB, Ops.CDIV, Ops.CMOD}, name="x"), lambda ctx,x: strip_binary_parens(x, ctx[x.src[0]], ctx[x.src[1]], lambda a,b: f"({a}{syms[x.op]}{b})")), (UPat(sugar, src=(), name="x"), lambda x: f"UOp.{x.op.name.lower()}("+', '.join(([f'arg={repr(x.arg)}'] if x.arg is not None else []))+")"), diff --git a/tinygrad/uop/spec.py b/tinygrad/uop/spec.py index 580a2e0828..aaa7d0614d 100644 --- a/tinygrad/uop/spec.py +++ b/tinygrad/uop/spec.py @@ -124,8 +124,7 @@ spec_tensor = PatternMatcher([ (UPat(Ops.UNIQUE, dtypes.void, ()), lambda: True), (UPat(Ops.LUNIQUE, dtypes.void, ()), lambda: True), - # CONST with a UNIQUE or DEVICE - (UPat(Ops.CONST, src=(UPat(Ops.DEVICE),)), lambda: True), + # CONST with a UNIQUE and DEVICE (UPat(Ops.CONST, src=(UPat((Ops.UNIQUE, Ops.LUNIQUE)), UPat(Ops.DEVICE)), name="c"), lambda c: c.arg is Invalid), # BUFFER