fix: rand supports large tensors (#15329)

This commit is contained in:
wozeparrot
2026-03-18 06:45:41 +08:00
committed by GitHub
parent 00817cf65e
commit b45edeb965
4 changed files with 53 additions and 21 deletions

View File

@@ -24,7 +24,7 @@ if __name__ == "__main__":
kernel_count = GlobalCounters.kernel_count
assert kernel_count > 0, "No kernels, test failed"
# NOTE: this is 124 on torch 2.10.0
expected_kernels = 332
expected_kernels = 334
expectation = f"ResNet18 kernels are {kernel_count} vs {expected_kernels} expected."
if kernel_count < expected_kernels: warnings.warn(f"{expectation} Expectation can be lowered.", UserWarning)
assert kernel_count <= expected_kernels, f"{expectation}"
assert kernel_count <= expected_kernels, f"{expectation}"

View File

@@ -23,7 +23,7 @@ class TestKernelFusionRegression(unittest.TestCase):
def fn():
x = torch.randn(128, 128, device=device)
return (x + 1.0) * 2.0 - 0.5
self._check_kernel_count(fn, 5)
self._check_kernel_count(fn, 7)
def test_relu_fusion(self):
def fn():
@@ -31,7 +31,7 @@ class TestKernelFusionRegression(unittest.TestCase):
conv = torch.nn.Conv2d(3, 16, 3, padding=1).to(device)
with torch.no_grad():
return torch.nn.functional.relu(conv(x))
self._check_kernel_count(fn, 6)
self._check_kernel_count(fn, 8)
def test_batchnorm_fusion(self):
def fn():
@@ -41,26 +41,26 @@ class TestKernelFusionRegression(unittest.TestCase):
bn.eval()
with torch.no_grad():
return torch.nn.functional.relu(bn(conv(x)))
self._check_kernel_count(fn, 10)
self._check_kernel_count(fn, 12)
def test_reduce_fusion(self):
def fn():
x = torch.randn(64, 64, device=device)
return (x * 2.0).sum()
self._check_kernel_count(fn, 5)
self._check_kernel_count(fn, 7)
def test_matmul_elementwise_fusion(self):
def fn():
x = torch.randn(32, 32, device=device)
w = torch.randn(32, 32, device=device)
return torch.nn.functional.relu(x @ w + 1.0)
self._check_kernel_count(fn, 7)
self._check_kernel_count(fn, 9)
def test_pooling_fusion(self):
def fn():
x = torch.randn(1, 8, 16, 16, device=device)
return torch.nn.functional.max_pool2d(x * 2.0, 2)
self._check_kernel_count(fn, 5)
self._check_kernel_count(fn, 7)
def test_residual_add_relu_fusion(self):
def fn():
@@ -68,7 +68,7 @@ class TestKernelFusionRegression(unittest.TestCase):
identity = torch.randn(1, 8, 16, 16, device=device)
out = x + identity
return torch.nn.functional.relu(out)
self._check_kernel_count(fn, 7)
self._check_kernel_count(fn, 9)
def test_inplace_add_relu_fusion(self):
def fn():
@@ -76,7 +76,7 @@ class TestKernelFusionRegression(unittest.TestCase):
y = torch.randn(1, 16, 32, 32, device=device)
x += y
return torch.nn.functional.relu(x)
self._check_kernel_count(fn, 7)
self._check_kernel_count(fn, 9)
def test_conv_bn_add_relu_fusion(self):
def fn():
@@ -89,7 +89,7 @@ class TestKernelFusionRegression(unittest.TestCase):
out = bn(conv(x))
out += identity
return torch.nn.functional.relu(out)
self._check_kernel_count(fn, 12)
self._check_kernel_count(fn, 14)
def test_multiple_inplace_ops_fusion(self):
def fn():
@@ -97,7 +97,7 @@ class TestKernelFusionRegression(unittest.TestCase):
x += 1.0
x *= 2.0
return torch.nn.functional.relu(x)
self._check_kernel_count(fn, 4)
self._check_kernel_count(fn, 6)
def test_view_inplace_no_fusion_break(self):
def fn():
@@ -105,7 +105,7 @@ class TestKernelFusionRegression(unittest.TestCase):
view = x[1:3]
view += 1.0
return x.sum()
self._check_kernel_count(fn, 8)
self._check_kernel_count(fn, 10)
def test_batchnorm_running_stats_update(self):
def fn():
@@ -114,7 +114,7 @@ class TestKernelFusionRegression(unittest.TestCase):
bn.train()
with torch.no_grad():
return bn(x)
self._check_kernel_count(fn, 8)
self._check_kernel_count(fn, 10)
# this is a minimal extra/other_mnist/beautiful_mnist_torch.py to cover fusion for training with optimizer
def test_mnist_training_fusion(self):
@@ -135,7 +135,7 @@ class TestKernelFusionRegression(unittest.TestCase):
loss.backward()
optimizer.step()
return loss
self._check_kernel_count(fn, 24)
self._check_kernel_count(fn, 26)
if __name__ == "__main__":
unittest.main()

View File

@@ -384,6 +384,21 @@ class TestRandomness(unittest.TestCase):
for _ in range(833): Tensor.rand(1)
Tensor.rand(1).realize()
def test_random_counter_overflow(self):
device = Device.DEFAULT
Tensor.manual_seed(1337)
Tensor.rand(1).realize()
Tensor._device_rng_counters[device].assign(Tensor([dtypes.uint32.max - 5, 0], device=device, dtype=dtypes.uint32)).realize()
Tensor.rand(10).realize()
c = Tensor._device_rng_counters[device].numpy()
np.testing.assert_allclose(c, [4, 1])
Tensor.rand(10).realize()
c = Tensor._device_rng_counters[device].numpy()
np.testing.assert_allclose(c, [14, 1])
# TODO: still fails with MAX_KERNEL_BUFFERS
@unittest.skipIf(Device.DEFAULT == "WEBGPU" and not OSX, "WEBGPU Vulkan can only run kernels with up to 10 buffers")
class TestSample(unittest.TestCase):

View File

@@ -630,15 +630,32 @@ class Tensor(OpMixin):
Tensor._device_seeds[device] = Tensor(
[int.from_bytes(hashlib.sha256(len(Tensor._device_seeds).to_bytes(4, "big")).digest(), "big"), Tensor._seed],
device=device, dtype=dtypes.uint32, requires_grad=False)
Tensor._device_rng_counters[device] = Tensor([num], device=device, dtype=dtypes.uint32, requires_grad=False).contiguous()
Tensor._device_rng_counters[device] = Tensor([0, 0], device=device, dtype=dtypes.uint32, requires_grad=False).contiguous()
# increment rng counter for devices
else: Tensor._device_rng_counters[device].assign(Tensor._device_rng_counters[device] + num)
new_low = Tensor._device_rng_counters[device][0] + (num & 0xffffffff)
new_high = Tensor._device_rng_counters[device][1] + (num >> 32) + (new_low < Tensor._device_rng_counters[device][0]).cast(dtypes.uint32)
Tensor._device_rng_counters[device].assign(Tensor.stack(new_low, new_high))
low = Tensor._device_rng_counters[device][0] - (num & 0xffffffff)
high = Tensor._device_rng_counters[device][1] - (num >> 32) - (Tensor._device_rng_counters[device][0] < (num & 0xffffffff)).cast(dtypes.uint32)
# threefry random bits
bits_count = Tensor._device_rng_counters[device] - num
counts0 = (Tensor.arange(ceildiv(num, 2), device=device, dtype=dtypes.uint32, requires_grad=False)+bits_count)
counts1 = counts0 + ceildiv(num, 2)
bits = Tensor._threefry_random_bits(Tensor._device_seeds[device], counts0, counts1)[:num]
if num > dtypes.uint32.max:
bits_list = []
for i in range(0, num, dtypes.uint32.max):
chunk_num = min(num - i, dtypes.uint32.max)
c_low = low + (i & 0xffffffff)
c_high = high + (i >> 32) + (c_low < low).cast(dtypes.uint32)
new_key = Tensor._threefry_random_bits(Tensor._device_seeds[device], c_low, c_high)
counts0 = Tensor.arange(ceildiv(chunk_num, 2), device=device, dtype=dtypes.uint32, requires_grad=False)
counts1 = counts0 + ceildiv(chunk_num, 2)
bits_list.append(Tensor._threefry_random_bits(new_key, counts0, counts1)[:chunk_num])
bits = Tensor.cat(*bits_list)
else:
counts0 = Tensor.arange(ceildiv(num, 2), device=device, dtype=dtypes.uint32, requires_grad=False) + low
counts1 = counts0 + ceildiv(num, 2)
bits = Tensor._threefry_random_bits(Tensor._device_seeds[device], counts0, counts1)[:num]
# bitcast to uint with same number of bits
_, nmant = dtypes.finfo(dt)