more readable rand [pr] (#7659)

no walrus inside walrus
This commit is contained in:
chenyu
2024-11-12 19:02:27 -05:00
committed by GitHub
parent 1884f021e3
commit 08706c2ea4

View File

@@ -502,7 +502,8 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
else: had_counter = True
# if shape has 0, return zero tensor
if (num := ceildiv(((num_ := prod(shape)) * dtype.itemsize), 4)) == 0: return Tensor.zeros(shape, device=_device, dtype=dtype, **kwargs)
if (numel := prod(shape)) == 0: return Tensor.zeros(shape, device=_device, dtype=dtype, **kwargs)
num = ceildiv(numel * dtype.itemsize, 4)
# increment rng counter for devices
if had_counter: Tensor._device_rng_counters[device].assign(Tensor._device_rng_counters[device] + num).contiguous()
@@ -520,7 +521,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
one = Tensor.ones_like(bits, device=bits.device, dtype=dtype).bitcast(uint_dtype)
bits = bits.rshift((dtype.itemsize * 8) - nmant).bitwise_or(one)
# bitcast back to the original dtype and reshape
out = bits.bitcast(dtype)[:num_].sub(1).reshape(shape)
out = bits.bitcast(dtype)[:numel].sub(1).reshape(shape)
# move back to the original device if we were using MOCKGPU
if getenv("MOCKGPU") and _device: out = out.to(_device)