From 623b66e0e4e8e519038f9f5cd86a8ab6976032c8 Mon Sep 17 00:00:00 2001 From: chenyu Date: Wed, 10 Jun 2026 00:39:33 -0400 Subject: [PATCH] more tensor and mixin cleanups [PR] (#16558) --- tinygrad/mixin/__init__.py | 2 +- tinygrad/tensor.py | 16 ++++++---------- 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/tinygrad/mixin/__init__.py b/tinygrad/mixin/__init__.py index 1ea01748de..dd3cd04edc 100644 --- a/tinygrad/mixin/__init__.py +++ b/tinygrad/mixin/__init__.py @@ -757,7 +757,7 @@ class OpMixin(ElementwiseMixin, ReduceMixin): base = ret[..., -1]._cumalu(-1, op)._pad_constant((None,)*(ret.ndim-2) + ((1, -1),), value) base = base.unsqueeze(-1).expand(*base.shape, ret.shape[-1]) def fix(x: Self) -> Self: return x.flatten(start_dim=-2)[..., -s:].transpose(axis,-1) - return getattr(fix(ret), {Ops.ADD: "add", Ops.MAX: "maximum", Ops.MUL: "mul"}[op])(fix(base)) + return fix(ret).alu(op, fix(base)) def cumsum(self, axis:int=0) -> Self: """ diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 7a05445316..2b4056be9d 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -185,11 +185,8 @@ class Tensor(RandMixin): # ***** data handlers **** def as_param(self, slot:int): - if self.uop.axis is not None: - param = UOp.param(slot, self.dtype, self.uop.shard_shape, self.device, axis=self.uop.axis) - else: - param = UOp.param(slot, self.dtype, self.shape, self.device) - return Tensor(param) + return Tensor(UOp.param(slot, self.dtype, self.uop.shard_shape, self.device, axis=self.uop.axis)) + def call(self, *lst:Tensor, fxn:Tensor|UOp, grad_fxn:Callable|None=None) -> Tensor: fret = fxn._uop.call(*[t.uop for t in (self,)+lst], grad_fxn=grad_fxn) return Tensor(fret.gettuple(0)) @@ -449,26 +446,25 @@ class Tensor(RandMixin): # ***** creation entrypoint ***** @staticmethod - def empty(*shape, device:str|tuple[str, ...]|None=None, dtype:DTypeLike|None=None, **kwargs) -> Tensor: + def empty(*shape, device:str|tuple[str, ...]|None=None, dtype:DTypeLike|None=None) -> Tensor: """ Creates an empty tensor with the given shape. You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor. - Additionally, all other keyword arguments are passed to the constructor of the tensor. ```python exec="true" source="above" session="tensor" result="python" t = Tensor.empty(2, 3) print(t.shape) ``` """ - return Tensor(UOp.empty(argfix(*shape), dtype, device), **kwargs) + return Tensor(UOp.empty(argfix(*shape), dtype, device)) - def empty_like(self, dtype:DTypeLike|None=None, device:str|tuple[str, ...]|None=None, **kwargs) -> Tensor: + def empty_like(self, dtype:DTypeLike|None=None, device:str|tuple[str, ...]|None=None) -> Tensor: """ Creates an empty tensor with the same shape as `self`. If `dtype` is not specified, the dtype of `self` is used. """ - return Tensor(self.uop.empty_like(dtype, device), **kwargs) + return Tensor(self.uop.empty_like(dtype, device)) @staticmethod def from_blob(ptr:int, shape:tuple[int, ...], **kwargs) -> Tensor: