more tensor and mixin cleanups [PR] (#16558)

This commit is contained in:
chenyu
2026-06-10 00:39:33 -04:00
committed by GitHub
parent 7366d32247
commit 623b66e0e4
2 changed files with 7 additions and 11 deletions

View File

@@ -757,7 +757,7 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
base = ret[..., -1]._cumalu(-1, op)._pad_constant((None,)*(ret.ndim-2) + ((1, -1),), value)
base = base.unsqueeze(-1).expand(*base.shape, ret.shape[-1])
def fix(x: Self) -> Self: return x.flatten(start_dim=-2)[..., -s:].transpose(axis,-1)
return getattr(fix(ret), {Ops.ADD: "add", Ops.MAX: "maximum", Ops.MUL: "mul"}[op])(fix(base))
return fix(ret).alu(op, fix(base))
def cumsum(self, axis:int=0) -> Self:
"""

View File

@@ -185,11 +185,8 @@ class Tensor(RandMixin):
# ***** data handlers ****
def as_param(self, slot:int):
if self.uop.axis is not None:
param = UOp.param(slot, self.dtype, self.uop.shard_shape, self.device, axis=self.uop.axis)
else:
param = UOp.param(slot, self.dtype, self.shape, self.device)
return Tensor(param)
return Tensor(UOp.param(slot, self.dtype, self.uop.shard_shape, self.device, axis=self.uop.axis))
def call(self, *lst:Tensor, fxn:Tensor|UOp, grad_fxn:Callable|None=None) -> Tensor:
fret = fxn._uop.call(*[t.uop for t in (self,)+lst], grad_fxn=grad_fxn)
return Tensor(fret.gettuple(0))
@@ -449,26 +446,25 @@ class Tensor(RandMixin):
# ***** creation entrypoint *****
@staticmethod
def empty(*shape, device:str|tuple[str, ...]|None=None, dtype:DTypeLike|None=None, **kwargs) -> Tensor:
def empty(*shape, device:str|tuple[str, ...]|None=None, dtype:DTypeLike|None=None) -> Tensor:
"""
Creates an empty tensor with the given shape.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.empty(2, 3)
print(t.shape)
```
"""
return Tensor(UOp.empty(argfix(*shape), dtype, device), **kwargs)
return Tensor(UOp.empty(argfix(*shape), dtype, device))
def empty_like(self, dtype:DTypeLike|None=None, device:str|tuple[str, ...]|None=None, **kwargs) -> Tensor:
def empty_like(self, dtype:DTypeLike|None=None, device:str|tuple[str, ...]|None=None) -> Tensor:
"""
Creates an empty tensor with the same shape as `self`.
If `dtype` is not specified, the dtype of `self` is used.
"""
return Tensor(self.uop.empty_like(dtype, device), **kwargs)
return Tensor(self.uop.empty_like(dtype, device))
@staticmethod
def from_blob(ptr:int, shape:tuple[int, ...], **kwargs) -> Tensor: