Merge branch 'master' into prealloc_bufs

This commit is contained in:
George Hotz
2026-02-20 09:11:44 +08:00
committed by GitHub
16 changed files with 191 additions and 166 deletions

View File

@@ -47,24 +47,21 @@ def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False
# delta_vec = (do * attn).sum(-1, dtype=dtypes.float32).transpose(1, 2).unsqueeze(-2).detach()
delta_vec = _sharded_empty((B, H, 1, N), xq, axis=0, dtype=dtypes.float32)
delta_vec, dq_in = Tensor.custom_kernel(delta_vec, dq_in, attn, do, fxn=functools.partial(custom_fa_backward_pre, device=single_device, arch=arch))[:2]
delta_vec, dq_in = Tensor.custom_kernel(delta_vec, dq_in, attn, do, fxn=functools.partial(custom_fa_backward_pre, device=single_device, arch=arch, B=B_local, N=N, H=H, H_KV=H_KV, D=D))[:2]
dq_in, dk, dv = Tensor.custom_kernel(dq_in, dk, dv, do, xq, xk, xv, l_vec, delta_vec, fxn=functools.partial(custom_fa_backward, device=single_device, arch=arch))[:3]
dq_in, dk, dv = Tensor.custom_kernel(dq_in, dk, dv, do, xq, xk, xv, l_vec, delta_vec, fxn=functools.partial(custom_fa_backward, device=single_device, arch=arch, B=B_local, N=N, H=H, H_KV=H_KV, D=D))[:3]
# unshuffle dq
dq = Tensor.custom_kernel(dq, dq_in, fxn=functools.partial(custom_fa_backward_post, device=single_device, arch=arch))[0]
dq = Tensor.custom_kernel(dq, dq_in, fxn=functools.partial(custom_fa_backward_post, device=single_device, arch=arch, B=B_local, N=N, H=H, H_KV=H_KV, D=D))[0]
return None, None, dq.uop, dk.uop, dv.uop
attn, l_vec = Tensor.custom_kernel(attn, l_vec, xq, xk, xv, fxn=functools.partial(custom_fa_forward, device=single_device, arch=arch), grad_fxn=grad)[:2]
attn, l_vec = Tensor.custom_kernel(attn, l_vec, xq, xk, xv, fxn=functools.partial(custom_fa_forward, device=single_device, arch=arch, B=B_local, N=N, H=H, H_KV=H_KV, D=D), grad_fxn=grad)[:2]
return attn.transpose(1, 2)
@functools.cache
def custom_fa_forward(o:UOp, l_vec:UOp, q:UOp, k:UOp, v:UOp, device:str, arch:str):
B, N, H, D = q.shape
H_KV = k.shape[2]
def custom_fa_forward(o:UOp, l_vec:UOp, q:UOp, k:UOp, v:UOp, device:str, arch:str, B:int, N:int, H:int, H_KV:int, D:int):
code = (pathlib.Path(__file__).parent / "fa_fwd_causal.cpp").read_text()
compile_args = [f"-I{(pathlib.Path(__file__).parent / 'include').as_posix()}", "-std=c++20", "-DKITTENS_CDNA4", "-DHIP_ENABLE_WARP_SYNC_BUILTINS", "-ffast-math",
f"-DATTN_B={B}", f"-DATTN_N={N}", f"-DATTN_H={H}", f"-DATTN_H_KV={H_KV}"]
@@ -95,9 +92,7 @@ def custom_fa_forward(o:UOp, l_vec:UOp, q:UOp, k:UOp, v:UOp, device:str, arch:st
src=(sink, UOp(Ops.DEVICE, arg=device), UOp(Ops.LINEAR, src=(*sink.src, sink)), UOp(Ops.SOURCE, arg=code), UOp(Ops.BINARY, arg=lib)))
@functools.cache
def custom_fa_backward_pre(delta_vec:UOp, dq:UOp, o:UOp, do:UOp, device:str, arch:str):
B, N, H, D = o.shape
def custom_fa_backward_pre(delta_vec:UOp, dq:UOp, o:UOp, do:UOp, device:str, arch:str, B:int, N:int, H:int, H_KV:int, D:int):
code = (pathlib.Path(__file__).parent / "fa_bwd_pre.cpp").read_text()
compile_args = [f"-I{(pathlib.Path(__file__).parent / 'include').as_posix()}", "-std=c++20", "-DKITTENS_CDNA4", "-DHIP_ENABLE_WARP_SYNC_BUILTINS", "-ffast-math",
f"-DATTN_B={B}", f"-DATTN_N={N}", f"-DATTN_H={H}"]
@@ -128,10 +123,7 @@ def custom_fa_backward_pre(delta_vec:UOp, dq:UOp, o:UOp, do:UOp, device:str, arc
src=(sink, UOp(Ops.DEVICE, arg=device), UOp(Ops.LINEAR, src=(*sink.src, sink)), UOp(Ops.SOURCE, arg=code), UOp(Ops.BINARY, arg=lib)))
@functools.cache
def custom_fa_backward(dq:UOp, dk:UOp, dv:UOp, do:UOp, q:UOp, k:UOp, v:UOp, l_vec:UOp, delta_vec:UOp, device:str, arch:str):
B, N, H, D = q.shape
H_KV = k.shape[2]
def custom_fa_backward(dq:UOp, dk:UOp, dv:UOp, do:UOp, q:UOp, k:UOp, v:UOp, l_vec:UOp, delta_vec:UOp, device:str, arch:str, B:int, N:int, H:int, H_KV:int, D:int):
code = (pathlib.Path(__file__).parent / "fa_bwd_causal.cpp").read_text()
compile_args = [f"-I{(pathlib.Path(__file__).parent / 'include').as_posix()}", "-std=c++20", "-DKITTENS_CDNA4", "-DHIP_ENABLE_WARP_SYNC_BUILTINS", "-ffast-math",
f"-DATTN_B={B}", f"-DATTN_N={N}", f"-DATTN_H={H}", f"-DATTN_H_KV={H_KV}"]
@@ -162,9 +154,7 @@ def custom_fa_backward(dq:UOp, dk:UOp, dv:UOp, do:UOp, q:UOp, k:UOp, v:UOp, l_ve
src=(sink, UOp(Ops.DEVICE, arg=device), UOp(Ops.LINEAR, src=(*sink.src, sink)), UOp(Ops.SOURCE, arg=code), UOp(Ops.BINARY, arg=lib)))
@functools.cache
def custom_fa_backward_post(dq_out:UOp, dq_in:UOp, device:str, arch:str):
B, N, H, D = dq_out.shape
def custom_fa_backward_post(dq_out:UOp, dq_in:UOp, device:str, arch:str, B:int, N:int, H:int, H_KV:int, D:int):
code = (pathlib.Path(__file__).parent / "fa_bwd_post.cpp").read_text()
compile_args = [f"-I{(pathlib.Path(__file__).parent / 'include').as_posix()}", "-std=c++20", "-DKITTENS_CDNA4", "-DHIP_ENABLE_WARP_SYNC_BUILTINS", "-ffast-math",
f"-DATTN_B={B}", f"-DATTN_N={N}", f"-DATTN_H={H}"]

View File

@@ -3,8 +3,7 @@ import unittest
from dataclasses import replace
from tinygrad.codegen.opt import Opt, OptOps
from tinygrad.codegen.gpudims import get_grouped_dims
from tinygrad.uop.ops import UOp, Ops, GroupOp, AxisType, PatternMatcher, graph_rewrite, UPat
from tinygrad.uop.ops import UOp, Ops, GroupOp, AxisType
from tinygrad.device import Device, Buffer, is_dtype_supported
from tinygrad.tensor import Tensor, _to_np_dtype
from tinygrad.engine.realize import run_schedule, CompiledRunner, get_program
@@ -253,100 +252,6 @@ class TestLinearizer(unittest.TestCase):
if any(x.op is Ops.END and x.src[1].op in GroupOp.ALU for x in u.src):
assert end_range < uops.index(u)
def test_grouped_dims(self):
def _assert_grouped_dims(prefix, dims, max_sizes, reverse_dims, expected_sizes, assert_same_length = True):
idxs = get_grouped_dims(prefix, dims, max_sizes, reverse_dims)
loop_idxs = dedup(flatten([[y for y in x.toposort() if y.op is Ops.SPECIAL] for x in idxs]))
loop_idxs = sorted(loop_idxs, key=lambda uop: uop.arg)
sizes = [x.src[0].arg for x in loop_idxs]
assert len(idxs) == len(dims), f"expected idxs to have same length as dims {len(dims)}, got {len(idxs)}"
if assert_same_length:
assert len(loop_idxs) == min(len(sizes), len(dims)), f"expected idxs to have length {min(len(sizes), len(dims))}, got {len(loop_idxs)}"
assert sizes == expected_sizes, f"expected sizes={expected_sizes}, got {sizes=}"
# TODO: add these back after uop symbolic
# for i in range(len(dims)):
# assert idxs[i].max+1 == dims[i], f"idxs[{i}] should have max {dims[i]-1}"
# for i in range(len(loop_idxs)):
# assert loop_idxs[i].expr.startswith(prefix), f"loop_idxs[{i}] must start with {prefix}"
# assert loop_idxs[i].max+1 == sizes[i], f"loop_idxs[{i}] should have max {sizes[i]-1}"
# no-op
_assert_grouped_dims("gidx", (2,), (16,16,16), False, [2])
_assert_grouped_dims("gidx", (2,3), (16,16,16), False, [2,3])
# check reverse dims
_assert_grouped_dims("gidx", (2,3), (16,16,16), True, [3,2])
_assert_grouped_dims("gidx", (2,3,4), (16,16,16), False, [2,3,4])
# test splitting globals: len(dims) == len(max)
_assert_grouped_dims("gidx", (64,3,4), (16,16,16), False, [16,12,4])
_assert_grouped_dims("gidx", (64,3,4), (16,4,16), False, [16,3,16])
_assert_grouped_dims("gidx", (64,3,4), (16,16,16), True, [16,3,16])
_assert_grouped_dims("gidx", (128,3,4), (16,4,256), False, [16,3,32])
_assert_grouped_dims("gidx", (4,4,512), (16,4,256), False, [8,4,256])
# prefer group_dim strategy when possible
_assert_grouped_dims("gidx", (512,4,2), (8192,2,2), False, [2048,2])
# test splitting globals: len(dims) < len(max)
# len(dim) -> len(limited)
# 1 -> 2
_assert_grouped_dims("gidx", (128,), (16,16,256), False, [16,8], False)
# 1 -> 3
_assert_grouped_dims("gidx", (65536,), (16,16,256), False, [16,16,256], False)
# 2 -> 3
_assert_grouped_dims("gidx", (128,128), (16,16,256), False, [16,16,64], False)
# 2 -> 2
_assert_grouped_dims("gidx", (65536,2), (65535,65535,65535), False, [32768,4], False)
# test when the only divisor is the square root of dim
_assert_grouped_dims("gidx", (121,), (12,12,12), False, [11,11], False)
# collapse on onto the left most axis
_assert_grouped_dims("gidx", (2,3,4,5), (16,16,16), False, [6,4,5])
_assert_grouped_dims("gidx", (2,3,4,5), (32,16,16), True, [20,3,2])
# _assert_grouped_dims("gidx", (Variable("start_pos",1,2),3,4,5), (32,16,16), True, [20,3,Variable("start_pos",1,2)])
# collapse on left-most available axis (the left most is too small)
_assert_grouped_dims("gidx", (2,3,4,5), (4,16,16), False, [2,12,5])
_assert_grouped_dims("gidx", (2,3,4,5), (16,16,16), True, [5,12,2])
# _assert_grouped_dims("gidx", (Variable("start_pos",1,2),3,4,5), (16,16,16), False, [Variable("start_pos",1,2)*3,4,5])
# dim too large and not factorable
with self.assertRaises(RuntimeError):
get_grouped_dims("gidx", (23,), (16,16,16), False,)
with self.assertRaises(RuntimeError):
get_grouped_dims("gidx", (128,3,4), (16,2,2), False,)
# too large for sizes
with self.assertRaises(RuntimeError):
get_grouped_dims("gidx", (2,3,4,5,6), (16,16,16))
# TODO: In the above cases we only test if the shape after reshape is correct, never the indices.
# We should check if the returned indices are correct, for all cases.
# (65536, 2) -> (32768, 4)
dims, expected_limited_dims = (65536,2), (32768, 4)
idxs = get_grouped_dims("gidx", dims, (65535,65535,65535))
def match_div(): raise RuntimeError("match_div")
def match_mod(): raise RuntimeError("match_mod")
flat_idx_pattern = UPat(Ops.SPECIAL, arg='gidx0')*expected_limited_dims[1]+UPat(Ops.SPECIAL, arg='gidx1')
pm = PatternMatcher([
(flat_idx_pattern//dims[1], match_div),
(flat_idx_pattern%dims[1], match_mod)
])
with self.assertRaises(RuntimeError) as error:
graph_rewrite(idxs[0], pm)
self.assertIn("match_div", str(error.exception))
with self.assertRaises(RuntimeError) as error:
graph_rewrite(idxs[1], pm)
self.assertIn("match_mod", str(error.exception))
# # variable too large
# with self.assertRaises(AssertionError):
# get_grouped_dims("gidx", (Variable("start_pos",0,16),3,4), (16,16,16), False,)
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
def test_default_global_reversed(self):
# shrink so that the dims do not collapse

101
test/null/test_gpudims.py Normal file
View File

@@ -0,0 +1,101 @@
import unittest, math
import z3
from tinygrad.codegen.gpudims import get_grouped_dims
from tinygrad.uop.ops import UOp, Ops
from tinygrad.uop.validate import uops_to_z3
from tinygrad.dtype import dtypes
from tinygrad.helpers import flatten, dedup
class TestGroupedDims(unittest.TestCase):
def _check_grouped_dims(self, prefix, dims, max_sizes, reverse, expected_sizes, assert_same_length=True):
idxs = get_grouped_dims(prefix, dims, max_sizes, reverse)
loop_idxs = dedup(flatten([[y for y in x.toposort() if y.op is Ops.SPECIAL] for x in idxs]))
loop_idxs = sorted(loop_idxs, key=lambda uop: uop.arg)
sizes = [x.src[0].arg for x in loop_idxs]
assert len(idxs) == len(dims), f"expected idxs to have same length as dims {len(dims)}, got {len(idxs)}"
if assert_same_length:
assert len(loop_idxs) == min(len(sizes), len(dims)), f"expected idxs to have length {min(len(sizes), len(dims))}, got {len(loop_idxs)}"
assert sizes == expected_sizes, f"expected sizes={expected_sizes}, got {sizes=}"
self._verify_indices_z3(idxs, dims)
def _verify_indices_z3(self, idxs, dims):
"""Use z3 to prove bijectivity: bounds (0 <= flat < total) + injectivity (different inputs => different flat)."""
total = math.prod(dims)
specials = sorted(dedup(flatten([[y for y in x.toposort() if y.op is Ops.SPECIAL] for x in idxs])), key=lambda u: u.arg)
# build flat index and primed flat (same expression with renamed SPECIALs)
flat = UOp.const(dtypes.index, 0)
for i, idx in enumerate(idxs):
flat = flat + idx * int(math.prod(dims[i+1:]))
flat_p = flat.substitute({s: UOp(Ops.SPECIAL, s.dtype, s.src, s.arg+"_p") for s in specials})
solver = z3.Solver()
[z3_flat, z3_flat_p] = uops_to_z3(solver, flat, flat_p)
# bounds
self.assertEqual(solver.check(z3_flat < 0), z3.unsat, f"flat can be negative: {dims=}")
self.assertEqual(solver.check(z3_flat >= total), z3.unsat, f"flat can be >= {total}: {dims=}")
# injectivity: flat == flat' but inputs differ => unsat
inputs_differ = z3.Or(*[z3.Int(s.arg) != z3.Int(s.arg+"_p") for s in specials])
self.assertEqual(solver.check(z3.And(z3_flat == z3_flat_p, inputs_differ)), z3.unsat, f"not injective: {dims=}")
def test_grouped_dims(self):
# no-op
self._check_grouped_dims("gidx", (2,), (16,16,16), False, [2])
self._check_grouped_dims("gidx", (2,3), (16,16,16), False, [2,3])
# check reverse dims
self._check_grouped_dims("gidx", (2,3), (16,16,16), True, [3,2])
self._check_grouped_dims("gidx", (2,3,4), (16,16,16), False, [2,3,4])
# test splitting globals: len(dims) == len(max)
self._check_grouped_dims("gidx", (64,3,4), (16,16,16), False, [16,12,4])
self._check_grouped_dims("gidx", (64,3,4), (16,4,16), False, [16,3,16])
self._check_grouped_dims("gidx", (64,3,4), (16,16,16), True, [16,3,16])
self._check_grouped_dims("gidx", (128,3,4), (16,4,256), False, [16,3,32])
self._check_grouped_dims("gidx", (4,4,512), (16,4,256), False, [8,4,256])
self._check_grouped_dims("gidx", (5,12,7), (8,4,16), False, [10,3,14])
# prefer group_dim strategy when possible
self._check_grouped_dims("gidx", (512,4,2), (8192,2,2), False, [2048,2])
# test splitting globals: len(dims) < len(max)
# len(dim) -> len(limited)
# 1 -> 2
self._check_grouped_dims("gidx", (128,), (16,16,256), False, [16,8], False)
# 1 -> 3
self._check_grouped_dims("gidx", (65536,), (16,16,256), False, [16,16,256], False)
# 2 -> 2
self._check_grouped_dims("gidx", (65536,2), (65535,65535,65535), False, [32768,4], False)
# test when the only divisor is the square root of dim
self._check_grouped_dims("gidx", (121,), (12,12,12), False, [11,11], False)
# 2 -> 3
self._check_grouped_dims("gidx", (128,128), (16,16,256), False, [16,16,64], False)
# collapse on onto the left most axis
self._check_grouped_dims("gidx", (2,3,4,5), (16,16,16), False, [6,4,5])
self._check_grouped_dims("gidx", (2,3,4,5), (32,16,16), True, [20,3,2])
# collapse on left-most available axis (the left most is too small)
self._check_grouped_dims("gidx", (2,3,4,5), (4,16,16), False, [2,12,5])
self._check_grouped_dims("gidx", (2,3,4,5), (16,16,16), True, [5,12,2])
# dim too large and not factorable
with self.assertRaises(RuntimeError):
get_grouped_dims("gidx", (23,), (16,16,16), False,)
with self.assertRaises(RuntimeError):
get_grouped_dims("gidx", (128,3,4), (16,2,2), False,)
# too large for sizes
with self.assertRaises(RuntimeError):
get_grouped_dims("gidx", (2,3,4,5,6), (16,16,16))
def test_grouped_direct_dims_are_special(self):
# when (2,3) are merged into 6, the unmerged dims (4,5) should map directly to SPECIAL ops (no div/mod)
idxs = get_grouped_dims("gidx", (2,3,4,5), (16,16,16), False)
assert idxs[2].op is Ops.SPECIAL, f"expected SPECIAL for direct-mapped dim, got {idxs[2].op}"
assert idxs[3].op is Ops.SPECIAL, f"expected SPECIAL for direct-mapped dim, got {idxs[3].op}"
def test_max_sizes_none(self):
self._check_grouped_dims("gidx", (2,3,4), None, False, [2,3,4])
self._check_grouped_dims("gidx", (100,), None, False, [100])
if __name__ == '__main__':
unittest.main()

View File

@@ -839,34 +839,33 @@ class TestSymbolicNumeric(unittest.TestCase):
def test_times_2_plus_3_div_4(self): self.helper_test_numeric(lambda x: (x*2 + 3)//4)
def test_times_2_plus_3_div_4_mod_4(self): self.helper_test_numeric(lambda x: ((x*2 + 3)//4)%4)
class TestSymbolicVars(unittest.TestCase):
class TestSymbolicVariables(unittest.TestCase):
def test_simple(self):
z = uconst(0)
a = Variable("a", 0, 10)
b = Variable("b", 0, 10)
c = Variable("c", 0, 10)
assert z.vars() == z.vars() == set()
print(a.vars())
assert a.vars() == a.vars() == {a}
assert z.variables() == []
assert a.variables() == [a]
m = a * 3
assert m.vars() == {a}
assert m.variables() == [a]
s = usum([a, b, c])
assert s.vars() == {a, b, c}
assert s.variables() == [a, b, c]
def test_compound(self):
a = Variable("a", 0, 10)
b = Variable("b", 0, 10)
c = Variable("c", 0, 10)
assert (a + b * c).vars() == {a, b, c}
assert (a % 3 + b // 5).vars() == {a, b}
assert (a + b * c).variables() == [a, b, c]
assert (a % 3 + b // 5).variables() == [a, b]
# TODO: fix me
with self.assertRaises(AssertionError):
assert (a + b + c - a).vars() == {b, c}
assert (a + b + c - a).variables() == [b, c]
def test_dedup(self):
a = Variable("a", 0, 10)
assert (a * a).vars() == {a}
assert (a//4 + a//6).vars() == {a}
assert (a * a).variables() == [a]
assert (a//4 + a//6).variables() == [a]
class TestSymInfer(unittest.TestCase):
def test_sym_infer(self):

View File

@@ -92,5 +92,13 @@ class TestCall(unittest.TestCase):
np.testing.assert_allclose(a.grad.numpy(), gt_a_grad, rtol=1e-5)
np.testing.assert_allclose(b.grad.numpy(), gt_b_grad, rtol=1e-5)
def test_call_plus_sharded(self):
devs = ("CPU:0", "CPU:1")
a = Tensor.ones(10, 10).shard(devs, axis=0)
b = Tensor.ones(10, 10).shard(devs, axis=0)
Tensor.realize(a, b)
c = Tensor.call(a, b, fxn=a.as_param(0) + b.as_param(1))
np.testing.assert_equal(c.numpy(), 2 * np.ones((10, 10)))
if __name__ == '__main__':
unittest.main()

View File

@@ -0,0 +1,28 @@
import sys
import pytest
@pytest.mark.skipif(sys.platform != "linux", reason="uses linux sysfs layout")
def test_pci_scan_bus_filters_vendor(monkeypatch):
import tinygrad.runtime.support.system as system
fake = {
"/sys/bus/pci/devices/0000:00:01.0/vendor": "0x1234",
"/sys/bus/pci/devices/0000:00:01.0/device": "0x1111",
"/sys/bus/pci/devices/0000:00:02.0/vendor": "0xabcd",
"/sys/bus/pci/devices/0000:00:02.0/device": "0x1111",
}
class FakeFileIOInterface:
def __init__(self, path, *args, **kwargs):
self.path = path
def listdir(self):
assert self.path == "/sys/bus/pci/devices"
return ["0000:00:01.0", "0000:00:02.0"]
def read(self, *args, **kwargs):
return fake[self.path]
monkeypatch.setattr(system, "FileIOInterface", FakeFileIOInterface)
assert system.System.pci_scan_bus(0x1234, devices=[(0xffff, [0x1111])]) == ["0000:00:01.0"]

View File

@@ -48,10 +48,9 @@ def get_grouped_dims(prefix, dims:tuple[sint, ...], max_sizes:tuple[int, ...]|No
elif (a:=len(limited)) > (b:=len(dims)):
if a == 2 and b == 1: return [raw_idxs[0] * limited[1] + raw_idxs[1]]
if a == 3 and b == 1: return [(raw_idxs[0] * limited[1] + raw_idxs[1]) * limited[2] + raw_idxs[2]]
if a == 3 and b == 2: return [raw_idxs[0] * limited[1] + raw_idxs[1], raw_idxs[2]]
elif limited != dims:
if limited != dims:
# Convert to 1D
flat = raw_idxs[0]*limited[1]+raw_idxs[1] if len(dims) == 2 else raw_idxs[0]*(limited[1]*limited[2])+raw_idxs[1]*limited[2]+raw_idxs[2]
flat = raw_idxs[0]*limited[1]+raw_idxs[1] if len(limited) == 2 else raw_idxs[0]*(limited[1]*limited[2])+raw_idxs[1]*limited[2]+raw_idxs[2]
# Get back original indices from 1D
return [flat//dims[1], flat%dims[1]] if len(dims) == 2 else [flat//(dims[2]*dims[1]), (flat//dims[2])%dims[1], flat%dims[2]]
return raw_idxs

View File

@@ -209,7 +209,7 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
ubufs = tuple(b.buffer for b in buf_uops)
if any(isinstance(x, MultiBuffer) for x in ubufs):
assert all(isinstance(x, MultiBuffer) for x in ubufs), "kernel must all be multibuffer"
dnums = [x for x in si.ast.variables() if x.arg[0] == '_device_num']
dnums = [x for x in si.ast.variables() if x.expr == '_device_num']
for j, bufs in enumerate(zip(*[x.bufs for x in cast(tuple[MultiBuffer, ...], ubufs)])):
schedule.append(ExecItem(si.ast, list(bufs), si.metadata, si.fixedvars | ({dnums[0].expr:j} if len(dnums) else {})))
else:
@@ -222,5 +222,5 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
f" | {' cache hit' if SCACHE and sc_ret is not None else 'CACHE MISS'} {sched_cache_key.hex()[:8]}"+\
f" | {len(UOpMetaClass.ucache)} uops in cache")
used_vars = set().union(*[{v.arg[0] for v in si.ast.variables()} for si in schedule])
return buffer_map, schedule, {k:v for k,v in var_vals.items() if k in used_vars}
used_vars = set().union(*[{v.expr for v in si.ast.variables()} for si in schedule])
return buffer_map, schedule, {k:v for k,v in var_vals.items() if k in used_vars}

View File

@@ -86,7 +86,7 @@ class ProgramSpec:
def function_name(self) -> str: return to_function_name(self.name)
@functools.cached_property
def runtimevars(self) -> dict[str, int]: return {v.arg[0]: i for i, v in enumerate(self.vars) if v.arg[0] == 'core_id'}
def runtimevars(self) -> dict[str, int]: return {v.expr: i for i, v in enumerate(self.vars) if v.expr == 'core_id'}
@property
def applied_opts(self) -> tuple[Opt, ...]|None: return self.ast.arg.applied_opts if self.ast.arg is not None else None

View File

@@ -169,7 +169,7 @@ class LLVMRenderer(Renderer):
if u.arg is not None: name = u.arg.function_name
continue
if u.op in (Ops.PARAM, Ops.DEFINE_VAR):
r[u] = f"%data{u.arg}" if u.op is Ops.PARAM else f"%{u.arg[0]}"
r[u] = f"%data{u.arg}" if u.op is Ops.PARAM else f"%{u.expr}"
args.append((r[u], u.dtype))
elif u.op in (Ops.DEFINE_LOCAL, Ops.DEFINE_REG):
r[u] = f"%{'local' if u.op is Ops.DEFINE_LOCAL else 'reg'}_{str(u.arg).replace('(', '').replace(')', '').replace(',', '_').replace(' ', '')}"

View File

@@ -134,7 +134,7 @@ string_rewrite = PatternMatcher([
(UPat(Ops.ENDIF, name="x"), lambda ctx, x: f"IF_{ctx.r[x.src[0].src[0]][1:]}_{ctx.uops.index(x.src[0])}:"),
(UPat(Ops.WMMA, name="x"), lambda ctx, x: list(render_wmma(ctx, x))),
(UPat(Ops.BARRIER), lambda ctx: ctx.barrier),
(UPat(Ops.DEFINE_VAR, name="x"), lambda ctx, x: f"ld.param.{ctx.mem_types[x.dtype]} {ctx.r[x]}, [{x.arg[0]}+0];"),
(UPat(Ops.DEFINE_VAR, name="x"), lambda ctx, x: f"ld.param.{ctx.mem_types[x.dtype]} {ctx.r[x]}, [{x.expr}+0];"),
])
class PTXRenderer(Renderer):
@@ -220,7 +220,7 @@ class PTXRenderer(Renderer):
continue
if u.op is Ops.INDEX: continue # other index we can skip
if u.op is Ops.SPECIAL: r[u] = "%" + u.arg
elif u.op is Ops.DEFINE_VAR: bufs.append((u.arg[0], u.dtype))
elif u.op is Ops.DEFINE_VAR: bufs.append((u.expr, u.dtype))
elif u.op is Ops.LOAD:
assert u.src[0].dtype == dtypes.int64, "load isn't int64"
r[u] = [ssa('val', dtype=self.types[u.dtype.scalar()]) for _ in range(u.dtype.count)] if u.dtype.count > 1 else ssa('val', u)

View File

@@ -81,7 +81,7 @@ class TLSFAllocator:
# Round up the allocation size to the next bucket, so any entry there can fit the requested size.
size = round_up(size, (1 << size.bit_length() - self.l2_cnt))
# Search for the smallest block that can fit the requested size. Start with the it's bucket and go up until any block is found.
# Search for the smallest block that can fit the requested size. Start with its bucket and go up until any block is found.
for l1 in range(self.lv1(size), len(self.storage)):
if self.lv1_entries[l1] == 0: continue
for l2 in range(self.lv2(size) if l1 == size.bit_length() else 0, (1 << self.l2_cnt)):
@@ -105,7 +105,7 @@ class TLSFAllocator:
def free(self, start:int):
self._insert_block(start - self.base, self.blocks[start - self.base][0])._merge_block(start - self.base)
# Memory Managment
# Memory Management
class AddrSpace(enum.Enum): PHYS = enum.auto(); SYS = enum.auto(); PEER = enum.auto() # noqa: E702
@@ -221,7 +221,7 @@ class MemoryManager:
@classmethod
def alloc_vaddr(cls, size:int, align=0x1000) -> int:
assert cls.va_allocator is not None, "must be set it"
assert cls.va_allocator is not None, "must be set"
return cls.va_allocator.alloc(size, max((1 << (size.bit_length() - 1)), align))
def valloc(self, size:int, align=0x1000, uncached=False, contiguous=False) -> VirtMapping:
@@ -248,7 +248,7 @@ class MemoryManager:
return self.map_range(va, size, paddrs, aspace=AddrSpace.PHYS, uncached=uncached)
def vfree(self, vm:VirtMapping):
assert self.va_allocator is not None, "must be set it"
assert self.va_allocator is not None, "must be set"
self.unmap_range(vm.va_addr, vm.size)
self.va_allocator.free(vm.va_addr)
for paddr, _ in vm.paddrs: self.pa_allocator.free(paddr)

View File

@@ -77,7 +77,7 @@ class NVDev(PCIDevImplBase):
self._early_ip_init()
self._early_mmu_init()
# Turn the booting early, gsp client is loaded from the clean.
# No booting state, gsp client is reinited every run.
self.is_booting = False
for ip in [self.flcn, self.gsp]: ip.init_sw()

View File

@@ -7,7 +7,8 @@ from tinygrad.runtime.support.hcq import FileIOInterface, MMIOInterface, HCQBuff
from tinygrad.runtime.support.memory import MemoryManager, VirtMapping, AddrSpace
from tinygrad.runtime.support.usb import ASM24Controller, USBMMIOInterface
MAP_FIXED, MAP_LOCKED, MAP_POPULATE, MAP_NORESERVE = 0x10, 0 if OSX else 0x2000, getattr(mmap, "MAP_POPULATE", 0 if OSX else 0x008000), 0x400
MAP_FIXED, MAP_FIXED_NOREPLACE = 0x10, 0x100000
MAP_LOCKED, MAP_POPULATE, MAP_NORESERVE = 0 if OSX else 0x2000, getattr(mmap, "MAP_POPULATE", 0 if OSX else 0x008000), 0x400
@dataclasses.dataclass(frozen=True)
class PCIBarInfo: addr:int; size:int # noqa: E702
@@ -69,7 +70,7 @@ class _System:
all_devs.append((int(FileIOInterface(f"/sys/bus/pci/devices/{pcibus}/vendor").read(), 16),
int(FileIOInterface(f"/sys/bus/pci/devices/{pcibus}/device").read(), 16), pcibus))
return sorted([val for vendor, device, val in all_devs if vendor == vendor and any((device & mask) in devlist for mask, devlist in devices)])
return sorted([val for vndr, device, val in all_devs if vndr == vendor and any((device & mask) in devlist for mask, devlist in devices)])
def pci_setup_usb_bars(self, usb:ASM24Controller, gpu_bus:int, mem_base:int, pref_mem_base:int) -> dict[int, PCIBarInfo]:
for bus in range(gpu_bus):
@@ -219,7 +220,7 @@ class LNXPCIIfaceBase:
cls.gpus = hcq_filter_visible_devices(System.pci_scan_bus(vendor, devices, base_class))
# Acquire va range to avoid collisions.
FileIOInterface.anon_mmap(va_start, va_size, 0, mmap.MAP_PRIVATE | mmap.MAP_ANONYMOUS | MAP_NORESERVE | MAP_FIXED, 0)
FileIOInterface.anon_mmap(va_start, va_size, 0, mmap.MAP_PRIVATE | mmap.MAP_ANONYMOUS | MAP_NORESERVE | MAP_FIXED_NOREPLACE, 0)
self.pci_dev, self.dev, self.vram_bar = PCIDevice(dev.__class__.__name__[:2], cls.gpus[dev_id], bars=bars, resize_bars=[vram_bar]), dev, vram_bar
self.p2p_base_addr = self.pci_dev.bar_info[vram_bar].addr

View File

@@ -1,4 +1,3 @@
from typing import cast
import functools, itertools
from tinygrad.helpers import all_same, all_int, prod, DEBUG, RING, ALL2ALL, VIZ, getenv
from tinygrad.uop.ops import Ops, UOp, PatternMatcher, UPat, GroupOp, graph_rewrite_map, graph_rewrite
@@ -29,7 +28,7 @@ def handle_allreduce(buf:UOp, red:UOp) -> UOp|None:
chunks = list(itertools.pairwise(itertools.accumulate([(base + 1) * factor] * left + [base * factor] * (ndev - left), initial=0)))
# reduce-scatter
reduced_chunks = []
reduced_chunks:list[UOp] = []
for i,(s,e) in enumerate(chunks):
if use_all2all:
chunks_on_i = [buf.mselect(j).reshape((numel,)).shrink(((s,e),)).copy_to_device(buf.device[i]) for j in range(ndev)]
@@ -43,16 +42,15 @@ def handle_allreduce(buf:UOp, red:UOp) -> UOp|None:
reduced_chunks.append(reduced)
# allgather
copied_chunks = []
copied_chunks:list[UOp] = []
for i,rc in enumerate(reduced_chunks):
if isinstance(red.src[1].arg, str): copied_chunks.append(rc.copy_to_device(red.src[1].arg))
elif use_all2all: copied_chunks.append(UOp(Ops.MSTACK, buf.dtype, tuple(rc.copy_to_device(buf.device[j]) for j in range(ndev))))
else:
this_chunk: list[UOp|None] = [None] * ndev
this_chunk[(i+ndev-1)%ndev] = rc
chain:list[UOp] = [rc]
for step in range(ndev-1):
this_chunk[(i+step)%ndev] = rc = rc.copy_to_device(buf.device[(i+step)%ndev])
copied_chunks.append(UOp(Ops.MSTACK, buf.dtype, tuple(cast(list[UOp], this_chunk))))
chain.append(rc := rc.copy_to_device(buf.device[(i+step)%ndev]))
copied_chunks.append(UOp(Ops.MSTACK, buf.dtype, tuple(chain[(j-i+1)%ndev] for j in range(ndev))))
# reassemble
return UOp.sum(*[c.pad(((s,numel-e),)) for (s,e),c in zip(chunks, copied_chunks)]).reshape(shape)
@@ -63,7 +61,7 @@ def mstack_early_shrink(ms:UOp, shrink:UOp):
ret:list[UOp] = []
def apply_shrink(s:UOp, i:int) -> UOp:
new_arg = [tuple([x.substitute({dvar[0]:dvar[0].const_like(i)}) if isinstance(x, UOp) and
(dvar:=[v for v in x.vars() if v.op is Ops.DEFINE_VAR and v.arg[0]=='_device_num']) else x for x in ss]) for ss in shrink.marg]
(dvar:=[v for v in x.variables() if v.expr=='_device_num']) else x for x in ss]) for ss in shrink.marg]
return s.shrink(tuple(new_arg))
for i, x in enumerate(ms.src):
if x.op is Ops.COPY:
@@ -97,20 +95,20 @@ def alu_multi(root:UOp):
axis = root.axis
assert axis is not None
srcs = []
srcs:list[UOp] = []
for mlb in msrcs:
if mlb.axis == axis:
# same axis, just copy through
assert mlb.op is Ops.MULTI
srcs.append(mlb.src[0])
elif mlb.axis is None:
if mlb.axis is None:
# no axis, shard it
assert mlb.op is not Ops.MULTI
srcs.append(mlb._shard(axis))
else:
# axis mismatch, unshard it, send it to all devices, and shard it correctly
assert mlb.op is Ops.MULTI
srcs.append(mlb.src[0]._unshard(mlb.axis).allreduce(Ops.ADD, mlb.device)._shard(axis))
if mlb.axis == axis:
# same axis, just copy through
srcs.append(mlb.src[0])
else:
# axis mismatch, unshard it, send it to all devices, and shard it correctly
srcs.append(mlb.src[0]._unshard(mlb.axis).allreduce(Ops.ADD, mlb.device)._shard(axis))
return srcs[0].alu(root.op, *srcs[1:]).multi(axis)
def reduce_multi(root:UOp, multi:UOp):

View File

@@ -721,12 +721,8 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
return graph_rewrite(self, pm_unbind, ctx=ret), ret
@property
def val(self) -> int: return self.unbind()[1]
def vars(self) -> set[UOp]:
topo = self.toposort()
bound = {x.src[0]: x for x in topo if x.op is Ops.BIND and x.src[0].op is Ops.DEFINE_VAR}
return {bound.get(x, x) for x in topo if x.op is Ops.DEFINE_VAR}
def variables(self) -> list[Variable]:
return sorted(set([x.unbind()[0] if x.op is not Ops.DEFINE_VAR else x for x in self.vars()]), key=lambda v: v.arg)
return sorted({x for x in self.backward_slice_with_self if x.op is Ops.DEFINE_VAR}, key=lambda v: v.arg)
# *** uop symbolic stuff ***
@@ -820,7 +816,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
@functools.cached_property
def _sym_fxn(self):
sself = self.simplify()
varnames = tuple(x.arg[0] for x in sself.toposort() if x.op is Ops.DEFINE_VAR)
varnames = tuple(x.expr for x in sself.toposort() if x.op is Ops.DEFINE_VAR)
# TODO: sanitize varnames, or don't use naked eval while staying fast
return eval("lambda "+','.join(varnames)+": "+sself.render(pm=renderer_infer)), varnames # pylint: disable=eval-used
@@ -1387,7 +1383,7 @@ def bitcast(x, in_dtype:DType, out_dtype:DType):
return ret[0] if out_count == 1 else ret
renderer = PatternMatcher([
(UPat((Ops.DEFINE_VAR,), name="x"), lambda x: x.arg[0]),
(UPat((Ops.DEFINE_VAR,), name="x"), lambda x: x.expr),
(UPat((Ops.SPECIAL), name="x"), lambda x: x.arg),
(UPat(Ops.RANGE, name="x"), lambda x: f"r{range_str(x)}"),
(UPat((Ops.CONST, Ops.VCONST), name="x"), lambda x: str(x.arg)),