move functional part of rand to RandMixin (#16551)

This commit is contained in:
chenyu
2026-06-09 09:40:48 -04:00
committed by GitHub
parent fa31c744b9
commit 3f053a3370
3 changed files with 13 additions and 9 deletions

View File

@@ -228,6 +228,10 @@ class TestTensorUOpRand(unittest.TestCase):
def test_threefry_random_bits(self):
key, c0, c1 = UOp.empty((2,), dtype=dtypes.uint32), UOp.arange(4, dtype=dtypes.uint32), UOp.arange(4, dtype=dtypes.uint32)
self.assertIs(Tensor._threefry_random_bits(Tensor(key), Tensor(c0), Tensor(c1)).uop, UOp._threefry_random_bits(key, c0, c1))
def test_rand(self):
k, c = UOp.empty((2,), dtype=dtypes.uint32), UOp.zeros(2, dtype=dtypes.uint32)
self.assertIs(Tensor._rand(Tensor(k), Tensor(c), (2, 2), dtypes.float32).uop, UOp._rand(k, c, (2, 2), dtypes.float32))
self.assertIs(Tensor._rand(Tensor(k), Tensor(c), (0, 3), dtypes.float32).uop, UOp._rand(k, c, (0, 3), dtypes.float32))
class TestTensorUOpGather(unittest.TestCase):
def _check(self, t, dim, idx):

View File

@@ -24,7 +24,7 @@ class RandMixin(OpMixin):
counts0 = cls.arange(ceildiv(chunk_num, 2), dtype=dtypes.uint32)
counts1 = counts0 + ceildiv(chunk_num, 2)
bits.append(cls._threefry_random_bits(new_key, counts0, counts1)[:chunk_num])
return bits[0].cat(*bits[1:])
return bits[0].cat(*bits[1:]) if bits else counter[0:0]
@staticmethod
def _bits_to_rand(bits, shape:tuple[int, ...], dtype:DType):
@@ -33,3 +33,9 @@ class RandMixin(OpMixin):
uint_bits = bits.bitcast(uint_dtype)
float_one_bits = uint_bits.const_like(1).cast(dtype).bitcast(uint_dtype)
return uint_bits.rshift(dtype.bitsize - nmant).bitwise_or(float_one_bits).bitcast(dtype)[:prod(shape)].sub(1).reshape(shape)
@classmethod
def _rand(cls, key:Self, counter:Self, shape:tuple[int, ...], dtype:DType, contiguous:bool=True) -> Self:
bits = cls.random_bits(key, counter, ceildiv(prod(shape) * dtype.itemsize, 4))
out = cls._bits_to_rand(bits, shape, dtype)
return out.contiguous() if contiguous else out

View File

@@ -551,14 +551,8 @@ class Tensor(RandMixin):
if not all_int(shape:=argfix(*shape)) or not all(s >= 0 for s in shape): raise ValueError(f"invalid input {shape=}")
if device is not None and not isinstance(device, str): raise ValueError(f"rand only supports single device, got {device=}")
device = cast(str, canonicalize_device(device))
# if shape has 0, return zero tensor
if (numel := prod(shape)) == 0: return Tensor.zeros(shape, device=device, dtype=dt)
num = ceildiv(numel * dt.itemsize, 4)
key, counter = Tensor._next_counter(device, num)
bits = Tensor.random_bits(key, counter, num)
out = Tensor._bits_to_rand(bits, shape, dt)
return out.contiguous() if contiguous else out
key, counter = Tensor._next_counter(device, ceildiv(prod(shape) * dt.itemsize, 4))
return Tensor._rand(key, counter, shape, dt, contiguous=contiguous)
# ***** creation helper functions *****