Files
tinygrad/test/null/test_gpudims.py
George Hotz 1e7f1dcf49 add ParamArgs [pr] (#16421)
* add ParamArgs

* fix export

* cleanups

* fixes

* simpler
2026-05-28 19:17:17 -07:00

111 lines
6.1 KiB
Python

import unittest, math
import z3
from tinygrad.codegen.gpudims import get_grouped_dims, add_gpudims
from tinygrad.uop.ops import UOp, Ops, KernelInfo, AxisType
from tinygrad.uop.validate import uops_to_z3
from tinygrad.dtype import dtypes
from tinygrad.renderer import Renderer
from tinygrad.helpers import flatten, dedup, Target
class TestGroupedDims(unittest.TestCase):
def _check_grouped_dims(self, prefix, dims, max_sizes, reverse, expected_sizes, assert_same_length=True):
idxs = get_grouped_dims(prefix, dims, max_sizes, reverse)
loop_idxs = dedup(flatten([[y for y in x.toposort() if y.op is Ops.SPECIAL] for x in idxs]))
loop_idxs = sorted(loop_idxs, key=lambda uop: uop.arg)
sizes = [x.src[0].arg for x in loop_idxs]
assert len(idxs) == len(dims), f"expected idxs to have same length as dims {len(dims)}, got {len(idxs)}"
if assert_same_length:
assert len(loop_idxs) == min(len(sizes), len(dims)), f"expected idxs to have length {min(len(sizes), len(dims))}, got {len(loop_idxs)}"
assert sizes == expected_sizes, f"expected sizes={expected_sizes}, got {sizes=}"
self._verify_indices_z3(idxs, dims)
def _verify_indices_z3(self, idxs, dims):
"""Use z3 to prove bijectivity: bounds (0 <= flat < total) + injectivity (different inputs => different flat)."""
total = math.prod(dims)
specials = sorted(dedup(flatten([[y for y in x.toposort() if y.op is Ops.SPECIAL] for x in idxs])), key=lambda u: u.arg)
# build flat index and primed flat (same expression with renamed SPECIALs)
flat = UOp.const(dtypes.weakint, 0)
for i, idx in enumerate(idxs):
flat = flat + idx * int(math.prod(dims[i+1:]))
flat_p = flat.substitute({s: UOp(Ops.SPECIAL, s.dtype, s.src, s.arg+"_p") for s in specials})
solver = z3.Solver()
[z3_flat, z3_flat_p] = uops_to_z3(solver, flat, flat_p)
# bounds
self.assertEqual(solver.check(z3_flat < 0), z3.unsat, f"flat can be negative: {dims=}")
self.assertEqual(solver.check(z3_flat >= total), z3.unsat, f"flat can be >= {total}: {dims=}")
# injectivity: flat == flat' but inputs differ => unsat
inputs_differ = z3.Or(*[z3.Int(s.arg) != z3.Int(s.arg+"_p") for s in specials])
self.assertEqual(solver.check(z3.And(z3_flat == z3_flat_p, inputs_differ)), z3.unsat, f"not injective: {dims=}")
def test_grouped_dims(self):
# no-op
self._check_grouped_dims("gidx", (2,), (16,16,16), False, [2])
self._check_grouped_dims("gidx", (2,3), (16,16,16), False, [2,3])
# check reverse dims
self._check_grouped_dims("gidx", (2,3), (16,16,16), True, [3,2])
self._check_grouped_dims("gidx", (2,3,4), (16,16,16), False, [2,3,4])
# test splitting globals: len(dims) == len(max)
self._check_grouped_dims("gidx", (64,3,4), (16,16,16), False, [16,12,4])
self._check_grouped_dims("gidx", (64,3,4), (16,4,16), False, [16,3,16])
self._check_grouped_dims("gidx", (64,3,4), (16,16,16), True, [16,3,16])
self._check_grouped_dims("gidx", (128,3,4), (16,4,256), False, [16,3,32])
self._check_grouped_dims("gidx", (4,4,512), (16,4,256), False, [8,4,256])
self._check_grouped_dims("gidx", (5,12,7), (8,4,16), False, [10,3,14])
# prefer group_dim strategy when possible
self._check_grouped_dims("gidx", (512,4,2), (8192,2,2), False, [2048,2])
# test splitting globals: len(dims) < len(max)
# len(dim) -> len(limited)
# 1 -> 2
self._check_grouped_dims("gidx", (128,), (16,16,256), False, [16,8], False)
# 1 -> 3
self._check_grouped_dims("gidx", (65536,), (16,16,256), False, [16,16,256], False)
# 2 -> 2
self._check_grouped_dims("gidx", (65536,2), (65535,65535,65535), False, [32768,4], False)
# test when the only divisor is the square root of dim
self._check_grouped_dims("gidx", (121,), (12,12,12), False, [11,11], False)
# 2 -> 3
self._check_grouped_dims("gidx", (128,128), (16,16,256), False, [16,16,64], False)
# collapse on onto the left most axis
self._check_grouped_dims("gidx", (2,3,4,5), (16,16,16), False, [6,4,5])
self._check_grouped_dims("gidx", (2,3,4,5), (32,16,16), True, [20,3,2])
# collapse on left-most available axis (the left most is too small)
self._check_grouped_dims("gidx", (2,3,4,5), (4,16,16), False, [2,12,5])
self._check_grouped_dims("gidx", (2,3,4,5), (16,16,16), True, [5,12,2])
# dim too large and not factorable
with self.assertRaises(RuntimeError):
get_grouped_dims("gidx", (23,), (16,16,16), False,)
with self.assertRaises(RuntimeError):
get_grouped_dims("gidx", (128,3,4), (16,2,2), False,)
# too large for sizes
with self.assertRaises(RuntimeError):
get_grouped_dims("gidx", (2,3,4,5,6), (16,16,16))
def test_grouped_direct_dims_are_special(self):
# when (2,3) are merged into 6, the unmerged dims (4,5) should map directly to SPECIAL ops (no div/mod)
idxs = get_grouped_dims("gidx", (2,3,4,5), (16,16,16), False)
assert idxs[2].op is Ops.SPECIAL, f"expected SPECIAL for direct-mapped dim, got {idxs[2].op}"
assert idxs[3].op is Ops.SPECIAL, f"expected SPECIAL for direct-mapped dim, got {idxs[3].op}"
def test_global_prod_max(self):
g, l = UOp.range(256, 0, AxisType.GLOBAL), UOp.range(256, 1, AxisType.LOCAL)
sink = UOp.param(0, dtypes.float.ptr()).index(g + l).store(UOp.const(dtypes.float, 1.0)).end(g, l).sink(arg=KernelInfo())
class R(Renderer): global_max, local_max, global_prod_max = (256, 256, 256), (128, 128, 128), (128, 128, 128)
specials = [u for u in add_gpudims(R(Target()), sink).toposort() if u.op is Ops.SPECIAL]
self.assertGreater(len([s for s in specials if "lidx" in s.arg]), 1)
self.assertGreater(len([s for s in specials if "gidx" in s.arg]), 1)
def test_max_sizes_none(self):
self._check_grouped_dims("gidx", (2,3,4), None, False, [2,3,4])
self._check_grouped_dims("gidx", (100,), None, False, [100])
if __name__ == '__main__':
unittest.main()