mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-11 07:27:43 +08:00
more tensor and mixin cleanups [PR] (#16558)
This commit is contained in:
@@ -757,7 +757,7 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
|
||||
base = ret[..., -1]._cumalu(-1, op)._pad_constant((None,)*(ret.ndim-2) + ((1, -1),), value)
|
||||
base = base.unsqueeze(-1).expand(*base.shape, ret.shape[-1])
|
||||
def fix(x: Self) -> Self: return x.flatten(start_dim=-2)[..., -s:].transpose(axis,-1)
|
||||
return getattr(fix(ret), {Ops.ADD: "add", Ops.MAX: "maximum", Ops.MUL: "mul"}[op])(fix(base))
|
||||
return fix(ret).alu(op, fix(base))
|
||||
|
||||
def cumsum(self, axis:int=0) -> Self:
|
||||
"""
|
||||
|
||||
@@ -185,11 +185,8 @@ class Tensor(RandMixin):
|
||||
# ***** data handlers ****
|
||||
|
||||
def as_param(self, slot:int):
|
||||
if self.uop.axis is not None:
|
||||
param = UOp.param(slot, self.dtype, self.uop.shard_shape, self.device, axis=self.uop.axis)
|
||||
else:
|
||||
param = UOp.param(slot, self.dtype, self.shape, self.device)
|
||||
return Tensor(param)
|
||||
return Tensor(UOp.param(slot, self.dtype, self.uop.shard_shape, self.device, axis=self.uop.axis))
|
||||
|
||||
def call(self, *lst:Tensor, fxn:Tensor|UOp, grad_fxn:Callable|None=None) -> Tensor:
|
||||
fret = fxn._uop.call(*[t.uop for t in (self,)+lst], grad_fxn=grad_fxn)
|
||||
return Tensor(fret.gettuple(0))
|
||||
@@ -449,26 +446,25 @@ class Tensor(RandMixin):
|
||||
# ***** creation entrypoint *****
|
||||
|
||||
@staticmethod
|
||||
def empty(*shape, device:str|tuple[str, ...]|None=None, dtype:DTypeLike|None=None, **kwargs) -> Tensor:
|
||||
def empty(*shape, device:str|tuple[str, ...]|None=None, dtype:DTypeLike|None=None) -> Tensor:
|
||||
"""
|
||||
Creates an empty tensor with the given shape.
|
||||
|
||||
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
||||
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor.empty(2, 3)
|
||||
print(t.shape)
|
||||
```
|
||||
"""
|
||||
return Tensor(UOp.empty(argfix(*shape), dtype, device), **kwargs)
|
||||
return Tensor(UOp.empty(argfix(*shape), dtype, device))
|
||||
|
||||
def empty_like(self, dtype:DTypeLike|None=None, device:str|tuple[str, ...]|None=None, **kwargs) -> Tensor:
|
||||
def empty_like(self, dtype:DTypeLike|None=None, device:str|tuple[str, ...]|None=None) -> Tensor:
|
||||
"""
|
||||
Creates an empty tensor with the same shape as `self`.
|
||||
If `dtype` is not specified, the dtype of `self` is used.
|
||||
"""
|
||||
return Tensor(self.uop.empty_like(dtype, device), **kwargs)
|
||||
return Tensor(self.uop.empty_like(dtype, device))
|
||||
|
||||
@staticmethod
|
||||
def from_blob(ptr:int, shape:tuple[int, ...], **kwargs) -> Tensor:
|
||||
|
||||
Reference in New Issue
Block a user