remove device= from UPat.const [PR] (#16530)

This commit is contained in:
chenyu
2026-06-07 16:38:43 -04:00
committed by GitHub
parent eb1238436a
commit 937aeaec60
3 changed files with 4 additions and 5 deletions

View File

@@ -17,9 +17,9 @@ ReductionStr = Literal["mean", "sum", "none"]
class OpMixin(ElementwiseMixin, ReduceMixin):
@staticmethod
def unique_const(fill_value:ConstType, **kwargs): raise NotImplementedError("creation helpers are only supported on Tensor and UOp")
def unique_const(fill_value:ConstType, **kwargs): raise NotImplementedError
@staticmethod
def const(dtype, b): raise NotImplementedError("creation helpers are only supported on Tensor and UOp")
def const(dtype, b): raise NotImplementedError
@property
def _uop(self) -> UOp: raise NotImplementedError
def _wrap_uop(self, u:UOp) -> Self: raise NotImplementedError

View File

@@ -161,8 +161,7 @@ class Tensor(OpMixin):
def _wrap_uop(self, u:UOp) -> Tensor: return Tensor(u)
def const_like(self, b:ConstType) -> Tensor: return Tensor(self.uop.const_like(b))
@staticmethod
def const(dtype:DType, b:ConstType|UOp) -> Tensor:
return Tensor(UOp.const(dtype, b))
def const(dtype:DType, b:ConstType|UOp) -> Tensor: return Tensor(UOp.const(dtype, b))
@staticmethod
def unique_const(fill_value:ConstType|UOp, **kwargs) -> Tensor:
if isinstance(fill_value, UOp): return Tensor(fill_value, **kwargs)

View File

@@ -1253,7 +1253,7 @@ class UPat(OpMixin):
@functools.cache
def cvar(name:str|None=None, dtype:DType|tuple[DType, ...]|None=None, arg=None): return UPat(Ops.CONST, dtype, name=name, arg=arg)
@staticmethod
def const(dtype:DType|tuple[DType, ...]|None, b:ConstType, device=None): return UPat(Ops.CONST, dtype=dtype, arg=b)
def const(dtype:DType|tuple[DType, ...]|None, b:ConstType): return UPat(Ops.CONST, dtype=dtype, arg=b)
# lil helper
def f(self, op, **kwargs): return UPat(op, src=(self,), **kwargs)