mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
@@ -502,7 +502,8 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
||||
else: had_counter = True
|
||||
|
||||
# if shape has 0, return zero tensor
|
||||
if (num := ceildiv(((num_ := prod(shape)) * dtype.itemsize), 4)) == 0: return Tensor.zeros(shape, device=_device, dtype=dtype, **kwargs)
|
||||
if (numel := prod(shape)) == 0: return Tensor.zeros(shape, device=_device, dtype=dtype, **kwargs)
|
||||
num = ceildiv(numel * dtype.itemsize, 4)
|
||||
|
||||
# increment rng counter for devices
|
||||
if had_counter: Tensor._device_rng_counters[device].assign(Tensor._device_rng_counters[device] + num).contiguous()
|
||||
@@ -520,7 +521,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
||||
one = Tensor.ones_like(bits, device=bits.device, dtype=dtype).bitcast(uint_dtype)
|
||||
bits = bits.rshift((dtype.itemsize * 8) - nmant).bitwise_or(one)
|
||||
# bitcast back to the original dtype and reshape
|
||||
out = bits.bitcast(dtype)[:num_].sub(1).reshape(shape)
|
||||
out = bits.bitcast(dtype)[:numel].sub(1).reshape(shape)
|
||||
|
||||
# move back to the original device if we were using MOCKGPU
|
||||
if getenv("MOCKGPU") and _device: out = out.to(_device)
|
||||
|
||||
Reference in New Issue
Block a user