no args and kwargs for _multi_like [PR] (#16539)

This commit is contained in:
chenyu
2026-06-08 17:35:15 -04:00
committed by GitHub
parent 12764161c9
commit e2ef5cf5c9

View File

@@ -562,12 +562,10 @@ class Tensor(OpMixin):
# ***** creation helper functions *****
def _multi_like(self, fxn, *args, **kwargs) -> Tensor:
dtype = kwargs.pop("dtype", self.dtype)
if kwargs.pop("device", None) is not None: raise RuntimeError("cannot specify `device` on `*_like` of a multi device tensor")
def _multi_like(self, fxn:Callable[[tuple[sint, ...], str|None], Tensor]) -> Tensor:
assert isinstance(self.device, tuple), f"_multi_like needs a multi device tensor, got {self.device}"
if self.uop.axis is None: return fxn(self.shape, *args, dtype=dtype, **kwargs).shard(self.device)
stacked = UOp.mstack(*[fxn(self.uop.shard_shape, *args, device=d, dtype=dtype, **kwargs).uop for d in self.device])
if self.uop.axis is None: return fxn(self.shape, None).shard(self.device)
stacked = UOp.mstack(*[fxn(self.uop.shard_shape, d).uop for d in self.device])
return Tensor(stacked.multi(self.uop.axis))
def full_like(self, fill_value:ConstType, dtype=None, device=None) -> Tensor:
@@ -582,7 +580,9 @@ class Tensor(OpMixin):
print(Tensor.full_like(t, 42).numpy())
```
"""
if isinstance(self.device, tuple): return self._multi_like(Tensor.full, fill_value, dtype=dtype or self.dtype, device=device)
if isinstance(self.device, tuple):
if device is not None: raise RuntimeError("cannot specify `device` on `*_like` of a multi device tensor")
return self._multi_like(lambda shape, dev: Tensor.full(shape, fill_value, dtype=dtype or self.dtype, device=dev))
return Tensor.full(self.shape, fill_value, dtype=dtype or self.dtype, device=self.device if device is None else device)
def rand_like(self, **kwargs) -> Tensor:
@@ -597,7 +597,10 @@ class Tensor(OpMixin):
print(Tensor.rand_like(t).numpy())
```
"""
if isinstance(self.device, tuple): return self._multi_like(Tensor.rand, **kwargs)
if isinstance(self.device, tuple):
if kwargs.pop("device", None) is not None: raise RuntimeError("cannot specify `device` on `*_like` of a multi device tensor")
dtype = kwargs.pop("dtype", self.dtype)
return self._multi_like(lambda shape, dev: Tensor.rand(*shape, dtype=dtype, device=dev, **kwargs))
return Tensor.rand(*self.shape, device=kwargs.pop("device", self.device), dtype=kwargs.pop("dtype", self.dtype), **kwargs)
# ***** random functions *****