From 08706c2ea4ba015b75bbe0f3ba2dbc427be6dd99 Mon Sep 17 00:00:00 2001 From: chenyu Date: Tue, 12 Nov 2024 19:02:27 -0500 Subject: [PATCH] more readable rand [pr] (#7659) no walrus inside walrus --- tinygrad/tensor.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index ed94e33d0d..05ab20fb4f 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -502,7 +502,8 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method else: had_counter = True # if shape has 0, return zero tensor - if (num := ceildiv(((num_ := prod(shape)) * dtype.itemsize), 4)) == 0: return Tensor.zeros(shape, device=_device, dtype=dtype, **kwargs) + if (numel := prod(shape)) == 0: return Tensor.zeros(shape, device=_device, dtype=dtype, **kwargs) + num = ceildiv(numel * dtype.itemsize, 4) # increment rng counter for devices if had_counter: Tensor._device_rng_counters[device].assign(Tensor._device_rng_counters[device] + num).contiguous() @@ -520,7 +521,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method one = Tensor.ones_like(bits, device=bits.device, dtype=dtype).bitcast(uint_dtype) bits = bits.rshift((dtype.itemsize * 8) - nmant).bitwise_or(one) # bitcast back to the original dtype and reshape - out = bits.bitcast(dtype)[:num_].sub(1).reshape(shape) + out = bits.bitcast(dtype)[:numel].sub(1).reshape(shape) # move back to the original device if we were using MOCKGPU if getenv("MOCKGPU") and _device: out = out.to(_device)