diff --git a/test/null/test_tensor_uop_mixin.py b/test/null/test_tensor_uop_mixin.py index 9acc456799..1ab37a9b6d 100644 --- a/test/null/test_tensor_uop_mixin.py +++ b/test/null/test_tensor_uop_mixin.py @@ -228,6 +228,10 @@ class TestTensorUOpRand(unittest.TestCase): def test_threefry_random_bits(self): key, c0, c1 = UOp.empty((2,), dtype=dtypes.uint32), UOp.arange(4, dtype=dtypes.uint32), UOp.arange(4, dtype=dtypes.uint32) self.assertIs(Tensor._threefry_random_bits(Tensor(key), Tensor(c0), Tensor(c1)).uop, UOp._threefry_random_bits(key, c0, c1)) + def test_rand(self): + k, c = UOp.empty((2,), dtype=dtypes.uint32), UOp.zeros(2, dtype=dtypes.uint32) + self.assertIs(Tensor._rand(Tensor(k), Tensor(c), (2, 2), dtypes.float32).uop, UOp._rand(k, c, (2, 2), dtypes.float32)) + self.assertIs(Tensor._rand(Tensor(k), Tensor(c), (0, 3), dtypes.float32).uop, UOp._rand(k, c, (0, 3), dtypes.float32)) class TestTensorUOpGather(unittest.TestCase): def _check(self, t, dim, idx): diff --git a/tinygrad/mixin/rand.py b/tinygrad/mixin/rand.py index c846830bc0..52c6621a15 100644 --- a/tinygrad/mixin/rand.py +++ b/tinygrad/mixin/rand.py @@ -24,7 +24,7 @@ class RandMixin(OpMixin): counts0 = cls.arange(ceildiv(chunk_num, 2), dtype=dtypes.uint32) counts1 = counts0 + ceildiv(chunk_num, 2) bits.append(cls._threefry_random_bits(new_key, counts0, counts1)[:chunk_num]) - return bits[0].cat(*bits[1:]) + return bits[0].cat(*bits[1:]) if bits else counter[0:0] @staticmethod def _bits_to_rand(bits, shape:tuple[int, ...], dtype:DType): @@ -33,3 +33,9 @@ class RandMixin(OpMixin): uint_bits = bits.bitcast(uint_dtype) float_one_bits = uint_bits.const_like(1).cast(dtype).bitcast(uint_dtype) return uint_bits.rshift(dtype.bitsize - nmant).bitwise_or(float_one_bits).bitcast(dtype)[:prod(shape)].sub(1).reshape(shape) + + @classmethod + def _rand(cls, key:Self, counter:Self, shape:tuple[int, ...], dtype:DType, contiguous:bool=True) -> Self: + bits = cls.random_bits(key, counter, ceildiv(prod(shape) * dtype.itemsize, 4)) + out = cls._bits_to_rand(bits, shape, dtype) + return out.contiguous() if contiguous else out diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 043eb1292f..d7c66b4ae4 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -551,14 +551,8 @@ class Tensor(RandMixin): if not all_int(shape:=argfix(*shape)) or not all(s >= 0 for s in shape): raise ValueError(f"invalid input {shape=}") if device is not None and not isinstance(device, str): raise ValueError(f"rand only supports single device, got {device=}") device = cast(str, canonicalize_device(device)) - - # if shape has 0, return zero tensor - if (numel := prod(shape)) == 0: return Tensor.zeros(shape, device=device, dtype=dt) - num = ceildiv(numel * dt.itemsize, 4) - key, counter = Tensor._next_counter(device, num) - bits = Tensor.random_bits(key, counter, num) - out = Tensor._bits_to_rand(bits, shape, dt) - return out.contiguous() if contiguous else out + key, counter = Tensor._next_counter(device, ceildiv(prod(shape) * dt.itemsize, 4)) + return Tensor._rand(key, counter, shape, dt, contiguous=contiguous) # ***** creation helper functions *****