mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-14 00:45:16 +08:00
move functional part of rand to RandMixin (#16551)
This commit is contained in:
@@ -228,6 +228,10 @@ class TestTensorUOpRand(unittest.TestCase):
|
||||
def test_threefry_random_bits(self):
|
||||
key, c0, c1 = UOp.empty((2,), dtype=dtypes.uint32), UOp.arange(4, dtype=dtypes.uint32), UOp.arange(4, dtype=dtypes.uint32)
|
||||
self.assertIs(Tensor._threefry_random_bits(Tensor(key), Tensor(c0), Tensor(c1)).uop, UOp._threefry_random_bits(key, c0, c1))
|
||||
def test_rand(self):
|
||||
k, c = UOp.empty((2,), dtype=dtypes.uint32), UOp.zeros(2, dtype=dtypes.uint32)
|
||||
self.assertIs(Tensor._rand(Tensor(k), Tensor(c), (2, 2), dtypes.float32).uop, UOp._rand(k, c, (2, 2), dtypes.float32))
|
||||
self.assertIs(Tensor._rand(Tensor(k), Tensor(c), (0, 3), dtypes.float32).uop, UOp._rand(k, c, (0, 3), dtypes.float32))
|
||||
|
||||
class TestTensorUOpGather(unittest.TestCase):
|
||||
def _check(self, t, dim, idx):
|
||||
|
||||
@@ -24,7 +24,7 @@ class RandMixin(OpMixin):
|
||||
counts0 = cls.arange(ceildiv(chunk_num, 2), dtype=dtypes.uint32)
|
||||
counts1 = counts0 + ceildiv(chunk_num, 2)
|
||||
bits.append(cls._threefry_random_bits(new_key, counts0, counts1)[:chunk_num])
|
||||
return bits[0].cat(*bits[1:])
|
||||
return bits[0].cat(*bits[1:]) if bits else counter[0:0]
|
||||
|
||||
@staticmethod
|
||||
def _bits_to_rand(bits, shape:tuple[int, ...], dtype:DType):
|
||||
@@ -33,3 +33,9 @@ class RandMixin(OpMixin):
|
||||
uint_bits = bits.bitcast(uint_dtype)
|
||||
float_one_bits = uint_bits.const_like(1).cast(dtype).bitcast(uint_dtype)
|
||||
return uint_bits.rshift(dtype.bitsize - nmant).bitwise_or(float_one_bits).bitcast(dtype)[:prod(shape)].sub(1).reshape(shape)
|
||||
|
||||
@classmethod
|
||||
def _rand(cls, key:Self, counter:Self, shape:tuple[int, ...], dtype:DType, contiguous:bool=True) -> Self:
|
||||
bits = cls.random_bits(key, counter, ceildiv(prod(shape) * dtype.itemsize, 4))
|
||||
out = cls._bits_to_rand(bits, shape, dtype)
|
||||
return out.contiguous() if contiguous else out
|
||||
|
||||
@@ -551,14 +551,8 @@ class Tensor(RandMixin):
|
||||
if not all_int(shape:=argfix(*shape)) or not all(s >= 0 for s in shape): raise ValueError(f"invalid input {shape=}")
|
||||
if device is not None and not isinstance(device, str): raise ValueError(f"rand only supports single device, got {device=}")
|
||||
device = cast(str, canonicalize_device(device))
|
||||
|
||||
# if shape has 0, return zero tensor
|
||||
if (numel := prod(shape)) == 0: return Tensor.zeros(shape, device=device, dtype=dt)
|
||||
num = ceildiv(numel * dt.itemsize, 4)
|
||||
key, counter = Tensor._next_counter(device, num)
|
||||
bits = Tensor.random_bits(key, counter, num)
|
||||
out = Tensor._bits_to_rand(bits, shape, dt)
|
||||
return out.contiguous() if contiguous else out
|
||||
key, counter = Tensor._next_counter(device, ceildiv(prod(shape) * dt.itemsize, 4))
|
||||
return Tensor._rand(key, counter, shape, dt, contiguous=contiguous)
|
||||
|
||||
# ***** creation helper functions *****
|
||||
|
||||
|
||||
Reference in New Issue
Block a user