mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-11 23:46:02 +08:00
fix: rand supports large tensors (#15329)
This commit is contained in:
@@ -24,7 +24,7 @@ if __name__ == "__main__":
|
||||
kernel_count = GlobalCounters.kernel_count
|
||||
assert kernel_count > 0, "No kernels, test failed"
|
||||
# NOTE: this is 124 on torch 2.10.0
|
||||
expected_kernels = 332
|
||||
expected_kernels = 334
|
||||
expectation = f"ResNet18 kernels are {kernel_count} vs {expected_kernels} expected."
|
||||
if kernel_count < expected_kernels: warnings.warn(f"{expectation} Expectation can be lowered.", UserWarning)
|
||||
assert kernel_count <= expected_kernels, f"{expectation}"
|
||||
assert kernel_count <= expected_kernels, f"{expectation}"
|
||||
|
||||
@@ -23,7 +23,7 @@ class TestKernelFusionRegression(unittest.TestCase):
|
||||
def fn():
|
||||
x = torch.randn(128, 128, device=device)
|
||||
return (x + 1.0) * 2.0 - 0.5
|
||||
self._check_kernel_count(fn, 5)
|
||||
self._check_kernel_count(fn, 7)
|
||||
|
||||
def test_relu_fusion(self):
|
||||
def fn():
|
||||
@@ -31,7 +31,7 @@ class TestKernelFusionRegression(unittest.TestCase):
|
||||
conv = torch.nn.Conv2d(3, 16, 3, padding=1).to(device)
|
||||
with torch.no_grad():
|
||||
return torch.nn.functional.relu(conv(x))
|
||||
self._check_kernel_count(fn, 6)
|
||||
self._check_kernel_count(fn, 8)
|
||||
|
||||
def test_batchnorm_fusion(self):
|
||||
def fn():
|
||||
@@ -41,26 +41,26 @@ class TestKernelFusionRegression(unittest.TestCase):
|
||||
bn.eval()
|
||||
with torch.no_grad():
|
||||
return torch.nn.functional.relu(bn(conv(x)))
|
||||
self._check_kernel_count(fn, 10)
|
||||
self._check_kernel_count(fn, 12)
|
||||
|
||||
def test_reduce_fusion(self):
|
||||
def fn():
|
||||
x = torch.randn(64, 64, device=device)
|
||||
return (x * 2.0).sum()
|
||||
self._check_kernel_count(fn, 5)
|
||||
self._check_kernel_count(fn, 7)
|
||||
|
||||
def test_matmul_elementwise_fusion(self):
|
||||
def fn():
|
||||
x = torch.randn(32, 32, device=device)
|
||||
w = torch.randn(32, 32, device=device)
|
||||
return torch.nn.functional.relu(x @ w + 1.0)
|
||||
self._check_kernel_count(fn, 7)
|
||||
self._check_kernel_count(fn, 9)
|
||||
|
||||
def test_pooling_fusion(self):
|
||||
def fn():
|
||||
x = torch.randn(1, 8, 16, 16, device=device)
|
||||
return torch.nn.functional.max_pool2d(x * 2.0, 2)
|
||||
self._check_kernel_count(fn, 5)
|
||||
self._check_kernel_count(fn, 7)
|
||||
|
||||
def test_residual_add_relu_fusion(self):
|
||||
def fn():
|
||||
@@ -68,7 +68,7 @@ class TestKernelFusionRegression(unittest.TestCase):
|
||||
identity = torch.randn(1, 8, 16, 16, device=device)
|
||||
out = x + identity
|
||||
return torch.nn.functional.relu(out)
|
||||
self._check_kernel_count(fn, 7)
|
||||
self._check_kernel_count(fn, 9)
|
||||
|
||||
def test_inplace_add_relu_fusion(self):
|
||||
def fn():
|
||||
@@ -76,7 +76,7 @@ class TestKernelFusionRegression(unittest.TestCase):
|
||||
y = torch.randn(1, 16, 32, 32, device=device)
|
||||
x += y
|
||||
return torch.nn.functional.relu(x)
|
||||
self._check_kernel_count(fn, 7)
|
||||
self._check_kernel_count(fn, 9)
|
||||
|
||||
def test_conv_bn_add_relu_fusion(self):
|
||||
def fn():
|
||||
@@ -89,7 +89,7 @@ class TestKernelFusionRegression(unittest.TestCase):
|
||||
out = bn(conv(x))
|
||||
out += identity
|
||||
return torch.nn.functional.relu(out)
|
||||
self._check_kernel_count(fn, 12)
|
||||
self._check_kernel_count(fn, 14)
|
||||
|
||||
def test_multiple_inplace_ops_fusion(self):
|
||||
def fn():
|
||||
@@ -97,7 +97,7 @@ class TestKernelFusionRegression(unittest.TestCase):
|
||||
x += 1.0
|
||||
x *= 2.0
|
||||
return torch.nn.functional.relu(x)
|
||||
self._check_kernel_count(fn, 4)
|
||||
self._check_kernel_count(fn, 6)
|
||||
|
||||
def test_view_inplace_no_fusion_break(self):
|
||||
def fn():
|
||||
@@ -105,7 +105,7 @@ class TestKernelFusionRegression(unittest.TestCase):
|
||||
view = x[1:3]
|
||||
view += 1.0
|
||||
return x.sum()
|
||||
self._check_kernel_count(fn, 8)
|
||||
self._check_kernel_count(fn, 10)
|
||||
|
||||
def test_batchnorm_running_stats_update(self):
|
||||
def fn():
|
||||
@@ -114,7 +114,7 @@ class TestKernelFusionRegression(unittest.TestCase):
|
||||
bn.train()
|
||||
with torch.no_grad():
|
||||
return bn(x)
|
||||
self._check_kernel_count(fn, 8)
|
||||
self._check_kernel_count(fn, 10)
|
||||
|
||||
# this is a minimal extra/other_mnist/beautiful_mnist_torch.py to cover fusion for training with optimizer
|
||||
def test_mnist_training_fusion(self):
|
||||
@@ -135,7 +135,7 @@ class TestKernelFusionRegression(unittest.TestCase):
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
return loss
|
||||
self._check_kernel_count(fn, 24)
|
||||
self._check_kernel_count(fn, 26)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -384,6 +384,21 @@ class TestRandomness(unittest.TestCase):
|
||||
for _ in range(833): Tensor.rand(1)
|
||||
Tensor.rand(1).realize()
|
||||
|
||||
def test_random_counter_overflow(self):
|
||||
device = Device.DEFAULT
|
||||
Tensor.manual_seed(1337)
|
||||
Tensor.rand(1).realize()
|
||||
|
||||
Tensor._device_rng_counters[device].assign(Tensor([dtypes.uint32.max - 5, 0], device=device, dtype=dtypes.uint32)).realize()
|
||||
|
||||
Tensor.rand(10).realize()
|
||||
c = Tensor._device_rng_counters[device].numpy()
|
||||
np.testing.assert_allclose(c, [4, 1])
|
||||
|
||||
Tensor.rand(10).realize()
|
||||
c = Tensor._device_rng_counters[device].numpy()
|
||||
np.testing.assert_allclose(c, [14, 1])
|
||||
|
||||
# TODO: still fails with MAX_KERNEL_BUFFERS
|
||||
@unittest.skipIf(Device.DEFAULT == "WEBGPU" and not OSX, "WEBGPU Vulkan can only run kernels with up to 10 buffers")
|
||||
class TestSample(unittest.TestCase):
|
||||
|
||||
@@ -630,15 +630,32 @@ class Tensor(OpMixin):
|
||||
Tensor._device_seeds[device] = Tensor(
|
||||
[int.from_bytes(hashlib.sha256(len(Tensor._device_seeds).to_bytes(4, "big")).digest(), "big"), Tensor._seed],
|
||||
device=device, dtype=dtypes.uint32, requires_grad=False)
|
||||
Tensor._device_rng_counters[device] = Tensor([num], device=device, dtype=dtypes.uint32, requires_grad=False).contiguous()
|
||||
Tensor._device_rng_counters[device] = Tensor([0, 0], device=device, dtype=dtypes.uint32, requires_grad=False).contiguous()
|
||||
|
||||
# increment rng counter for devices
|
||||
else: Tensor._device_rng_counters[device].assign(Tensor._device_rng_counters[device] + num)
|
||||
new_low = Tensor._device_rng_counters[device][0] + (num & 0xffffffff)
|
||||
new_high = Tensor._device_rng_counters[device][1] + (num >> 32) + (new_low < Tensor._device_rng_counters[device][0]).cast(dtypes.uint32)
|
||||
Tensor._device_rng_counters[device].assign(Tensor.stack(new_low, new_high))
|
||||
|
||||
low = Tensor._device_rng_counters[device][0] - (num & 0xffffffff)
|
||||
high = Tensor._device_rng_counters[device][1] - (num >> 32) - (Tensor._device_rng_counters[device][0] < (num & 0xffffffff)).cast(dtypes.uint32)
|
||||
|
||||
# threefry random bits
|
||||
bits_count = Tensor._device_rng_counters[device] - num
|
||||
counts0 = (Tensor.arange(ceildiv(num, 2), device=device, dtype=dtypes.uint32, requires_grad=False)+bits_count)
|
||||
counts1 = counts0 + ceildiv(num, 2)
|
||||
bits = Tensor._threefry_random_bits(Tensor._device_seeds[device], counts0, counts1)[:num]
|
||||
if num > dtypes.uint32.max:
|
||||
bits_list = []
|
||||
for i in range(0, num, dtypes.uint32.max):
|
||||
chunk_num = min(num - i, dtypes.uint32.max)
|
||||
c_low = low + (i & 0xffffffff)
|
||||
c_high = high + (i >> 32) + (c_low < low).cast(dtypes.uint32)
|
||||
new_key = Tensor._threefry_random_bits(Tensor._device_seeds[device], c_low, c_high)
|
||||
counts0 = Tensor.arange(ceildiv(chunk_num, 2), device=device, dtype=dtypes.uint32, requires_grad=False)
|
||||
counts1 = counts0 + ceildiv(chunk_num, 2)
|
||||
bits_list.append(Tensor._threefry_random_bits(new_key, counts0, counts1)[:chunk_num])
|
||||
bits = Tensor.cat(*bits_list)
|
||||
else:
|
||||
counts0 = Tensor.arange(ceildiv(num, 2), device=device, dtype=dtypes.uint32, requires_grad=False) + low
|
||||
counts1 = counts0 + ceildiv(num, 2)
|
||||
bits = Tensor._threefry_random_bits(Tensor._device_seeds[device], counts0, counts1)[:num]
|
||||
|
||||
# bitcast to uint with same number of bits
|
||||
_, nmant = dtypes.finfo(dt)
|
||||
|
||||
Reference in New Issue
Block a user