From b45edeb9653d5e7adbf0fcd3cd14c48e8bfd2b6c Mon Sep 17 00:00:00 2001 From: wozeparrot Date: Wed, 18 Mar 2026 06:45:41 +0800 Subject: [PATCH] fix: rand supports large tensors (#15329) --- extra/torch_backend/example.py | 4 ++-- extra/torch_backend/test_kernel_fusion.py | 26 ++++++++++---------- test/backend/test_randomness.py | 15 ++++++++++++ tinygrad/tensor.py | 29 ++++++++++++++++++----- 4 files changed, 53 insertions(+), 21 deletions(-) diff --git a/extra/torch_backend/example.py b/extra/torch_backend/example.py index 9a3cc2581e..7a609b7a17 100644 --- a/extra/torch_backend/example.py +++ b/extra/torch_backend/example.py @@ -24,7 +24,7 @@ if __name__ == "__main__": kernel_count = GlobalCounters.kernel_count assert kernel_count > 0, "No kernels, test failed" # NOTE: this is 124 on torch 2.10.0 - expected_kernels = 332 + expected_kernels = 334 expectation = f"ResNet18 kernels are {kernel_count} vs {expected_kernels} expected." if kernel_count < expected_kernels: warnings.warn(f"{expectation} Expectation can be lowered.", UserWarning) - assert kernel_count <= expected_kernels, f"{expectation}" \ No newline at end of file + assert kernel_count <= expected_kernels, f"{expectation}" diff --git a/extra/torch_backend/test_kernel_fusion.py b/extra/torch_backend/test_kernel_fusion.py index dffcfe067f..f79733fd37 100644 --- a/extra/torch_backend/test_kernel_fusion.py +++ b/extra/torch_backend/test_kernel_fusion.py @@ -23,7 +23,7 @@ class TestKernelFusionRegression(unittest.TestCase): def fn(): x = torch.randn(128, 128, device=device) return (x + 1.0) * 2.0 - 0.5 - self._check_kernel_count(fn, 5) + self._check_kernel_count(fn, 7) def test_relu_fusion(self): def fn(): @@ -31,7 +31,7 @@ class TestKernelFusionRegression(unittest.TestCase): conv = torch.nn.Conv2d(3, 16, 3, padding=1).to(device) with torch.no_grad(): return torch.nn.functional.relu(conv(x)) - self._check_kernel_count(fn, 6) + self._check_kernel_count(fn, 8) def test_batchnorm_fusion(self): def fn(): @@ -41,26 +41,26 @@ class TestKernelFusionRegression(unittest.TestCase): bn.eval() with torch.no_grad(): return torch.nn.functional.relu(bn(conv(x))) - self._check_kernel_count(fn, 10) + self._check_kernel_count(fn, 12) def test_reduce_fusion(self): def fn(): x = torch.randn(64, 64, device=device) return (x * 2.0).sum() - self._check_kernel_count(fn, 5) + self._check_kernel_count(fn, 7) def test_matmul_elementwise_fusion(self): def fn(): x = torch.randn(32, 32, device=device) w = torch.randn(32, 32, device=device) return torch.nn.functional.relu(x @ w + 1.0) - self._check_kernel_count(fn, 7) + self._check_kernel_count(fn, 9) def test_pooling_fusion(self): def fn(): x = torch.randn(1, 8, 16, 16, device=device) return torch.nn.functional.max_pool2d(x * 2.0, 2) - self._check_kernel_count(fn, 5) + self._check_kernel_count(fn, 7) def test_residual_add_relu_fusion(self): def fn(): @@ -68,7 +68,7 @@ class TestKernelFusionRegression(unittest.TestCase): identity = torch.randn(1, 8, 16, 16, device=device) out = x + identity return torch.nn.functional.relu(out) - self._check_kernel_count(fn, 7) + self._check_kernel_count(fn, 9) def test_inplace_add_relu_fusion(self): def fn(): @@ -76,7 +76,7 @@ class TestKernelFusionRegression(unittest.TestCase): y = torch.randn(1, 16, 32, 32, device=device) x += y return torch.nn.functional.relu(x) - self._check_kernel_count(fn, 7) + self._check_kernel_count(fn, 9) def test_conv_bn_add_relu_fusion(self): def fn(): @@ -89,7 +89,7 @@ class TestKernelFusionRegression(unittest.TestCase): out = bn(conv(x)) out += identity return torch.nn.functional.relu(out) - self._check_kernel_count(fn, 12) + self._check_kernel_count(fn, 14) def test_multiple_inplace_ops_fusion(self): def fn(): @@ -97,7 +97,7 @@ class TestKernelFusionRegression(unittest.TestCase): x += 1.0 x *= 2.0 return torch.nn.functional.relu(x) - self._check_kernel_count(fn, 4) + self._check_kernel_count(fn, 6) def test_view_inplace_no_fusion_break(self): def fn(): @@ -105,7 +105,7 @@ class TestKernelFusionRegression(unittest.TestCase): view = x[1:3] view += 1.0 return x.sum() - self._check_kernel_count(fn, 8) + self._check_kernel_count(fn, 10) def test_batchnorm_running_stats_update(self): def fn(): @@ -114,7 +114,7 @@ class TestKernelFusionRegression(unittest.TestCase): bn.train() with torch.no_grad(): return bn(x) - self._check_kernel_count(fn, 8) + self._check_kernel_count(fn, 10) # this is a minimal extra/other_mnist/beautiful_mnist_torch.py to cover fusion for training with optimizer def test_mnist_training_fusion(self): @@ -135,7 +135,7 @@ class TestKernelFusionRegression(unittest.TestCase): loss.backward() optimizer.step() return loss - self._check_kernel_count(fn, 24) + self._check_kernel_count(fn, 26) if __name__ == "__main__": unittest.main() diff --git a/test/backend/test_randomness.py b/test/backend/test_randomness.py index 1d9aaec1ec..bebf0c54ba 100644 --- a/test/backend/test_randomness.py +++ b/test/backend/test_randomness.py @@ -384,6 +384,21 @@ class TestRandomness(unittest.TestCase): for _ in range(833): Tensor.rand(1) Tensor.rand(1).realize() + def test_random_counter_overflow(self): + device = Device.DEFAULT + Tensor.manual_seed(1337) + Tensor.rand(1).realize() + + Tensor._device_rng_counters[device].assign(Tensor([dtypes.uint32.max - 5, 0], device=device, dtype=dtypes.uint32)).realize() + + Tensor.rand(10).realize() + c = Tensor._device_rng_counters[device].numpy() + np.testing.assert_allclose(c, [4, 1]) + + Tensor.rand(10).realize() + c = Tensor._device_rng_counters[device].numpy() + np.testing.assert_allclose(c, [14, 1]) + # TODO: still fails with MAX_KERNEL_BUFFERS @unittest.skipIf(Device.DEFAULT == "WEBGPU" and not OSX, "WEBGPU Vulkan can only run kernels with up to 10 buffers") class TestSample(unittest.TestCase): diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 88e111e62a..2da7ff16e6 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -630,15 +630,32 @@ class Tensor(OpMixin): Tensor._device_seeds[device] = Tensor( [int.from_bytes(hashlib.sha256(len(Tensor._device_seeds).to_bytes(4, "big")).digest(), "big"), Tensor._seed], device=device, dtype=dtypes.uint32, requires_grad=False) - Tensor._device_rng_counters[device] = Tensor([num], device=device, dtype=dtypes.uint32, requires_grad=False).contiguous() + Tensor._device_rng_counters[device] = Tensor([0, 0], device=device, dtype=dtypes.uint32, requires_grad=False).contiguous() + # increment rng counter for devices - else: Tensor._device_rng_counters[device].assign(Tensor._device_rng_counters[device] + num) + new_low = Tensor._device_rng_counters[device][0] + (num & 0xffffffff) + new_high = Tensor._device_rng_counters[device][1] + (num >> 32) + (new_low < Tensor._device_rng_counters[device][0]).cast(dtypes.uint32) + Tensor._device_rng_counters[device].assign(Tensor.stack(new_low, new_high)) + + low = Tensor._device_rng_counters[device][0] - (num & 0xffffffff) + high = Tensor._device_rng_counters[device][1] - (num >> 32) - (Tensor._device_rng_counters[device][0] < (num & 0xffffffff)).cast(dtypes.uint32) # threefry random bits - bits_count = Tensor._device_rng_counters[device] - num - counts0 = (Tensor.arange(ceildiv(num, 2), device=device, dtype=dtypes.uint32, requires_grad=False)+bits_count) - counts1 = counts0 + ceildiv(num, 2) - bits = Tensor._threefry_random_bits(Tensor._device_seeds[device], counts0, counts1)[:num] + if num > dtypes.uint32.max: + bits_list = [] + for i in range(0, num, dtypes.uint32.max): + chunk_num = min(num - i, dtypes.uint32.max) + c_low = low + (i & 0xffffffff) + c_high = high + (i >> 32) + (c_low < low).cast(dtypes.uint32) + new_key = Tensor._threefry_random_bits(Tensor._device_seeds[device], c_low, c_high) + counts0 = Tensor.arange(ceildiv(chunk_num, 2), device=device, dtype=dtypes.uint32, requires_grad=False) + counts1 = counts0 + ceildiv(chunk_num, 2) + bits_list.append(Tensor._threefry_random_bits(new_key, counts0, counts1)[:chunk_num]) + bits = Tensor.cat(*bits_list) + else: + counts0 = Tensor.arange(ceildiv(num, 2), device=device, dtype=dtypes.uint32, requires_grad=False) + low + counts1 = counts0 + ceildiv(num, 2) + bits = Tensor._threefry_random_bits(Tensor._device_seeds[device], counts0, counts1)[:num] # bitcast to uint with same number of bits _, nmant = dtypes.finfo(dt)