mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 16:37:04 +08:00
no args and kwargs for _multi_like [PR] (#16539)
This commit is contained in:
@@ -562,12 +562,10 @@ class Tensor(OpMixin):
|
||||
|
||||
# ***** creation helper functions *****
|
||||
|
||||
def _multi_like(self, fxn, *args, **kwargs) -> Tensor:
|
||||
dtype = kwargs.pop("dtype", self.dtype)
|
||||
if kwargs.pop("device", None) is not None: raise RuntimeError("cannot specify `device` on `*_like` of a multi device tensor")
|
||||
def _multi_like(self, fxn:Callable[[tuple[sint, ...], str|None], Tensor]) -> Tensor:
|
||||
assert isinstance(self.device, tuple), f"_multi_like needs a multi device tensor, got {self.device}"
|
||||
if self.uop.axis is None: return fxn(self.shape, *args, dtype=dtype, **kwargs).shard(self.device)
|
||||
stacked = UOp.mstack(*[fxn(self.uop.shard_shape, *args, device=d, dtype=dtype, **kwargs).uop for d in self.device])
|
||||
if self.uop.axis is None: return fxn(self.shape, None).shard(self.device)
|
||||
stacked = UOp.mstack(*[fxn(self.uop.shard_shape, d).uop for d in self.device])
|
||||
return Tensor(stacked.multi(self.uop.axis))
|
||||
|
||||
def full_like(self, fill_value:ConstType, dtype=None, device=None) -> Tensor:
|
||||
@@ -582,7 +580,9 @@ class Tensor(OpMixin):
|
||||
print(Tensor.full_like(t, 42).numpy())
|
||||
```
|
||||
"""
|
||||
if isinstance(self.device, tuple): return self._multi_like(Tensor.full, fill_value, dtype=dtype or self.dtype, device=device)
|
||||
if isinstance(self.device, tuple):
|
||||
if device is not None: raise RuntimeError("cannot specify `device` on `*_like` of a multi device tensor")
|
||||
return self._multi_like(lambda shape, dev: Tensor.full(shape, fill_value, dtype=dtype or self.dtype, device=dev))
|
||||
return Tensor.full(self.shape, fill_value, dtype=dtype or self.dtype, device=self.device if device is None else device)
|
||||
|
||||
def rand_like(self, **kwargs) -> Tensor:
|
||||
@@ -597,7 +597,10 @@ class Tensor(OpMixin):
|
||||
print(Tensor.rand_like(t).numpy())
|
||||
```
|
||||
"""
|
||||
if isinstance(self.device, tuple): return self._multi_like(Tensor.rand, **kwargs)
|
||||
if isinstance(self.device, tuple):
|
||||
if kwargs.pop("device", None) is not None: raise RuntimeError("cannot specify `device` on `*_like` of a multi device tensor")
|
||||
dtype = kwargs.pop("dtype", self.dtype)
|
||||
return self._multi_like(lambda shape, dev: Tensor.rand(*shape, dtype=dtype, device=dev, **kwargs))
|
||||
return Tensor.rand(*self.shape, device=kwargs.pop("device", self.device), dtype=kwargs.pop("dtype", self.dtype), **kwargs)
|
||||
|
||||
# ***** random functions *****
|
||||
|
||||
Reference in New Issue
Block a user