remove CONST(DEVICE) (#16506)

This commit is contained in:
chenyu
2026-06-04 11:29:46 -04:00
committed by GitHub
parent 3838c8df1b
commit 4a8bf07a87
9 changed files with 13 additions and 26 deletions

View File

@@ -167,7 +167,7 @@ class TestMultiConstFolding(unittest.TestCase):
class TestThreefryConstFolding(unittest.TestCase):
def test_threefry(self):
x = UOp.const(dtypes.uint64, 5, Device.DEFAULT, ()).threefry(UOp.const(dtypes.uint64, 10, Device.DEFAULT, ()))
x = UOp.const(dtypes.uint64, 5).threefry(UOp.const(dtypes.uint64, 10))
self.assertIs(x.simplify().op, Ops.CONST)
class TestTautologicalCompare(unittest.TestCase):

View File

@@ -233,7 +233,7 @@ class TestViz(unittest.TestCase):
def test_const_reshape_expand_folded(self):
# CONST->RESHAPE->EXPAND should be folded into the ALU node, not shown as separate RESHAPE/EXPAND nodes
c = UOp.const(dtypes.float, 1.0, device="CPU", shape=(3,4)) # creates CONST->RESHAPE->EXPAND chain
c = UOp.const(dtypes.float, 1.0, shape=(3,4)) # creates CONST->RESHAPE->EXPAND chain
a = UOp(Ops.DEFINE_VAR, dtypes.float, arg=("a", 0.0, 10.0))
alu = a + c
with save_viz() as viz:
@@ -244,7 +244,7 @@ class TestViz(unittest.TestCase):
self.assertIn("STACK", excluded_nodes)
self.assertIn("RESHAPE", excluded_nodes)
self.assertIn("EXPAND", excluded_nodes)
self.assertIn("CONST1 1 Ops.DEVICE", graph[id(alu)]["label"])
self.assertIn("CONST1 1", graph[id(alu)]["label"])
def test_stack_movement_not_folded_unless_all_const(self):
a = UOp.variable("a", 0, 10, dtype=dtypes.int)

View File

@@ -83,9 +83,5 @@ class TestTensorData(unittest.TestCase):
assert dat.shape == (2,2)
# NOTE: python can't deref float16
def test_data_uop_device(self):
uop = UOp.const(dtypes.float, 1.0, "DEVICE")
self.assertEqual(Tensor(uop).device, "DEVICE")
if __name__ == '__main__':
unittest.main()

View File

@@ -182,8 +182,6 @@ def replace_input_buffer(ctx:AllocCtx, b:UOp):
pm_finalize_call = PatternMatcher([
(UPat(Ops.AFTER, name="x"), finalize_after),
(UPat(Ops.COPY, name="x"), lambda ctx,x: ctx.assigns.append(x) if isinstance(x.device, str) and x.device.startswith(("DISK", "TINYFS")) else None),
# remove unique from const. TODO: this is copied in function.py
(UPat(Ops.CONST, src=(UPat(Ops.UNIQUE), UPat(Ops.DEVICE, name="d")), name="b"), lambda b,d: b.replace(src=(d,))),
])
pm_replace_buf = PatternMatcher([

View File

@@ -19,7 +19,7 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
@staticmethod
def unique_const(fill_value:ConstType, **kwargs): raise NotImplementedError("creation helpers are only supported on Tensor and UOp")
@staticmethod
def const(dtype, b, device=None): raise NotImplementedError("creation helpers are only supported on Tensor and UOp")
def const(dtype, b): raise NotImplementedError("creation helpers are only supported on Tensor and UOp")
@classmethod
def full(cls, shape:tuple[sint, ...], fill_value:ConstType|UOp, dtype:DTypeLike|None=None,

View File

@@ -159,8 +159,8 @@ class Tensor(OpMixin):
def alu(self, op: Ops, *src: Tensor) -> Tensor: return self._apply_uop(lambda *u: u[0].alu(op, *u[1:]), *src)
def const_like(self, b:ConstType) -> Tensor: return Tensor(self.uop.const_like(b))
@staticmethod
def const(dtype:DType, b:ConstType|UOp, device:str|tuple[str, ...]|None=None) -> Tensor:
return Tensor(UOp.const(dtype, b, device))
def const(dtype:DType, b:ConstType|UOp) -> Tensor:
return Tensor(UOp.const(dtype, b))
@staticmethod
def unique_const(fill_value:ConstType|UOp, **kwargs) -> Tensor:
if isinstance(fill_value, UOp): return Tensor(fill_value, **kwargs)

View File

@@ -537,24 +537,25 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
if op in {Ops.CMPLT, Ops.CMPNE, Ops.CMPEQ}: out_dtype = dtypes.bool.vec(out_dtype.count) if out_dtype.count > 1 else dtypes.bool
return UOp(op, out_dtype, all_srcs, **kwargs)
@staticmethod
def const(dtype:DType, b:ConstLike, device:str|tuple[str, ...]|None=None, shape:tuple[sint, ...]|None=None):
def const(dtype:DType, b:ConstLike, shape:tuple[sint, ...]|None=None):
if isinstance(b, UOp): return b.cast(dtype)
if isinstance(b, tuple) and all_same(b):
assert len(b) > 0, "can't create const from empty tuple"
b = b[0] # doesn't have to be a STACK if they are all the same
if isinstance(b, tuple):
stk = [UOp(Ops.CONST, dtype.scalar(), arg=dtype.const(c), src=(UOp(Ops.DEVICE, arg=device),) if device is not None else ()) for c in b]
stk = [UOp(Ops.CONST, dtype.scalar(), arg=dtype.const(c), src=()) for c in b]
ret = UOp.vectorize(*stk)
else:
ret = UOp(Ops.CONST, dtype, arg=dtype.const(b), src=(UOp(Ops.DEVICE, arg=device),) if device is not None else ())
ret = UOp(Ops.CONST, dtype, arg=dtype.const(b), src=())
return ret.reshape((1,)*len(shape)).expand(shape) if shape is not None and shape != () and ret.shape != shape else ret
@staticmethod
def unique_const(fill_value:ConstType, dtype:DTypeLike|None=None, device:str|tuple[str, ...]|None=None, # type: ignore[override]
shape:tuple[sint, ...]|None=None, unique=True):
# NOTE: fill_value is ConstType, not ConstLike, so UOps and tuples aren't allowed
assert not isinstance(fill_value, (UOp, tuple)), "unique const only works on numbers"
ret = UOp.const(to_dtype(dtype) if dtype is not None else dtypes.from_py(fill_value), fill_value, canonicalize_device(device))
ret = ret.replace(src=(UOp.unique(None if unique is True else unique),) + ret.src)
dt = to_dtype(dtype) if dtype is not None else dtypes.from_py(fill_value)
ret = UOp(Ops.CONST, dt, arg=dt.const(fill_value),
src=(UOp.unique(None if unique is True else unique), UOp(Ops.DEVICE, arg=canonicalize_device(device))))
return ret.reshape((1,)*len(shape)).expand(shape) if shape is not None and ret.shape != shape else ret
@staticmethod
def range(end:sint, axis_id, axis_type=AxisType.LOOP, *arg, dtype=dtypes.weakint, src=(), **kwargs):

View File

@@ -81,7 +81,6 @@ sugar = {Ops.SINK, Ops.END, Ops.STORE, Ops.LOAD, Ops.UNIQUE, Ops.SQRT, Ops.INDEX
pm_pyrender_extra = PatternMatcher([
(UPat(Ops.CONST, src=(UPat(Ops.UNIQUE, name="u"), UPat(Ops.DEVICE, name="d")), name="x"),
lambda x,u,d: f"UOp.unique_const({x.arg}, dtype={x.dtype}, device={repr(d.arg)}, unique={u.arg})"),
(UPat(Ops.CONST, src=(UPat(Ops.DEVICE, name="d"),), name="x"), lambda x,d: f"UOp.const({x.dtype}, {x.arg}, device={repr(d.arg)})"),
(UPat(Ops.CONST, src=(), name="x"), lambda x: f"UOp.const({x.dtype}, {x.arg})"),
(UPat(Ops.DEFINE_VAR, src=(), name="x"), lambda x:
f"UOp.variable(\"{x.arg[0]}\", {x.arg[1]}, {x.arg[2]}{', dtype='+str(x.dtype) if x.dtype is not dtypes.weakint else ''})"),
@@ -106,12 +105,6 @@ pm_pyrender_extra = PatternMatcher([
# explicit trunc ops: `//` and `%` parse as FLOORDIV/FLOORMOD, so render CDIV/CMOD via .alu()
(UPat(Ops.CDIV, name="x"), lambda ctx,x: f"{ctx[x.src[0]]}.alu(Ops.CDIV, {ctx[x.src[1]]})"),
(UPat(Ops.CMOD, name="x"), lambda ctx,x: f"{ctx[x.src[0]]}.alu(Ops.CMOD, {ctx[x.src[1]]})"),
# NOTE: only match CONSTs without UNIQUE (len(src)==1), unique_const needs explicit rendering
(UPat(set(syms.keys())-{Ops.SUB, Ops.CMPNE, Ops.CDIV, Ops.CMOD}, src=(UPat(Ops.CONST, src=(UPat(Ops.DEVICE),), name="y"), UPat(name="z")),
name="x"), lambda ctx,x,y,z: strip_binary_parens(x, str(y.arg), ctx[z], lambda a,b: f"({a}{syms[x.op]}{b})") if y.device==z.device else None),
# NOTE: sub doesn't work cause it's written as add/mul
(UPat(set(syms.keys())-{Ops.SUB, Ops.CDIV, Ops.CMOD}, src=(UPat(name="y"), UPat(Ops.CONST, src=(UPat(Ops.DEVICE),), name="z")), name="x"),
lambda ctx,x,y,z: strip_binary_parens(x, ctx[y], str(z.arg), lambda a,b: f"({a}{syms[x.op]}{b})") if y.device==z.device else None),
(UPat(set(syms.keys())-{Ops.SUB, Ops.CDIV, Ops.CMOD}, name="x"), lambda ctx,x:
strip_binary_parens(x, ctx[x.src[0]], ctx[x.src[1]], lambda a,b: f"({a}{syms[x.op]}{b})")),
(UPat(sugar, src=(), name="x"), lambda x: f"UOp.{x.op.name.lower()}("+', '.join(([f'arg={repr(x.arg)}'] if x.arg is not None else []))+")"),

View File

@@ -124,8 +124,7 @@ spec_tensor = PatternMatcher([
(UPat(Ops.UNIQUE, dtypes.void, ()), lambda: True),
(UPat(Ops.LUNIQUE, dtypes.void, ()), lambda: True),
# CONST with a UNIQUE or DEVICE
(UPat(Ops.CONST, src=(UPat(Ops.DEVICE),)), lambda: True),
# CONST with a UNIQUE and DEVICE
(UPat(Ops.CONST, src=(UPat((Ops.UNIQUE, Ops.LUNIQUE)), UPat(Ops.DEVICE)), name="c"), lambda c: c.arg is Invalid),
# BUFFER