diff --git a/extra/thunder/amd/fa.py b/extra/thunder/amd/fa.py index b3f5c47139..698f6d1b7a 100644 --- a/extra/thunder/amd/fa.py +++ b/extra/thunder/amd/fa.py @@ -47,24 +47,21 @@ def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False # delta_vec = (do * attn).sum(-1, dtype=dtypes.float32).transpose(1, 2).unsqueeze(-2).detach() delta_vec = _sharded_empty((B, H, 1, N), xq, axis=0, dtype=dtypes.float32) - delta_vec, dq_in = Tensor.custom_kernel(delta_vec, dq_in, attn, do, fxn=functools.partial(custom_fa_backward_pre, device=single_device, arch=arch))[:2] + delta_vec, dq_in = Tensor.custom_kernel(delta_vec, dq_in, attn, do, fxn=functools.partial(custom_fa_backward_pre, device=single_device, arch=arch, B=B_local, N=N, H=H, H_KV=H_KV, D=D))[:2] - dq_in, dk, dv = Tensor.custom_kernel(dq_in, dk, dv, do, xq, xk, xv, l_vec, delta_vec, fxn=functools.partial(custom_fa_backward, device=single_device, arch=arch))[:3] + dq_in, dk, dv = Tensor.custom_kernel(dq_in, dk, dv, do, xq, xk, xv, l_vec, delta_vec, fxn=functools.partial(custom_fa_backward, device=single_device, arch=arch, B=B_local, N=N, H=H, H_KV=H_KV, D=D))[:3] # unshuffle dq - dq = Tensor.custom_kernel(dq, dq_in, fxn=functools.partial(custom_fa_backward_post, device=single_device, arch=arch))[0] + dq = Tensor.custom_kernel(dq, dq_in, fxn=functools.partial(custom_fa_backward_post, device=single_device, arch=arch, B=B_local, N=N, H=H, H_KV=H_KV, D=D))[0] return None, None, dq.uop, dk.uop, dv.uop - attn, l_vec = Tensor.custom_kernel(attn, l_vec, xq, xk, xv, fxn=functools.partial(custom_fa_forward, device=single_device, arch=arch), grad_fxn=grad)[:2] + attn, l_vec = Tensor.custom_kernel(attn, l_vec, xq, xk, xv, fxn=functools.partial(custom_fa_forward, device=single_device, arch=arch, B=B_local, N=N, H=H, H_KV=H_KV, D=D), grad_fxn=grad)[:2] return attn.transpose(1, 2) @functools.cache -def custom_fa_forward(o:UOp, l_vec:UOp, q:UOp, k:UOp, v:UOp, device:str, arch:str): - B, N, H, D = q.shape - H_KV = k.shape[2] - +def custom_fa_forward(o:UOp, l_vec:UOp, q:UOp, k:UOp, v:UOp, device:str, arch:str, B:int, N:int, H:int, H_KV:int, D:int): code = (pathlib.Path(__file__).parent / "fa_fwd_causal.cpp").read_text() compile_args = [f"-I{(pathlib.Path(__file__).parent / 'include').as_posix()}", "-std=c++20", "-DKITTENS_CDNA4", "-DHIP_ENABLE_WARP_SYNC_BUILTINS", "-ffast-math", f"-DATTN_B={B}", f"-DATTN_N={N}", f"-DATTN_H={H}", f"-DATTN_H_KV={H_KV}"] @@ -95,9 +92,7 @@ def custom_fa_forward(o:UOp, l_vec:UOp, q:UOp, k:UOp, v:UOp, device:str, arch:st src=(sink, UOp(Ops.DEVICE, arg=device), UOp(Ops.LINEAR, src=(*sink.src, sink)), UOp(Ops.SOURCE, arg=code), UOp(Ops.BINARY, arg=lib))) @functools.cache -def custom_fa_backward_pre(delta_vec:UOp, dq:UOp, o:UOp, do:UOp, device:str, arch:str): - B, N, H, D = o.shape - +def custom_fa_backward_pre(delta_vec:UOp, dq:UOp, o:UOp, do:UOp, device:str, arch:str, B:int, N:int, H:int, H_KV:int, D:int): code = (pathlib.Path(__file__).parent / "fa_bwd_pre.cpp").read_text() compile_args = [f"-I{(pathlib.Path(__file__).parent / 'include').as_posix()}", "-std=c++20", "-DKITTENS_CDNA4", "-DHIP_ENABLE_WARP_SYNC_BUILTINS", "-ffast-math", f"-DATTN_B={B}", f"-DATTN_N={N}", f"-DATTN_H={H}"] @@ -128,10 +123,7 @@ def custom_fa_backward_pre(delta_vec:UOp, dq:UOp, o:UOp, do:UOp, device:str, arc src=(sink, UOp(Ops.DEVICE, arg=device), UOp(Ops.LINEAR, src=(*sink.src, sink)), UOp(Ops.SOURCE, arg=code), UOp(Ops.BINARY, arg=lib))) @functools.cache -def custom_fa_backward(dq:UOp, dk:UOp, dv:UOp, do:UOp, q:UOp, k:UOp, v:UOp, l_vec:UOp, delta_vec:UOp, device:str, arch:str): - B, N, H, D = q.shape - H_KV = k.shape[2] - +def custom_fa_backward(dq:UOp, dk:UOp, dv:UOp, do:UOp, q:UOp, k:UOp, v:UOp, l_vec:UOp, delta_vec:UOp, device:str, arch:str, B:int, N:int, H:int, H_KV:int, D:int): code = (pathlib.Path(__file__).parent / "fa_bwd_causal.cpp").read_text() compile_args = [f"-I{(pathlib.Path(__file__).parent / 'include').as_posix()}", "-std=c++20", "-DKITTENS_CDNA4", "-DHIP_ENABLE_WARP_SYNC_BUILTINS", "-ffast-math", f"-DATTN_B={B}", f"-DATTN_N={N}", f"-DATTN_H={H}", f"-DATTN_H_KV={H_KV}"] @@ -162,9 +154,7 @@ def custom_fa_backward(dq:UOp, dk:UOp, dv:UOp, do:UOp, q:UOp, k:UOp, v:UOp, l_ve src=(sink, UOp(Ops.DEVICE, arg=device), UOp(Ops.LINEAR, src=(*sink.src, sink)), UOp(Ops.SOURCE, arg=code), UOp(Ops.BINARY, arg=lib))) @functools.cache -def custom_fa_backward_post(dq_out:UOp, dq_in:UOp, device:str, arch:str): - B, N, H, D = dq_out.shape - +def custom_fa_backward_post(dq_out:UOp, dq_in:UOp, device:str, arch:str, B:int, N:int, H:int, H_KV:int, D:int): code = (pathlib.Path(__file__).parent / "fa_bwd_post.cpp").read_text() compile_args = [f"-I{(pathlib.Path(__file__).parent / 'include').as_posix()}", "-std=c++20", "-DKITTENS_CDNA4", "-DHIP_ENABLE_WARP_SYNC_BUILTINS", "-ffast-math", f"-DATTN_B={B}", f"-DATTN_N={N}", f"-DATTN_H={H}"] diff --git a/test/backend/test_linearizer.py b/test/backend/test_linearizer.py index 4e7e0e108d..d9688e8ca7 100644 --- a/test/backend/test_linearizer.py +++ b/test/backend/test_linearizer.py @@ -3,8 +3,7 @@ import unittest from dataclasses import replace from tinygrad.codegen.opt import Opt, OptOps -from tinygrad.codegen.gpudims import get_grouped_dims -from tinygrad.uop.ops import UOp, Ops, GroupOp, AxisType, PatternMatcher, graph_rewrite, UPat +from tinygrad.uop.ops import UOp, Ops, GroupOp, AxisType from tinygrad.device import Device, Buffer, is_dtype_supported from tinygrad.tensor import Tensor, _to_np_dtype from tinygrad.engine.realize import run_schedule, CompiledRunner, get_program @@ -253,100 +252,6 @@ class TestLinearizer(unittest.TestCase): if any(x.op is Ops.END and x.src[1].op in GroupOp.ALU for x in u.src): assert end_range < uops.index(u) - def test_grouped_dims(self): - def _assert_grouped_dims(prefix, dims, max_sizes, reverse_dims, expected_sizes, assert_same_length = True): - idxs = get_grouped_dims(prefix, dims, max_sizes, reverse_dims) - loop_idxs = dedup(flatten([[y for y in x.toposort() if y.op is Ops.SPECIAL] for x in idxs])) - loop_idxs = sorted(loop_idxs, key=lambda uop: uop.arg) - sizes = [x.src[0].arg for x in loop_idxs] - assert len(idxs) == len(dims), f"expected idxs to have same length as dims {len(dims)}, got {len(idxs)}" - if assert_same_length: - assert len(loop_idxs) == min(len(sizes), len(dims)), f"expected idxs to have length {min(len(sizes), len(dims))}, got {len(loop_idxs)}" - assert sizes == expected_sizes, f"expected sizes={expected_sizes}, got {sizes=}" - # TODO: add these back after uop symbolic - # for i in range(len(dims)): - # assert idxs[i].max+1 == dims[i], f"idxs[{i}] should have max {dims[i]-1}" - # for i in range(len(loop_idxs)): - # assert loop_idxs[i].expr.startswith(prefix), f"loop_idxs[{i}] must start with {prefix}" - # assert loop_idxs[i].max+1 == sizes[i], f"loop_idxs[{i}] should have max {sizes[i]-1}" - - # no-op - _assert_grouped_dims("gidx", (2,), (16,16,16), False, [2]) - _assert_grouped_dims("gidx", (2,3), (16,16,16), False, [2,3]) - - # check reverse dims - _assert_grouped_dims("gidx", (2,3), (16,16,16), True, [3,2]) - _assert_grouped_dims("gidx", (2,3,4), (16,16,16), False, [2,3,4]) - - # test splitting globals: len(dims) == len(max) - _assert_grouped_dims("gidx", (64,3,4), (16,16,16), False, [16,12,4]) - _assert_grouped_dims("gidx", (64,3,4), (16,4,16), False, [16,3,16]) - _assert_grouped_dims("gidx", (64,3,4), (16,16,16), True, [16,3,16]) - _assert_grouped_dims("gidx", (128,3,4), (16,4,256), False, [16,3,32]) - _assert_grouped_dims("gidx", (4,4,512), (16,4,256), False, [8,4,256]) - - # prefer group_dim strategy when possible - _assert_grouped_dims("gidx", (512,4,2), (8192,2,2), False, [2048,2]) - - # test splitting globals: len(dims) < len(max) - # len(dim) -> len(limited) - # 1 -> 2 - _assert_grouped_dims("gidx", (128,), (16,16,256), False, [16,8], False) - # 1 -> 3 - _assert_grouped_dims("gidx", (65536,), (16,16,256), False, [16,16,256], False) - # 2 -> 3 - _assert_grouped_dims("gidx", (128,128), (16,16,256), False, [16,16,64], False) - # 2 -> 2 - _assert_grouped_dims("gidx", (65536,2), (65535,65535,65535), False, [32768,4], False) - # test when the only divisor is the square root of dim - _assert_grouped_dims("gidx", (121,), (12,12,12), False, [11,11], False) - - # collapse on onto the left most axis - _assert_grouped_dims("gidx", (2,3,4,5), (16,16,16), False, [6,4,5]) - _assert_grouped_dims("gidx", (2,3,4,5), (32,16,16), True, [20,3,2]) - # _assert_grouped_dims("gidx", (Variable("start_pos",1,2),3,4,5), (32,16,16), True, [20,3,Variable("start_pos",1,2)]) - - # collapse on left-most available axis (the left most is too small) - _assert_grouped_dims("gidx", (2,3,4,5), (4,16,16), False, [2,12,5]) - _assert_grouped_dims("gidx", (2,3,4,5), (16,16,16), True, [5,12,2]) - - # _assert_grouped_dims("gidx", (Variable("start_pos",1,2),3,4,5), (16,16,16), False, [Variable("start_pos",1,2)*3,4,5]) - - # dim too large and not factorable - with self.assertRaises(RuntimeError): - get_grouped_dims("gidx", (23,), (16,16,16), False,) - with self.assertRaises(RuntimeError): - get_grouped_dims("gidx", (128,3,4), (16,2,2), False,) - - # too large for sizes - with self.assertRaises(RuntimeError): - get_grouped_dims("gidx", (2,3,4,5,6), (16,16,16)) - - # TODO: In the above cases we only test if the shape after reshape is correct, never the indices. - # We should check if the returned indices are correct, for all cases. - # (65536, 2) -> (32768, 4) - dims, expected_limited_dims = (65536,2), (32768, 4) - idxs = get_grouped_dims("gidx", dims, (65535,65535,65535)) - def match_div(): raise RuntimeError("match_div") - def match_mod(): raise RuntimeError("match_mod") - flat_idx_pattern = UPat(Ops.SPECIAL, arg='gidx0')*expected_limited_dims[1]+UPat(Ops.SPECIAL, arg='gidx1') - pm = PatternMatcher([ - (flat_idx_pattern//dims[1], match_div), - (flat_idx_pattern%dims[1], match_mod) - ]) - - with self.assertRaises(RuntimeError) as error: - graph_rewrite(idxs[0], pm) - self.assertIn("match_div", str(error.exception)) - - with self.assertRaises(RuntimeError) as error: - graph_rewrite(idxs[1], pm) - self.assertIn("match_mod", str(error.exception)) - - # # variable too large - # with self.assertRaises(AssertionError): - # get_grouped_dims("gidx", (Variable("start_pos",0,16),3,4), (16,16,16), False,) - @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") def test_default_global_reversed(self): # shrink so that the dims do not collapse diff --git a/test/null/test_gpudims.py b/test/null/test_gpudims.py new file mode 100644 index 0000000000..b18ff6c1ac --- /dev/null +++ b/test/null/test_gpudims.py @@ -0,0 +1,101 @@ +import unittest, math +import z3 +from tinygrad.codegen.gpudims import get_grouped_dims +from tinygrad.uop.ops import UOp, Ops +from tinygrad.uop.validate import uops_to_z3 +from tinygrad.dtype import dtypes +from tinygrad.helpers import flatten, dedup + +class TestGroupedDims(unittest.TestCase): + def _check_grouped_dims(self, prefix, dims, max_sizes, reverse, expected_sizes, assert_same_length=True): + idxs = get_grouped_dims(prefix, dims, max_sizes, reverse) + loop_idxs = dedup(flatten([[y for y in x.toposort() if y.op is Ops.SPECIAL] for x in idxs])) + loop_idxs = sorted(loop_idxs, key=lambda uop: uop.arg) + sizes = [x.src[0].arg for x in loop_idxs] + assert len(idxs) == len(dims), f"expected idxs to have same length as dims {len(dims)}, got {len(idxs)}" + if assert_same_length: + assert len(loop_idxs) == min(len(sizes), len(dims)), f"expected idxs to have length {min(len(sizes), len(dims))}, got {len(loop_idxs)}" + assert sizes == expected_sizes, f"expected sizes={expected_sizes}, got {sizes=}" + self._verify_indices_z3(idxs, dims) + + def _verify_indices_z3(self, idxs, dims): + """Use z3 to prove bijectivity: bounds (0 <= flat < total) + injectivity (different inputs => different flat).""" + total = math.prod(dims) + specials = sorted(dedup(flatten([[y for y in x.toposort() if y.op is Ops.SPECIAL] for x in idxs])), key=lambda u: u.arg) + # build flat index and primed flat (same expression with renamed SPECIALs) + flat = UOp.const(dtypes.index, 0) + for i, idx in enumerate(idxs): + flat = flat + idx * int(math.prod(dims[i+1:])) + flat_p = flat.substitute({s: UOp(Ops.SPECIAL, s.dtype, s.src, s.arg+"_p") for s in specials}) + solver = z3.Solver() + [z3_flat, z3_flat_p] = uops_to_z3(solver, flat, flat_p) + # bounds + self.assertEqual(solver.check(z3_flat < 0), z3.unsat, f"flat can be negative: {dims=}") + self.assertEqual(solver.check(z3_flat >= total), z3.unsat, f"flat can be >= {total}: {dims=}") + # injectivity: flat == flat' but inputs differ => unsat + inputs_differ = z3.Or(*[z3.Int(s.arg) != z3.Int(s.arg+"_p") for s in specials]) + self.assertEqual(solver.check(z3.And(z3_flat == z3_flat_p, inputs_differ)), z3.unsat, f"not injective: {dims=}") + + def test_grouped_dims(self): + # no-op + self._check_grouped_dims("gidx", (2,), (16,16,16), False, [2]) + self._check_grouped_dims("gidx", (2,3), (16,16,16), False, [2,3]) + + # check reverse dims + self._check_grouped_dims("gidx", (2,3), (16,16,16), True, [3,2]) + self._check_grouped_dims("gidx", (2,3,4), (16,16,16), False, [2,3,4]) + + # test splitting globals: len(dims) == len(max) + self._check_grouped_dims("gidx", (64,3,4), (16,16,16), False, [16,12,4]) + self._check_grouped_dims("gidx", (64,3,4), (16,4,16), False, [16,3,16]) + self._check_grouped_dims("gidx", (64,3,4), (16,16,16), True, [16,3,16]) + self._check_grouped_dims("gidx", (128,3,4), (16,4,256), False, [16,3,32]) + self._check_grouped_dims("gidx", (4,4,512), (16,4,256), False, [8,4,256]) + self._check_grouped_dims("gidx", (5,12,7), (8,4,16), False, [10,3,14]) + + # prefer group_dim strategy when possible + self._check_grouped_dims("gidx", (512,4,2), (8192,2,2), False, [2048,2]) + + # test splitting globals: len(dims) < len(max) + # len(dim) -> len(limited) + # 1 -> 2 + self._check_grouped_dims("gidx", (128,), (16,16,256), False, [16,8], False) + # 1 -> 3 + self._check_grouped_dims("gidx", (65536,), (16,16,256), False, [16,16,256], False) + # 2 -> 2 + self._check_grouped_dims("gidx", (65536,2), (65535,65535,65535), False, [32768,4], False) + # test when the only divisor is the square root of dim + self._check_grouped_dims("gidx", (121,), (12,12,12), False, [11,11], False) + # 2 -> 3 + self._check_grouped_dims("gidx", (128,128), (16,16,256), False, [16,16,64], False) + + # collapse on onto the left most axis + self._check_grouped_dims("gidx", (2,3,4,5), (16,16,16), False, [6,4,5]) + self._check_grouped_dims("gidx", (2,3,4,5), (32,16,16), True, [20,3,2]) + + # collapse on left-most available axis (the left most is too small) + self._check_grouped_dims("gidx", (2,3,4,5), (4,16,16), False, [2,12,5]) + self._check_grouped_dims("gidx", (2,3,4,5), (16,16,16), True, [5,12,2]) + + # dim too large and not factorable + with self.assertRaises(RuntimeError): + get_grouped_dims("gidx", (23,), (16,16,16), False,) + with self.assertRaises(RuntimeError): + get_grouped_dims("gidx", (128,3,4), (16,2,2), False,) + + # too large for sizes + with self.assertRaises(RuntimeError): + get_grouped_dims("gidx", (2,3,4,5,6), (16,16,16)) + + def test_grouped_direct_dims_are_special(self): + # when (2,3) are merged into 6, the unmerged dims (4,5) should map directly to SPECIAL ops (no div/mod) + idxs = get_grouped_dims("gidx", (2,3,4,5), (16,16,16), False) + assert idxs[2].op is Ops.SPECIAL, f"expected SPECIAL for direct-mapped dim, got {idxs[2].op}" + assert idxs[3].op is Ops.SPECIAL, f"expected SPECIAL for direct-mapped dim, got {idxs[3].op}" + + def test_max_sizes_none(self): + self._check_grouped_dims("gidx", (2,3,4), None, False, [2,3,4]) + self._check_grouped_dims("gidx", (100,), None, False, [100]) + +if __name__ == '__main__': + unittest.main() diff --git a/test/null/test_uop_symbolic.py b/test/null/test_uop_symbolic.py index 0c8848f2a7..12c4158787 100644 --- a/test/null/test_uop_symbolic.py +++ b/test/null/test_uop_symbolic.py @@ -839,34 +839,33 @@ class TestSymbolicNumeric(unittest.TestCase): def test_times_2_plus_3_div_4(self): self.helper_test_numeric(lambda x: (x*2 + 3)//4) def test_times_2_plus_3_div_4_mod_4(self): self.helper_test_numeric(lambda x: ((x*2 + 3)//4)%4) -class TestSymbolicVars(unittest.TestCase): +class TestSymbolicVariables(unittest.TestCase): def test_simple(self): z = uconst(0) a = Variable("a", 0, 10) b = Variable("b", 0, 10) c = Variable("c", 0, 10) - assert z.vars() == z.vars() == set() - print(a.vars()) - assert a.vars() == a.vars() == {a} + assert z.variables() == [] + assert a.variables() == [a] m = a * 3 - assert m.vars() == {a} + assert m.variables() == [a] s = usum([a, b, c]) - assert s.vars() == {a, b, c} + assert s.variables() == [a, b, c] def test_compound(self): a = Variable("a", 0, 10) b = Variable("b", 0, 10) c = Variable("c", 0, 10) - assert (a + b * c).vars() == {a, b, c} - assert (a % 3 + b // 5).vars() == {a, b} + assert (a + b * c).variables() == [a, b, c] + assert (a % 3 + b // 5).variables() == [a, b] # TODO: fix me with self.assertRaises(AssertionError): - assert (a + b + c - a).vars() == {b, c} + assert (a + b + c - a).variables() == [b, c] def test_dedup(self): a = Variable("a", 0, 10) - assert (a * a).vars() == {a} - assert (a//4 + a//6).vars() == {a} + assert (a * a).variables() == [a] + assert (a//4 + a//6).variables() == [a] class TestSymInfer(unittest.TestCase): def test_sym_infer(self): diff --git a/test/unit/test_call.py b/test/unit/test_call.py index f2d9434183..be342d2239 100644 --- a/test/unit/test_call.py +++ b/test/unit/test_call.py @@ -92,5 +92,13 @@ class TestCall(unittest.TestCase): np.testing.assert_allclose(a.grad.numpy(), gt_a_grad, rtol=1e-5) np.testing.assert_allclose(b.grad.numpy(), gt_b_grad, rtol=1e-5) + def test_call_plus_sharded(self): + devs = ("CPU:0", "CPU:1") + a = Tensor.ones(10, 10).shard(devs, axis=0) + b = Tensor.ones(10, 10).shard(devs, axis=0) + Tensor.realize(a, b) + c = Tensor.call(a, b, fxn=a.as_param(0) + b.as_param(1)) + np.testing.assert_equal(c.numpy(), 2 * np.ones((10, 10))) + if __name__ == '__main__': unittest.main() diff --git a/test/unit/test_system_pci_scan_bus.py b/test/unit/test_system_pci_scan_bus.py new file mode 100644 index 0000000000..2854089d9a --- /dev/null +++ b/test/unit/test_system_pci_scan_bus.py @@ -0,0 +1,28 @@ +import sys +import pytest + +@pytest.mark.skipif(sys.platform != "linux", reason="uses linux sysfs layout") +def test_pci_scan_bus_filters_vendor(monkeypatch): + import tinygrad.runtime.support.system as system + + fake = { + "/sys/bus/pci/devices/0000:00:01.0/vendor": "0x1234", + "/sys/bus/pci/devices/0000:00:01.0/device": "0x1111", + "/sys/bus/pci/devices/0000:00:02.0/vendor": "0xabcd", + "/sys/bus/pci/devices/0000:00:02.0/device": "0x1111", + } + + class FakeFileIOInterface: + def __init__(self, path, *args, **kwargs): + self.path = path + + def listdir(self): + assert self.path == "/sys/bus/pci/devices" + return ["0000:00:01.0", "0000:00:02.0"] + + def read(self, *args, **kwargs): + return fake[self.path] + + monkeypatch.setattr(system, "FileIOInterface", FakeFileIOInterface) + + assert system.System.pci_scan_bus(0x1234, devices=[(0xffff, [0x1111])]) == ["0000:00:01.0"] diff --git a/tinygrad/codegen/gpudims.py b/tinygrad/codegen/gpudims.py index 4c9decdcfc..11fa8b7c72 100644 --- a/tinygrad/codegen/gpudims.py +++ b/tinygrad/codegen/gpudims.py @@ -48,10 +48,9 @@ def get_grouped_dims(prefix, dims:tuple[sint, ...], max_sizes:tuple[int, ...]|No elif (a:=len(limited)) > (b:=len(dims)): if a == 2 and b == 1: return [raw_idxs[0] * limited[1] + raw_idxs[1]] if a == 3 and b == 1: return [(raw_idxs[0] * limited[1] + raw_idxs[1]) * limited[2] + raw_idxs[2]] - if a == 3 and b == 2: return [raw_idxs[0] * limited[1] + raw_idxs[1], raw_idxs[2]] - elif limited != dims: + if limited != dims: # Convert to 1D - flat = raw_idxs[0]*limited[1]+raw_idxs[1] if len(dims) == 2 else raw_idxs[0]*(limited[1]*limited[2])+raw_idxs[1]*limited[2]+raw_idxs[2] + flat = raw_idxs[0]*limited[1]+raw_idxs[1] if len(limited) == 2 else raw_idxs[0]*(limited[1]*limited[2])+raw_idxs[1]*limited[2]+raw_idxs[2] # Get back original indices from 1D return [flat//dims[1], flat%dims[1]] if len(dims) == 2 else [flat//(dims[2]*dims[1]), (flat//dims[2])%dims[1], flat%dims[2]] return raw_idxs diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 845fbf1376..c62d3db391 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -209,7 +209,7 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li ubufs = tuple(b.buffer for b in buf_uops) if any(isinstance(x, MultiBuffer) for x in ubufs): assert all(isinstance(x, MultiBuffer) for x in ubufs), "kernel must all be multibuffer" - dnums = [x for x in si.ast.variables() if x.arg[0] == '_device_num'] + dnums = [x for x in si.ast.variables() if x.expr == '_device_num'] for j, bufs in enumerate(zip(*[x.bufs for x in cast(tuple[MultiBuffer, ...], ubufs)])): schedule.append(ExecItem(si.ast, list(bufs), si.metadata, si.fixedvars | ({dnums[0].expr:j} if len(dnums) else {}))) else: @@ -222,5 +222,5 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li f" | {' cache hit' if SCACHE and sc_ret is not None else 'CACHE MISS'} {sched_cache_key.hex()[:8]}"+\ f" | {len(UOpMetaClass.ucache)} uops in cache") - used_vars = set().union(*[{v.arg[0] for v in si.ast.variables()} for si in schedule]) - return buffer_map, schedule, {k:v for k,v in var_vals.items() if k in used_vars} + used_vars = set().union(*[{v.expr for v in si.ast.variables()} for si in schedule]) + return buffer_map, schedule, {k:v for k,v in var_vals.items() if k in used_vars} \ No newline at end of file diff --git a/tinygrad/renderer/__init__.py b/tinygrad/renderer/__init__.py index 1778efa4fe..c417f817d1 100644 --- a/tinygrad/renderer/__init__.py +++ b/tinygrad/renderer/__init__.py @@ -86,7 +86,7 @@ class ProgramSpec: def function_name(self) -> str: return to_function_name(self.name) @functools.cached_property - def runtimevars(self) -> dict[str, int]: return {v.arg[0]: i for i, v in enumerate(self.vars) if v.arg[0] == 'core_id'} + def runtimevars(self) -> dict[str, int]: return {v.expr: i for i, v in enumerate(self.vars) if v.expr == 'core_id'} @property def applied_opts(self) -> tuple[Opt, ...]|None: return self.ast.arg.applied_opts if self.ast.arg is not None else None diff --git a/tinygrad/renderer/llvmir.py b/tinygrad/renderer/llvmir.py index a2d1a1be82..8adafe0fbf 100644 --- a/tinygrad/renderer/llvmir.py +++ b/tinygrad/renderer/llvmir.py @@ -169,7 +169,7 @@ class LLVMRenderer(Renderer): if u.arg is not None: name = u.arg.function_name continue if u.op in (Ops.PARAM, Ops.DEFINE_VAR): - r[u] = f"%data{u.arg}" if u.op is Ops.PARAM else f"%{u.arg[0]}" + r[u] = f"%data{u.arg}" if u.op is Ops.PARAM else f"%{u.expr}" args.append((r[u], u.dtype)) elif u.op in (Ops.DEFINE_LOCAL, Ops.DEFINE_REG): r[u] = f"%{'local' if u.op is Ops.DEFINE_LOCAL else 'reg'}_{str(u.arg).replace('(', '').replace(')', '').replace(',', '_').replace(' ', '')}" diff --git a/tinygrad/renderer/ptx.py b/tinygrad/renderer/ptx.py index 91c519dac9..2b766a715a 100644 --- a/tinygrad/renderer/ptx.py +++ b/tinygrad/renderer/ptx.py @@ -134,7 +134,7 @@ string_rewrite = PatternMatcher([ (UPat(Ops.ENDIF, name="x"), lambda ctx, x: f"IF_{ctx.r[x.src[0].src[0]][1:]}_{ctx.uops.index(x.src[0])}:"), (UPat(Ops.WMMA, name="x"), lambda ctx, x: list(render_wmma(ctx, x))), (UPat(Ops.BARRIER), lambda ctx: ctx.barrier), - (UPat(Ops.DEFINE_VAR, name="x"), lambda ctx, x: f"ld.param.{ctx.mem_types[x.dtype]} {ctx.r[x]}, [{x.arg[0]}+0];"), + (UPat(Ops.DEFINE_VAR, name="x"), lambda ctx, x: f"ld.param.{ctx.mem_types[x.dtype]} {ctx.r[x]}, [{x.expr}+0];"), ]) class PTXRenderer(Renderer): @@ -220,7 +220,7 @@ class PTXRenderer(Renderer): continue if u.op is Ops.INDEX: continue # other index we can skip if u.op is Ops.SPECIAL: r[u] = "%" + u.arg - elif u.op is Ops.DEFINE_VAR: bufs.append((u.arg[0], u.dtype)) + elif u.op is Ops.DEFINE_VAR: bufs.append((u.expr, u.dtype)) elif u.op is Ops.LOAD: assert u.src[0].dtype == dtypes.int64, "load isn't int64" r[u] = [ssa('val', dtype=self.types[u.dtype.scalar()]) for _ in range(u.dtype.count)] if u.dtype.count > 1 else ssa('val', u) diff --git a/tinygrad/runtime/support/memory.py b/tinygrad/runtime/support/memory.py index 4c75b16719..a63f90c8d2 100644 --- a/tinygrad/runtime/support/memory.py +++ b/tinygrad/runtime/support/memory.py @@ -81,7 +81,7 @@ class TLSFAllocator: # Round up the allocation size to the next bucket, so any entry there can fit the requested size. size = round_up(size, (1 << size.bit_length() - self.l2_cnt)) - # Search for the smallest block that can fit the requested size. Start with the it's bucket and go up until any block is found. + # Search for the smallest block that can fit the requested size. Start with its bucket and go up until any block is found. for l1 in range(self.lv1(size), len(self.storage)): if self.lv1_entries[l1] == 0: continue for l2 in range(self.lv2(size) if l1 == size.bit_length() else 0, (1 << self.l2_cnt)): @@ -105,7 +105,7 @@ class TLSFAllocator: def free(self, start:int): self._insert_block(start - self.base, self.blocks[start - self.base][0])._merge_block(start - self.base) -# Memory Managment +# Memory Management class AddrSpace(enum.Enum): PHYS = enum.auto(); SYS = enum.auto(); PEER = enum.auto() # noqa: E702 @@ -221,7 +221,7 @@ class MemoryManager: @classmethod def alloc_vaddr(cls, size:int, align=0x1000) -> int: - assert cls.va_allocator is not None, "must be set it" + assert cls.va_allocator is not None, "must be set" return cls.va_allocator.alloc(size, max((1 << (size.bit_length() - 1)), align)) def valloc(self, size:int, align=0x1000, uncached=False, contiguous=False) -> VirtMapping: @@ -248,7 +248,7 @@ class MemoryManager: return self.map_range(va, size, paddrs, aspace=AddrSpace.PHYS, uncached=uncached) def vfree(self, vm:VirtMapping): - assert self.va_allocator is not None, "must be set it" + assert self.va_allocator is not None, "must be set" self.unmap_range(vm.va_addr, vm.size) self.va_allocator.free(vm.va_addr) for paddr, _ in vm.paddrs: self.pa_allocator.free(paddr) diff --git a/tinygrad/runtime/support/nv/nvdev.py b/tinygrad/runtime/support/nv/nvdev.py index 9a2aceed88..e516edea60 100644 --- a/tinygrad/runtime/support/nv/nvdev.py +++ b/tinygrad/runtime/support/nv/nvdev.py @@ -77,7 +77,7 @@ class NVDev(PCIDevImplBase): self._early_ip_init() self._early_mmu_init() - # Turn the booting early, gsp client is loaded from the clean. + # No booting state, gsp client is reinited every run. self.is_booting = False for ip in [self.flcn, self.gsp]: ip.init_sw() diff --git a/tinygrad/runtime/support/system.py b/tinygrad/runtime/support/system.py index 04ad686788..af814d4fb2 100644 --- a/tinygrad/runtime/support/system.py +++ b/tinygrad/runtime/support/system.py @@ -7,7 +7,8 @@ from tinygrad.runtime.support.hcq import FileIOInterface, MMIOInterface, HCQBuff from tinygrad.runtime.support.memory import MemoryManager, VirtMapping, AddrSpace from tinygrad.runtime.support.usb import ASM24Controller, USBMMIOInterface -MAP_FIXED, MAP_LOCKED, MAP_POPULATE, MAP_NORESERVE = 0x10, 0 if OSX else 0x2000, getattr(mmap, "MAP_POPULATE", 0 if OSX else 0x008000), 0x400 +MAP_FIXED, MAP_FIXED_NOREPLACE = 0x10, 0x100000 +MAP_LOCKED, MAP_POPULATE, MAP_NORESERVE = 0 if OSX else 0x2000, getattr(mmap, "MAP_POPULATE", 0 if OSX else 0x008000), 0x400 @dataclasses.dataclass(frozen=True) class PCIBarInfo: addr:int; size:int # noqa: E702 @@ -69,7 +70,7 @@ class _System: all_devs.append((int(FileIOInterface(f"/sys/bus/pci/devices/{pcibus}/vendor").read(), 16), int(FileIOInterface(f"/sys/bus/pci/devices/{pcibus}/device").read(), 16), pcibus)) - return sorted([val for vendor, device, val in all_devs if vendor == vendor and any((device & mask) in devlist for mask, devlist in devices)]) + return sorted([val for vndr, device, val in all_devs if vndr == vendor and any((device & mask) in devlist for mask, devlist in devices)]) def pci_setup_usb_bars(self, usb:ASM24Controller, gpu_bus:int, mem_base:int, pref_mem_base:int) -> dict[int, PCIBarInfo]: for bus in range(gpu_bus): @@ -219,7 +220,7 @@ class LNXPCIIfaceBase: cls.gpus = hcq_filter_visible_devices(System.pci_scan_bus(vendor, devices, base_class)) # Acquire va range to avoid collisions. - FileIOInterface.anon_mmap(va_start, va_size, 0, mmap.MAP_PRIVATE | mmap.MAP_ANONYMOUS | MAP_NORESERVE | MAP_FIXED, 0) + FileIOInterface.anon_mmap(va_start, va_size, 0, mmap.MAP_PRIVATE | mmap.MAP_ANONYMOUS | MAP_NORESERVE | MAP_FIXED_NOREPLACE, 0) self.pci_dev, self.dev, self.vram_bar = PCIDevice(dev.__class__.__name__[:2], cls.gpus[dev_id], bars=bars, resize_bars=[vram_bar]), dev, vram_bar self.p2p_base_addr = self.pci_dev.bar_info[vram_bar].addr diff --git a/tinygrad/schedule/multi.py b/tinygrad/schedule/multi.py index 7d3b3fada3..b1493dbc58 100644 --- a/tinygrad/schedule/multi.py +++ b/tinygrad/schedule/multi.py @@ -1,4 +1,3 @@ -from typing import cast import functools, itertools from tinygrad.helpers import all_same, all_int, prod, DEBUG, RING, ALL2ALL, VIZ, getenv from tinygrad.uop.ops import Ops, UOp, PatternMatcher, UPat, GroupOp, graph_rewrite_map, graph_rewrite @@ -29,7 +28,7 @@ def handle_allreduce(buf:UOp, red:UOp) -> UOp|None: chunks = list(itertools.pairwise(itertools.accumulate([(base + 1) * factor] * left + [base * factor] * (ndev - left), initial=0))) # reduce-scatter - reduced_chunks = [] + reduced_chunks:list[UOp] = [] for i,(s,e) in enumerate(chunks): if use_all2all: chunks_on_i = [buf.mselect(j).reshape((numel,)).shrink(((s,e),)).copy_to_device(buf.device[i]) for j in range(ndev)] @@ -43,16 +42,15 @@ def handle_allreduce(buf:UOp, red:UOp) -> UOp|None: reduced_chunks.append(reduced) # allgather - copied_chunks = [] + copied_chunks:list[UOp] = [] for i,rc in enumerate(reduced_chunks): if isinstance(red.src[1].arg, str): copied_chunks.append(rc.copy_to_device(red.src[1].arg)) elif use_all2all: copied_chunks.append(UOp(Ops.MSTACK, buf.dtype, tuple(rc.copy_to_device(buf.device[j]) for j in range(ndev)))) else: - this_chunk: list[UOp|None] = [None] * ndev - this_chunk[(i+ndev-1)%ndev] = rc + chain:list[UOp] = [rc] for step in range(ndev-1): - this_chunk[(i+step)%ndev] = rc = rc.copy_to_device(buf.device[(i+step)%ndev]) - copied_chunks.append(UOp(Ops.MSTACK, buf.dtype, tuple(cast(list[UOp], this_chunk)))) + chain.append(rc := rc.copy_to_device(buf.device[(i+step)%ndev])) + copied_chunks.append(UOp(Ops.MSTACK, buf.dtype, tuple(chain[(j-i+1)%ndev] for j in range(ndev)))) # reassemble return UOp.sum(*[c.pad(((s,numel-e),)) for (s,e),c in zip(chunks, copied_chunks)]).reshape(shape) @@ -63,7 +61,7 @@ def mstack_early_shrink(ms:UOp, shrink:UOp): ret:list[UOp] = [] def apply_shrink(s:UOp, i:int) -> UOp: new_arg = [tuple([x.substitute({dvar[0]:dvar[0].const_like(i)}) if isinstance(x, UOp) and - (dvar:=[v for v in x.vars() if v.op is Ops.DEFINE_VAR and v.arg[0]=='_device_num']) else x for x in ss]) for ss in shrink.marg] + (dvar:=[v for v in x.variables() if v.expr=='_device_num']) else x for x in ss]) for ss in shrink.marg] return s.shrink(tuple(new_arg)) for i, x in enumerate(ms.src): if x.op is Ops.COPY: @@ -97,20 +95,20 @@ def alu_multi(root:UOp): axis = root.axis assert axis is not None - srcs = [] + srcs:list[UOp] = [] for mlb in msrcs: - if mlb.axis == axis: - # same axis, just copy through - assert mlb.op is Ops.MULTI - srcs.append(mlb.src[0]) - elif mlb.axis is None: + if mlb.axis is None: # no axis, shard it assert mlb.op is not Ops.MULTI srcs.append(mlb._shard(axis)) else: - # axis mismatch, unshard it, send it to all devices, and shard it correctly assert mlb.op is Ops.MULTI - srcs.append(mlb.src[0]._unshard(mlb.axis).allreduce(Ops.ADD, mlb.device)._shard(axis)) + if mlb.axis == axis: + # same axis, just copy through + srcs.append(mlb.src[0]) + else: + # axis mismatch, unshard it, send it to all devices, and shard it correctly + srcs.append(mlb.src[0]._unshard(mlb.axis).allreduce(Ops.ADD, mlb.device)._shard(axis)) return srcs[0].alu(root.op, *srcs[1:]).multi(axis) def reduce_multi(root:UOp, multi:UOp): diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index b8047e8fe8..512b92e0c1 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -721,12 +721,8 @@ class UOp(OpMixin, metaclass=UOpMetaClass): return graph_rewrite(self, pm_unbind, ctx=ret), ret @property def val(self) -> int: return self.unbind()[1] - def vars(self) -> set[UOp]: - topo = self.toposort() - bound = {x.src[0]: x for x in topo if x.op is Ops.BIND and x.src[0].op is Ops.DEFINE_VAR} - return {bound.get(x, x) for x in topo if x.op is Ops.DEFINE_VAR} def variables(self) -> list[Variable]: - return sorted(set([x.unbind()[0] if x.op is not Ops.DEFINE_VAR else x for x in self.vars()]), key=lambda v: v.arg) + return sorted({x for x in self.backward_slice_with_self if x.op is Ops.DEFINE_VAR}, key=lambda v: v.arg) # *** uop symbolic stuff *** @@ -820,7 +816,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass): @functools.cached_property def _sym_fxn(self): sself = self.simplify() - varnames = tuple(x.arg[0] for x in sself.toposort() if x.op is Ops.DEFINE_VAR) + varnames = tuple(x.expr for x in sself.toposort() if x.op is Ops.DEFINE_VAR) # TODO: sanitize varnames, or don't use naked eval while staying fast return eval("lambda "+','.join(varnames)+": "+sself.render(pm=renderer_infer)), varnames # pylint: disable=eval-used @@ -1387,7 +1383,7 @@ def bitcast(x, in_dtype:DType, out_dtype:DType): return ret[0] if out_count == 1 else ret renderer = PatternMatcher([ - (UPat((Ops.DEFINE_VAR,), name="x"), lambda x: x.arg[0]), + (UPat((Ops.DEFINE_VAR,), name="x"), lambda x: x.expr), (UPat((Ops.SPECIAL), name="x"), lambda x: x.arg), (UPat(Ops.RANGE, name="x"), lambda x: f"r{range_str(x)}"), (UPat((Ops.CONST, Ops.VCONST), name="x"), lambda x: str(x.arg)),