mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
Merge branch 'master' into prealloc_bufs
This commit is contained in:
@@ -47,24 +47,21 @@ def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False
|
||||
|
||||
# delta_vec = (do * attn).sum(-1, dtype=dtypes.float32).transpose(1, 2).unsqueeze(-2).detach()
|
||||
delta_vec = _sharded_empty((B, H, 1, N), xq, axis=0, dtype=dtypes.float32)
|
||||
delta_vec, dq_in = Tensor.custom_kernel(delta_vec, dq_in, attn, do, fxn=functools.partial(custom_fa_backward_pre, device=single_device, arch=arch))[:2]
|
||||
delta_vec, dq_in = Tensor.custom_kernel(delta_vec, dq_in, attn, do, fxn=functools.partial(custom_fa_backward_pre, device=single_device, arch=arch, B=B_local, N=N, H=H, H_KV=H_KV, D=D))[:2]
|
||||
|
||||
dq_in, dk, dv = Tensor.custom_kernel(dq_in, dk, dv, do, xq, xk, xv, l_vec, delta_vec, fxn=functools.partial(custom_fa_backward, device=single_device, arch=arch))[:3]
|
||||
dq_in, dk, dv = Tensor.custom_kernel(dq_in, dk, dv, do, xq, xk, xv, l_vec, delta_vec, fxn=functools.partial(custom_fa_backward, device=single_device, arch=arch, B=B_local, N=N, H=H, H_KV=H_KV, D=D))[:3]
|
||||
|
||||
# unshuffle dq
|
||||
dq = Tensor.custom_kernel(dq, dq_in, fxn=functools.partial(custom_fa_backward_post, device=single_device, arch=arch))[0]
|
||||
dq = Tensor.custom_kernel(dq, dq_in, fxn=functools.partial(custom_fa_backward_post, device=single_device, arch=arch, B=B_local, N=N, H=H, H_KV=H_KV, D=D))[0]
|
||||
|
||||
return None, None, dq.uop, dk.uop, dv.uop
|
||||
|
||||
attn, l_vec = Tensor.custom_kernel(attn, l_vec, xq, xk, xv, fxn=functools.partial(custom_fa_forward, device=single_device, arch=arch), grad_fxn=grad)[:2]
|
||||
attn, l_vec = Tensor.custom_kernel(attn, l_vec, xq, xk, xv, fxn=functools.partial(custom_fa_forward, device=single_device, arch=arch, B=B_local, N=N, H=H, H_KV=H_KV, D=D), grad_fxn=grad)[:2]
|
||||
|
||||
return attn.transpose(1, 2)
|
||||
|
||||
@functools.cache
|
||||
def custom_fa_forward(o:UOp, l_vec:UOp, q:UOp, k:UOp, v:UOp, device:str, arch:str):
|
||||
B, N, H, D = q.shape
|
||||
H_KV = k.shape[2]
|
||||
|
||||
def custom_fa_forward(o:UOp, l_vec:UOp, q:UOp, k:UOp, v:UOp, device:str, arch:str, B:int, N:int, H:int, H_KV:int, D:int):
|
||||
code = (pathlib.Path(__file__).parent / "fa_fwd_causal.cpp").read_text()
|
||||
compile_args = [f"-I{(pathlib.Path(__file__).parent / 'include').as_posix()}", "-std=c++20", "-DKITTENS_CDNA4", "-DHIP_ENABLE_WARP_SYNC_BUILTINS", "-ffast-math",
|
||||
f"-DATTN_B={B}", f"-DATTN_N={N}", f"-DATTN_H={H}", f"-DATTN_H_KV={H_KV}"]
|
||||
@@ -95,9 +92,7 @@ def custom_fa_forward(o:UOp, l_vec:UOp, q:UOp, k:UOp, v:UOp, device:str, arch:st
|
||||
src=(sink, UOp(Ops.DEVICE, arg=device), UOp(Ops.LINEAR, src=(*sink.src, sink)), UOp(Ops.SOURCE, arg=code), UOp(Ops.BINARY, arg=lib)))
|
||||
|
||||
@functools.cache
|
||||
def custom_fa_backward_pre(delta_vec:UOp, dq:UOp, o:UOp, do:UOp, device:str, arch:str):
|
||||
B, N, H, D = o.shape
|
||||
|
||||
def custom_fa_backward_pre(delta_vec:UOp, dq:UOp, o:UOp, do:UOp, device:str, arch:str, B:int, N:int, H:int, H_KV:int, D:int):
|
||||
code = (pathlib.Path(__file__).parent / "fa_bwd_pre.cpp").read_text()
|
||||
compile_args = [f"-I{(pathlib.Path(__file__).parent / 'include').as_posix()}", "-std=c++20", "-DKITTENS_CDNA4", "-DHIP_ENABLE_WARP_SYNC_BUILTINS", "-ffast-math",
|
||||
f"-DATTN_B={B}", f"-DATTN_N={N}", f"-DATTN_H={H}"]
|
||||
@@ -128,10 +123,7 @@ def custom_fa_backward_pre(delta_vec:UOp, dq:UOp, o:UOp, do:UOp, device:str, arc
|
||||
src=(sink, UOp(Ops.DEVICE, arg=device), UOp(Ops.LINEAR, src=(*sink.src, sink)), UOp(Ops.SOURCE, arg=code), UOp(Ops.BINARY, arg=lib)))
|
||||
|
||||
@functools.cache
|
||||
def custom_fa_backward(dq:UOp, dk:UOp, dv:UOp, do:UOp, q:UOp, k:UOp, v:UOp, l_vec:UOp, delta_vec:UOp, device:str, arch:str):
|
||||
B, N, H, D = q.shape
|
||||
H_KV = k.shape[2]
|
||||
|
||||
def custom_fa_backward(dq:UOp, dk:UOp, dv:UOp, do:UOp, q:UOp, k:UOp, v:UOp, l_vec:UOp, delta_vec:UOp, device:str, arch:str, B:int, N:int, H:int, H_KV:int, D:int):
|
||||
code = (pathlib.Path(__file__).parent / "fa_bwd_causal.cpp").read_text()
|
||||
compile_args = [f"-I{(pathlib.Path(__file__).parent / 'include').as_posix()}", "-std=c++20", "-DKITTENS_CDNA4", "-DHIP_ENABLE_WARP_SYNC_BUILTINS", "-ffast-math",
|
||||
f"-DATTN_B={B}", f"-DATTN_N={N}", f"-DATTN_H={H}", f"-DATTN_H_KV={H_KV}"]
|
||||
@@ -162,9 +154,7 @@ def custom_fa_backward(dq:UOp, dk:UOp, dv:UOp, do:UOp, q:UOp, k:UOp, v:UOp, l_ve
|
||||
src=(sink, UOp(Ops.DEVICE, arg=device), UOp(Ops.LINEAR, src=(*sink.src, sink)), UOp(Ops.SOURCE, arg=code), UOp(Ops.BINARY, arg=lib)))
|
||||
|
||||
@functools.cache
|
||||
def custom_fa_backward_post(dq_out:UOp, dq_in:UOp, device:str, arch:str):
|
||||
B, N, H, D = dq_out.shape
|
||||
|
||||
def custom_fa_backward_post(dq_out:UOp, dq_in:UOp, device:str, arch:str, B:int, N:int, H:int, H_KV:int, D:int):
|
||||
code = (pathlib.Path(__file__).parent / "fa_bwd_post.cpp").read_text()
|
||||
compile_args = [f"-I{(pathlib.Path(__file__).parent / 'include').as_posix()}", "-std=c++20", "-DKITTENS_CDNA4", "-DHIP_ENABLE_WARP_SYNC_BUILTINS", "-ffast-math",
|
||||
f"-DATTN_B={B}", f"-DATTN_N={N}", f"-DATTN_H={H}"]
|
||||
|
||||
@@ -3,8 +3,7 @@ import unittest
|
||||
from dataclasses import replace
|
||||
|
||||
from tinygrad.codegen.opt import Opt, OptOps
|
||||
from tinygrad.codegen.gpudims import get_grouped_dims
|
||||
from tinygrad.uop.ops import UOp, Ops, GroupOp, AxisType, PatternMatcher, graph_rewrite, UPat
|
||||
from tinygrad.uop.ops import UOp, Ops, GroupOp, AxisType
|
||||
from tinygrad.device import Device, Buffer, is_dtype_supported
|
||||
from tinygrad.tensor import Tensor, _to_np_dtype
|
||||
from tinygrad.engine.realize import run_schedule, CompiledRunner, get_program
|
||||
@@ -253,100 +252,6 @@ class TestLinearizer(unittest.TestCase):
|
||||
if any(x.op is Ops.END and x.src[1].op in GroupOp.ALU for x in u.src):
|
||||
assert end_range < uops.index(u)
|
||||
|
||||
def test_grouped_dims(self):
|
||||
def _assert_grouped_dims(prefix, dims, max_sizes, reverse_dims, expected_sizes, assert_same_length = True):
|
||||
idxs = get_grouped_dims(prefix, dims, max_sizes, reverse_dims)
|
||||
loop_idxs = dedup(flatten([[y for y in x.toposort() if y.op is Ops.SPECIAL] for x in idxs]))
|
||||
loop_idxs = sorted(loop_idxs, key=lambda uop: uop.arg)
|
||||
sizes = [x.src[0].arg for x in loop_idxs]
|
||||
assert len(idxs) == len(dims), f"expected idxs to have same length as dims {len(dims)}, got {len(idxs)}"
|
||||
if assert_same_length:
|
||||
assert len(loop_idxs) == min(len(sizes), len(dims)), f"expected idxs to have length {min(len(sizes), len(dims))}, got {len(loop_idxs)}"
|
||||
assert sizes == expected_sizes, f"expected sizes={expected_sizes}, got {sizes=}"
|
||||
# TODO: add these back after uop symbolic
|
||||
# for i in range(len(dims)):
|
||||
# assert idxs[i].max+1 == dims[i], f"idxs[{i}] should have max {dims[i]-1}"
|
||||
# for i in range(len(loop_idxs)):
|
||||
# assert loop_idxs[i].expr.startswith(prefix), f"loop_idxs[{i}] must start with {prefix}"
|
||||
# assert loop_idxs[i].max+1 == sizes[i], f"loop_idxs[{i}] should have max {sizes[i]-1}"
|
||||
|
||||
# no-op
|
||||
_assert_grouped_dims("gidx", (2,), (16,16,16), False, [2])
|
||||
_assert_grouped_dims("gidx", (2,3), (16,16,16), False, [2,3])
|
||||
|
||||
# check reverse dims
|
||||
_assert_grouped_dims("gidx", (2,3), (16,16,16), True, [3,2])
|
||||
_assert_grouped_dims("gidx", (2,3,4), (16,16,16), False, [2,3,4])
|
||||
|
||||
# test splitting globals: len(dims) == len(max)
|
||||
_assert_grouped_dims("gidx", (64,3,4), (16,16,16), False, [16,12,4])
|
||||
_assert_grouped_dims("gidx", (64,3,4), (16,4,16), False, [16,3,16])
|
||||
_assert_grouped_dims("gidx", (64,3,4), (16,16,16), True, [16,3,16])
|
||||
_assert_grouped_dims("gidx", (128,3,4), (16,4,256), False, [16,3,32])
|
||||
_assert_grouped_dims("gidx", (4,4,512), (16,4,256), False, [8,4,256])
|
||||
|
||||
# prefer group_dim strategy when possible
|
||||
_assert_grouped_dims("gidx", (512,4,2), (8192,2,2), False, [2048,2])
|
||||
|
||||
# test splitting globals: len(dims) < len(max)
|
||||
# len(dim) -> len(limited)
|
||||
# 1 -> 2
|
||||
_assert_grouped_dims("gidx", (128,), (16,16,256), False, [16,8], False)
|
||||
# 1 -> 3
|
||||
_assert_grouped_dims("gidx", (65536,), (16,16,256), False, [16,16,256], False)
|
||||
# 2 -> 3
|
||||
_assert_grouped_dims("gidx", (128,128), (16,16,256), False, [16,16,64], False)
|
||||
# 2 -> 2
|
||||
_assert_grouped_dims("gidx", (65536,2), (65535,65535,65535), False, [32768,4], False)
|
||||
# test when the only divisor is the square root of dim
|
||||
_assert_grouped_dims("gidx", (121,), (12,12,12), False, [11,11], False)
|
||||
|
||||
# collapse on onto the left most axis
|
||||
_assert_grouped_dims("gidx", (2,3,4,5), (16,16,16), False, [6,4,5])
|
||||
_assert_grouped_dims("gidx", (2,3,4,5), (32,16,16), True, [20,3,2])
|
||||
# _assert_grouped_dims("gidx", (Variable("start_pos",1,2),3,4,5), (32,16,16), True, [20,3,Variable("start_pos",1,2)])
|
||||
|
||||
# collapse on left-most available axis (the left most is too small)
|
||||
_assert_grouped_dims("gidx", (2,3,4,5), (4,16,16), False, [2,12,5])
|
||||
_assert_grouped_dims("gidx", (2,3,4,5), (16,16,16), True, [5,12,2])
|
||||
|
||||
# _assert_grouped_dims("gidx", (Variable("start_pos",1,2),3,4,5), (16,16,16), False, [Variable("start_pos",1,2)*3,4,5])
|
||||
|
||||
# dim too large and not factorable
|
||||
with self.assertRaises(RuntimeError):
|
||||
get_grouped_dims("gidx", (23,), (16,16,16), False,)
|
||||
with self.assertRaises(RuntimeError):
|
||||
get_grouped_dims("gidx", (128,3,4), (16,2,2), False,)
|
||||
|
||||
# too large for sizes
|
||||
with self.assertRaises(RuntimeError):
|
||||
get_grouped_dims("gidx", (2,3,4,5,6), (16,16,16))
|
||||
|
||||
# TODO: In the above cases we only test if the shape after reshape is correct, never the indices.
|
||||
# We should check if the returned indices are correct, for all cases.
|
||||
# (65536, 2) -> (32768, 4)
|
||||
dims, expected_limited_dims = (65536,2), (32768, 4)
|
||||
idxs = get_grouped_dims("gidx", dims, (65535,65535,65535))
|
||||
def match_div(): raise RuntimeError("match_div")
|
||||
def match_mod(): raise RuntimeError("match_mod")
|
||||
flat_idx_pattern = UPat(Ops.SPECIAL, arg='gidx0')*expected_limited_dims[1]+UPat(Ops.SPECIAL, arg='gidx1')
|
||||
pm = PatternMatcher([
|
||||
(flat_idx_pattern//dims[1], match_div),
|
||||
(flat_idx_pattern%dims[1], match_mod)
|
||||
])
|
||||
|
||||
with self.assertRaises(RuntimeError) as error:
|
||||
graph_rewrite(idxs[0], pm)
|
||||
self.assertIn("match_div", str(error.exception))
|
||||
|
||||
with self.assertRaises(RuntimeError) as error:
|
||||
graph_rewrite(idxs[1], pm)
|
||||
self.assertIn("match_mod", str(error.exception))
|
||||
|
||||
# # variable too large
|
||||
# with self.assertRaises(AssertionError):
|
||||
# get_grouped_dims("gidx", (Variable("start_pos",0,16),3,4), (16,16,16), False,)
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
|
||||
def test_default_global_reversed(self):
|
||||
# shrink so that the dims do not collapse
|
||||
|
||||
101
test/null/test_gpudims.py
Normal file
101
test/null/test_gpudims.py
Normal file
@@ -0,0 +1,101 @@
|
||||
import unittest, math
|
||||
import z3
|
||||
from tinygrad.codegen.gpudims import get_grouped_dims
|
||||
from tinygrad.uop.ops import UOp, Ops
|
||||
from tinygrad.uop.validate import uops_to_z3
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.helpers import flatten, dedup
|
||||
|
||||
class TestGroupedDims(unittest.TestCase):
|
||||
def _check_grouped_dims(self, prefix, dims, max_sizes, reverse, expected_sizes, assert_same_length=True):
|
||||
idxs = get_grouped_dims(prefix, dims, max_sizes, reverse)
|
||||
loop_idxs = dedup(flatten([[y for y in x.toposort() if y.op is Ops.SPECIAL] for x in idxs]))
|
||||
loop_idxs = sorted(loop_idxs, key=lambda uop: uop.arg)
|
||||
sizes = [x.src[0].arg for x in loop_idxs]
|
||||
assert len(idxs) == len(dims), f"expected idxs to have same length as dims {len(dims)}, got {len(idxs)}"
|
||||
if assert_same_length:
|
||||
assert len(loop_idxs) == min(len(sizes), len(dims)), f"expected idxs to have length {min(len(sizes), len(dims))}, got {len(loop_idxs)}"
|
||||
assert sizes == expected_sizes, f"expected sizes={expected_sizes}, got {sizes=}"
|
||||
self._verify_indices_z3(idxs, dims)
|
||||
|
||||
def _verify_indices_z3(self, idxs, dims):
|
||||
"""Use z3 to prove bijectivity: bounds (0 <= flat < total) + injectivity (different inputs => different flat)."""
|
||||
total = math.prod(dims)
|
||||
specials = sorted(dedup(flatten([[y for y in x.toposort() if y.op is Ops.SPECIAL] for x in idxs])), key=lambda u: u.arg)
|
||||
# build flat index and primed flat (same expression with renamed SPECIALs)
|
||||
flat = UOp.const(dtypes.index, 0)
|
||||
for i, idx in enumerate(idxs):
|
||||
flat = flat + idx * int(math.prod(dims[i+1:]))
|
||||
flat_p = flat.substitute({s: UOp(Ops.SPECIAL, s.dtype, s.src, s.arg+"_p") for s in specials})
|
||||
solver = z3.Solver()
|
||||
[z3_flat, z3_flat_p] = uops_to_z3(solver, flat, flat_p)
|
||||
# bounds
|
||||
self.assertEqual(solver.check(z3_flat < 0), z3.unsat, f"flat can be negative: {dims=}")
|
||||
self.assertEqual(solver.check(z3_flat >= total), z3.unsat, f"flat can be >= {total}: {dims=}")
|
||||
# injectivity: flat == flat' but inputs differ => unsat
|
||||
inputs_differ = z3.Or(*[z3.Int(s.arg) != z3.Int(s.arg+"_p") for s in specials])
|
||||
self.assertEqual(solver.check(z3.And(z3_flat == z3_flat_p, inputs_differ)), z3.unsat, f"not injective: {dims=}")
|
||||
|
||||
def test_grouped_dims(self):
|
||||
# no-op
|
||||
self._check_grouped_dims("gidx", (2,), (16,16,16), False, [2])
|
||||
self._check_grouped_dims("gidx", (2,3), (16,16,16), False, [2,3])
|
||||
|
||||
# check reverse dims
|
||||
self._check_grouped_dims("gidx", (2,3), (16,16,16), True, [3,2])
|
||||
self._check_grouped_dims("gidx", (2,3,4), (16,16,16), False, [2,3,4])
|
||||
|
||||
# test splitting globals: len(dims) == len(max)
|
||||
self._check_grouped_dims("gidx", (64,3,4), (16,16,16), False, [16,12,4])
|
||||
self._check_grouped_dims("gidx", (64,3,4), (16,4,16), False, [16,3,16])
|
||||
self._check_grouped_dims("gidx", (64,3,4), (16,16,16), True, [16,3,16])
|
||||
self._check_grouped_dims("gidx", (128,3,4), (16,4,256), False, [16,3,32])
|
||||
self._check_grouped_dims("gidx", (4,4,512), (16,4,256), False, [8,4,256])
|
||||
self._check_grouped_dims("gidx", (5,12,7), (8,4,16), False, [10,3,14])
|
||||
|
||||
# prefer group_dim strategy when possible
|
||||
self._check_grouped_dims("gidx", (512,4,2), (8192,2,2), False, [2048,2])
|
||||
|
||||
# test splitting globals: len(dims) < len(max)
|
||||
# len(dim) -> len(limited)
|
||||
# 1 -> 2
|
||||
self._check_grouped_dims("gidx", (128,), (16,16,256), False, [16,8], False)
|
||||
# 1 -> 3
|
||||
self._check_grouped_dims("gidx", (65536,), (16,16,256), False, [16,16,256], False)
|
||||
# 2 -> 2
|
||||
self._check_grouped_dims("gidx", (65536,2), (65535,65535,65535), False, [32768,4], False)
|
||||
# test when the only divisor is the square root of dim
|
||||
self._check_grouped_dims("gidx", (121,), (12,12,12), False, [11,11], False)
|
||||
# 2 -> 3
|
||||
self._check_grouped_dims("gidx", (128,128), (16,16,256), False, [16,16,64], False)
|
||||
|
||||
# collapse on onto the left most axis
|
||||
self._check_grouped_dims("gidx", (2,3,4,5), (16,16,16), False, [6,4,5])
|
||||
self._check_grouped_dims("gidx", (2,3,4,5), (32,16,16), True, [20,3,2])
|
||||
|
||||
# collapse on left-most available axis (the left most is too small)
|
||||
self._check_grouped_dims("gidx", (2,3,4,5), (4,16,16), False, [2,12,5])
|
||||
self._check_grouped_dims("gidx", (2,3,4,5), (16,16,16), True, [5,12,2])
|
||||
|
||||
# dim too large and not factorable
|
||||
with self.assertRaises(RuntimeError):
|
||||
get_grouped_dims("gidx", (23,), (16,16,16), False,)
|
||||
with self.assertRaises(RuntimeError):
|
||||
get_grouped_dims("gidx", (128,3,4), (16,2,2), False,)
|
||||
|
||||
# too large for sizes
|
||||
with self.assertRaises(RuntimeError):
|
||||
get_grouped_dims("gidx", (2,3,4,5,6), (16,16,16))
|
||||
|
||||
def test_grouped_direct_dims_are_special(self):
|
||||
# when (2,3) are merged into 6, the unmerged dims (4,5) should map directly to SPECIAL ops (no div/mod)
|
||||
idxs = get_grouped_dims("gidx", (2,3,4,5), (16,16,16), False)
|
||||
assert idxs[2].op is Ops.SPECIAL, f"expected SPECIAL for direct-mapped dim, got {idxs[2].op}"
|
||||
assert idxs[3].op is Ops.SPECIAL, f"expected SPECIAL for direct-mapped dim, got {idxs[3].op}"
|
||||
|
||||
def test_max_sizes_none(self):
|
||||
self._check_grouped_dims("gidx", (2,3,4), None, False, [2,3,4])
|
||||
self._check_grouped_dims("gidx", (100,), None, False, [100])
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -839,34 +839,33 @@ class TestSymbolicNumeric(unittest.TestCase):
|
||||
def test_times_2_plus_3_div_4(self): self.helper_test_numeric(lambda x: (x*2 + 3)//4)
|
||||
def test_times_2_plus_3_div_4_mod_4(self): self.helper_test_numeric(lambda x: ((x*2 + 3)//4)%4)
|
||||
|
||||
class TestSymbolicVars(unittest.TestCase):
|
||||
class TestSymbolicVariables(unittest.TestCase):
|
||||
def test_simple(self):
|
||||
z = uconst(0)
|
||||
a = Variable("a", 0, 10)
|
||||
b = Variable("b", 0, 10)
|
||||
c = Variable("c", 0, 10)
|
||||
assert z.vars() == z.vars() == set()
|
||||
print(a.vars())
|
||||
assert a.vars() == a.vars() == {a}
|
||||
assert z.variables() == []
|
||||
assert a.variables() == [a]
|
||||
m = a * 3
|
||||
assert m.vars() == {a}
|
||||
assert m.variables() == [a]
|
||||
s = usum([a, b, c])
|
||||
assert s.vars() == {a, b, c}
|
||||
assert s.variables() == [a, b, c]
|
||||
|
||||
def test_compound(self):
|
||||
a = Variable("a", 0, 10)
|
||||
b = Variable("b", 0, 10)
|
||||
c = Variable("c", 0, 10)
|
||||
assert (a + b * c).vars() == {a, b, c}
|
||||
assert (a % 3 + b // 5).vars() == {a, b}
|
||||
assert (a + b * c).variables() == [a, b, c]
|
||||
assert (a % 3 + b // 5).variables() == [a, b]
|
||||
# TODO: fix me
|
||||
with self.assertRaises(AssertionError):
|
||||
assert (a + b + c - a).vars() == {b, c}
|
||||
assert (a + b + c - a).variables() == [b, c]
|
||||
|
||||
def test_dedup(self):
|
||||
a = Variable("a", 0, 10)
|
||||
assert (a * a).vars() == {a}
|
||||
assert (a//4 + a//6).vars() == {a}
|
||||
assert (a * a).variables() == [a]
|
||||
assert (a//4 + a//6).variables() == [a]
|
||||
|
||||
class TestSymInfer(unittest.TestCase):
|
||||
def test_sym_infer(self):
|
||||
|
||||
@@ -92,5 +92,13 @@ class TestCall(unittest.TestCase):
|
||||
np.testing.assert_allclose(a.grad.numpy(), gt_a_grad, rtol=1e-5)
|
||||
np.testing.assert_allclose(b.grad.numpy(), gt_b_grad, rtol=1e-5)
|
||||
|
||||
def test_call_plus_sharded(self):
|
||||
devs = ("CPU:0", "CPU:1")
|
||||
a = Tensor.ones(10, 10).shard(devs, axis=0)
|
||||
b = Tensor.ones(10, 10).shard(devs, axis=0)
|
||||
Tensor.realize(a, b)
|
||||
c = Tensor.call(a, b, fxn=a.as_param(0) + b.as_param(1))
|
||||
np.testing.assert_equal(c.numpy(), 2 * np.ones((10, 10)))
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
28
test/unit/test_system_pci_scan_bus.py
Normal file
28
test/unit/test_system_pci_scan_bus.py
Normal file
@@ -0,0 +1,28 @@
|
||||
import sys
|
||||
import pytest
|
||||
|
||||
@pytest.mark.skipif(sys.platform != "linux", reason="uses linux sysfs layout")
|
||||
def test_pci_scan_bus_filters_vendor(monkeypatch):
|
||||
import tinygrad.runtime.support.system as system
|
||||
|
||||
fake = {
|
||||
"/sys/bus/pci/devices/0000:00:01.0/vendor": "0x1234",
|
||||
"/sys/bus/pci/devices/0000:00:01.0/device": "0x1111",
|
||||
"/sys/bus/pci/devices/0000:00:02.0/vendor": "0xabcd",
|
||||
"/sys/bus/pci/devices/0000:00:02.0/device": "0x1111",
|
||||
}
|
||||
|
||||
class FakeFileIOInterface:
|
||||
def __init__(self, path, *args, **kwargs):
|
||||
self.path = path
|
||||
|
||||
def listdir(self):
|
||||
assert self.path == "/sys/bus/pci/devices"
|
||||
return ["0000:00:01.0", "0000:00:02.0"]
|
||||
|
||||
def read(self, *args, **kwargs):
|
||||
return fake[self.path]
|
||||
|
||||
monkeypatch.setattr(system, "FileIOInterface", FakeFileIOInterface)
|
||||
|
||||
assert system.System.pci_scan_bus(0x1234, devices=[(0xffff, [0x1111])]) == ["0000:00:01.0"]
|
||||
@@ -48,10 +48,9 @@ def get_grouped_dims(prefix, dims:tuple[sint, ...], max_sizes:tuple[int, ...]|No
|
||||
elif (a:=len(limited)) > (b:=len(dims)):
|
||||
if a == 2 and b == 1: return [raw_idxs[0] * limited[1] + raw_idxs[1]]
|
||||
if a == 3 and b == 1: return [(raw_idxs[0] * limited[1] + raw_idxs[1]) * limited[2] + raw_idxs[2]]
|
||||
if a == 3 and b == 2: return [raw_idxs[0] * limited[1] + raw_idxs[1], raw_idxs[2]]
|
||||
elif limited != dims:
|
||||
if limited != dims:
|
||||
# Convert to 1D
|
||||
flat = raw_idxs[0]*limited[1]+raw_idxs[1] if len(dims) == 2 else raw_idxs[0]*(limited[1]*limited[2])+raw_idxs[1]*limited[2]+raw_idxs[2]
|
||||
flat = raw_idxs[0]*limited[1]+raw_idxs[1] if len(limited) == 2 else raw_idxs[0]*(limited[1]*limited[2])+raw_idxs[1]*limited[2]+raw_idxs[2]
|
||||
# Get back original indices from 1D
|
||||
return [flat//dims[1], flat%dims[1]] if len(dims) == 2 else [flat//(dims[2]*dims[1]), (flat//dims[2])%dims[1], flat%dims[2]]
|
||||
return raw_idxs
|
||||
|
||||
@@ -209,7 +209,7 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
|
||||
ubufs = tuple(b.buffer for b in buf_uops)
|
||||
if any(isinstance(x, MultiBuffer) for x in ubufs):
|
||||
assert all(isinstance(x, MultiBuffer) for x in ubufs), "kernel must all be multibuffer"
|
||||
dnums = [x for x in si.ast.variables() if x.arg[0] == '_device_num']
|
||||
dnums = [x for x in si.ast.variables() if x.expr == '_device_num']
|
||||
for j, bufs in enumerate(zip(*[x.bufs for x in cast(tuple[MultiBuffer, ...], ubufs)])):
|
||||
schedule.append(ExecItem(si.ast, list(bufs), si.metadata, si.fixedvars | ({dnums[0].expr:j} if len(dnums) else {})))
|
||||
else:
|
||||
@@ -222,5 +222,5 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
|
||||
f" | {' cache hit' if SCACHE and sc_ret is not None else 'CACHE MISS'} {sched_cache_key.hex()[:8]}"+\
|
||||
f" | {len(UOpMetaClass.ucache)} uops in cache")
|
||||
|
||||
used_vars = set().union(*[{v.arg[0] for v in si.ast.variables()} for si in schedule])
|
||||
return buffer_map, schedule, {k:v for k,v in var_vals.items() if k in used_vars}
|
||||
used_vars = set().union(*[{v.expr for v in si.ast.variables()} for si in schedule])
|
||||
return buffer_map, schedule, {k:v for k,v in var_vals.items() if k in used_vars}
|
||||
@@ -86,7 +86,7 @@ class ProgramSpec:
|
||||
def function_name(self) -> str: return to_function_name(self.name)
|
||||
|
||||
@functools.cached_property
|
||||
def runtimevars(self) -> dict[str, int]: return {v.arg[0]: i for i, v in enumerate(self.vars) if v.arg[0] == 'core_id'}
|
||||
def runtimevars(self) -> dict[str, int]: return {v.expr: i for i, v in enumerate(self.vars) if v.expr == 'core_id'}
|
||||
|
||||
@property
|
||||
def applied_opts(self) -> tuple[Opt, ...]|None: return self.ast.arg.applied_opts if self.ast.arg is not None else None
|
||||
|
||||
@@ -169,7 +169,7 @@ class LLVMRenderer(Renderer):
|
||||
if u.arg is not None: name = u.arg.function_name
|
||||
continue
|
||||
if u.op in (Ops.PARAM, Ops.DEFINE_VAR):
|
||||
r[u] = f"%data{u.arg}" if u.op is Ops.PARAM else f"%{u.arg[0]}"
|
||||
r[u] = f"%data{u.arg}" if u.op is Ops.PARAM else f"%{u.expr}"
|
||||
args.append((r[u], u.dtype))
|
||||
elif u.op in (Ops.DEFINE_LOCAL, Ops.DEFINE_REG):
|
||||
r[u] = f"%{'local' if u.op is Ops.DEFINE_LOCAL else 'reg'}_{str(u.arg).replace('(', '').replace(')', '').replace(',', '_').replace(' ', '')}"
|
||||
|
||||
@@ -134,7 +134,7 @@ string_rewrite = PatternMatcher([
|
||||
(UPat(Ops.ENDIF, name="x"), lambda ctx, x: f"IF_{ctx.r[x.src[0].src[0]][1:]}_{ctx.uops.index(x.src[0])}:"),
|
||||
(UPat(Ops.WMMA, name="x"), lambda ctx, x: list(render_wmma(ctx, x))),
|
||||
(UPat(Ops.BARRIER), lambda ctx: ctx.barrier),
|
||||
(UPat(Ops.DEFINE_VAR, name="x"), lambda ctx, x: f"ld.param.{ctx.mem_types[x.dtype]} {ctx.r[x]}, [{x.arg[0]}+0];"),
|
||||
(UPat(Ops.DEFINE_VAR, name="x"), lambda ctx, x: f"ld.param.{ctx.mem_types[x.dtype]} {ctx.r[x]}, [{x.expr}+0];"),
|
||||
])
|
||||
|
||||
class PTXRenderer(Renderer):
|
||||
@@ -220,7 +220,7 @@ class PTXRenderer(Renderer):
|
||||
continue
|
||||
if u.op is Ops.INDEX: continue # other index we can skip
|
||||
if u.op is Ops.SPECIAL: r[u] = "%" + u.arg
|
||||
elif u.op is Ops.DEFINE_VAR: bufs.append((u.arg[0], u.dtype))
|
||||
elif u.op is Ops.DEFINE_VAR: bufs.append((u.expr, u.dtype))
|
||||
elif u.op is Ops.LOAD:
|
||||
assert u.src[0].dtype == dtypes.int64, "load isn't int64"
|
||||
r[u] = [ssa('val', dtype=self.types[u.dtype.scalar()]) for _ in range(u.dtype.count)] if u.dtype.count > 1 else ssa('val', u)
|
||||
|
||||
@@ -81,7 +81,7 @@ class TLSFAllocator:
|
||||
# Round up the allocation size to the next bucket, so any entry there can fit the requested size.
|
||||
size = round_up(size, (1 << size.bit_length() - self.l2_cnt))
|
||||
|
||||
# Search for the smallest block that can fit the requested size. Start with the it's bucket and go up until any block is found.
|
||||
# Search for the smallest block that can fit the requested size. Start with its bucket and go up until any block is found.
|
||||
for l1 in range(self.lv1(size), len(self.storage)):
|
||||
if self.lv1_entries[l1] == 0: continue
|
||||
for l2 in range(self.lv2(size) if l1 == size.bit_length() else 0, (1 << self.l2_cnt)):
|
||||
@@ -105,7 +105,7 @@ class TLSFAllocator:
|
||||
def free(self, start:int):
|
||||
self._insert_block(start - self.base, self.blocks[start - self.base][0])._merge_block(start - self.base)
|
||||
|
||||
# Memory Managment
|
||||
# Memory Management
|
||||
|
||||
class AddrSpace(enum.Enum): PHYS = enum.auto(); SYS = enum.auto(); PEER = enum.auto() # noqa: E702
|
||||
|
||||
@@ -221,7 +221,7 @@ class MemoryManager:
|
||||
|
||||
@classmethod
|
||||
def alloc_vaddr(cls, size:int, align=0x1000) -> int:
|
||||
assert cls.va_allocator is not None, "must be set it"
|
||||
assert cls.va_allocator is not None, "must be set"
|
||||
return cls.va_allocator.alloc(size, max((1 << (size.bit_length() - 1)), align))
|
||||
|
||||
def valloc(self, size:int, align=0x1000, uncached=False, contiguous=False) -> VirtMapping:
|
||||
@@ -248,7 +248,7 @@ class MemoryManager:
|
||||
return self.map_range(va, size, paddrs, aspace=AddrSpace.PHYS, uncached=uncached)
|
||||
|
||||
def vfree(self, vm:VirtMapping):
|
||||
assert self.va_allocator is not None, "must be set it"
|
||||
assert self.va_allocator is not None, "must be set"
|
||||
self.unmap_range(vm.va_addr, vm.size)
|
||||
self.va_allocator.free(vm.va_addr)
|
||||
for paddr, _ in vm.paddrs: self.pa_allocator.free(paddr)
|
||||
|
||||
@@ -77,7 +77,7 @@ class NVDev(PCIDevImplBase):
|
||||
self._early_ip_init()
|
||||
self._early_mmu_init()
|
||||
|
||||
# Turn the booting early, gsp client is loaded from the clean.
|
||||
# No booting state, gsp client is reinited every run.
|
||||
self.is_booting = False
|
||||
|
||||
for ip in [self.flcn, self.gsp]: ip.init_sw()
|
||||
|
||||
@@ -7,7 +7,8 @@ from tinygrad.runtime.support.hcq import FileIOInterface, MMIOInterface, HCQBuff
|
||||
from tinygrad.runtime.support.memory import MemoryManager, VirtMapping, AddrSpace
|
||||
from tinygrad.runtime.support.usb import ASM24Controller, USBMMIOInterface
|
||||
|
||||
MAP_FIXED, MAP_LOCKED, MAP_POPULATE, MAP_NORESERVE = 0x10, 0 if OSX else 0x2000, getattr(mmap, "MAP_POPULATE", 0 if OSX else 0x008000), 0x400
|
||||
MAP_FIXED, MAP_FIXED_NOREPLACE = 0x10, 0x100000
|
||||
MAP_LOCKED, MAP_POPULATE, MAP_NORESERVE = 0 if OSX else 0x2000, getattr(mmap, "MAP_POPULATE", 0 if OSX else 0x008000), 0x400
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class PCIBarInfo: addr:int; size:int # noqa: E702
|
||||
@@ -69,7 +70,7 @@ class _System:
|
||||
all_devs.append((int(FileIOInterface(f"/sys/bus/pci/devices/{pcibus}/vendor").read(), 16),
|
||||
int(FileIOInterface(f"/sys/bus/pci/devices/{pcibus}/device").read(), 16), pcibus))
|
||||
|
||||
return sorted([val for vendor, device, val in all_devs if vendor == vendor and any((device & mask) in devlist for mask, devlist in devices)])
|
||||
return sorted([val for vndr, device, val in all_devs if vndr == vendor and any((device & mask) in devlist for mask, devlist in devices)])
|
||||
|
||||
def pci_setup_usb_bars(self, usb:ASM24Controller, gpu_bus:int, mem_base:int, pref_mem_base:int) -> dict[int, PCIBarInfo]:
|
||||
for bus in range(gpu_bus):
|
||||
@@ -219,7 +220,7 @@ class LNXPCIIfaceBase:
|
||||
cls.gpus = hcq_filter_visible_devices(System.pci_scan_bus(vendor, devices, base_class))
|
||||
|
||||
# Acquire va range to avoid collisions.
|
||||
FileIOInterface.anon_mmap(va_start, va_size, 0, mmap.MAP_PRIVATE | mmap.MAP_ANONYMOUS | MAP_NORESERVE | MAP_FIXED, 0)
|
||||
FileIOInterface.anon_mmap(va_start, va_size, 0, mmap.MAP_PRIVATE | mmap.MAP_ANONYMOUS | MAP_NORESERVE | MAP_FIXED_NOREPLACE, 0)
|
||||
self.pci_dev, self.dev, self.vram_bar = PCIDevice(dev.__class__.__name__[:2], cls.gpus[dev_id], bars=bars, resize_bars=[vram_bar]), dev, vram_bar
|
||||
self.p2p_base_addr = self.pci_dev.bar_info[vram_bar].addr
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from typing import cast
|
||||
import functools, itertools
|
||||
from tinygrad.helpers import all_same, all_int, prod, DEBUG, RING, ALL2ALL, VIZ, getenv
|
||||
from tinygrad.uop.ops import Ops, UOp, PatternMatcher, UPat, GroupOp, graph_rewrite_map, graph_rewrite
|
||||
@@ -29,7 +28,7 @@ def handle_allreduce(buf:UOp, red:UOp) -> UOp|None:
|
||||
chunks = list(itertools.pairwise(itertools.accumulate([(base + 1) * factor] * left + [base * factor] * (ndev - left), initial=0)))
|
||||
|
||||
# reduce-scatter
|
||||
reduced_chunks = []
|
||||
reduced_chunks:list[UOp] = []
|
||||
for i,(s,e) in enumerate(chunks):
|
||||
if use_all2all:
|
||||
chunks_on_i = [buf.mselect(j).reshape((numel,)).shrink(((s,e),)).copy_to_device(buf.device[i]) for j in range(ndev)]
|
||||
@@ -43,16 +42,15 @@ def handle_allreduce(buf:UOp, red:UOp) -> UOp|None:
|
||||
reduced_chunks.append(reduced)
|
||||
|
||||
# allgather
|
||||
copied_chunks = []
|
||||
copied_chunks:list[UOp] = []
|
||||
for i,rc in enumerate(reduced_chunks):
|
||||
if isinstance(red.src[1].arg, str): copied_chunks.append(rc.copy_to_device(red.src[1].arg))
|
||||
elif use_all2all: copied_chunks.append(UOp(Ops.MSTACK, buf.dtype, tuple(rc.copy_to_device(buf.device[j]) for j in range(ndev))))
|
||||
else:
|
||||
this_chunk: list[UOp|None] = [None] * ndev
|
||||
this_chunk[(i+ndev-1)%ndev] = rc
|
||||
chain:list[UOp] = [rc]
|
||||
for step in range(ndev-1):
|
||||
this_chunk[(i+step)%ndev] = rc = rc.copy_to_device(buf.device[(i+step)%ndev])
|
||||
copied_chunks.append(UOp(Ops.MSTACK, buf.dtype, tuple(cast(list[UOp], this_chunk))))
|
||||
chain.append(rc := rc.copy_to_device(buf.device[(i+step)%ndev]))
|
||||
copied_chunks.append(UOp(Ops.MSTACK, buf.dtype, tuple(chain[(j-i+1)%ndev] for j in range(ndev))))
|
||||
|
||||
# reassemble
|
||||
return UOp.sum(*[c.pad(((s,numel-e),)) for (s,e),c in zip(chunks, copied_chunks)]).reshape(shape)
|
||||
@@ -63,7 +61,7 @@ def mstack_early_shrink(ms:UOp, shrink:UOp):
|
||||
ret:list[UOp] = []
|
||||
def apply_shrink(s:UOp, i:int) -> UOp:
|
||||
new_arg = [tuple([x.substitute({dvar[0]:dvar[0].const_like(i)}) if isinstance(x, UOp) and
|
||||
(dvar:=[v for v in x.vars() if v.op is Ops.DEFINE_VAR and v.arg[0]=='_device_num']) else x for x in ss]) for ss in shrink.marg]
|
||||
(dvar:=[v for v in x.variables() if v.expr=='_device_num']) else x for x in ss]) for ss in shrink.marg]
|
||||
return s.shrink(tuple(new_arg))
|
||||
for i, x in enumerate(ms.src):
|
||||
if x.op is Ops.COPY:
|
||||
@@ -97,20 +95,20 @@ def alu_multi(root:UOp):
|
||||
axis = root.axis
|
||||
assert axis is not None
|
||||
|
||||
srcs = []
|
||||
srcs:list[UOp] = []
|
||||
for mlb in msrcs:
|
||||
if mlb.axis == axis:
|
||||
# same axis, just copy through
|
||||
assert mlb.op is Ops.MULTI
|
||||
srcs.append(mlb.src[0])
|
||||
elif mlb.axis is None:
|
||||
if mlb.axis is None:
|
||||
# no axis, shard it
|
||||
assert mlb.op is not Ops.MULTI
|
||||
srcs.append(mlb._shard(axis))
|
||||
else:
|
||||
# axis mismatch, unshard it, send it to all devices, and shard it correctly
|
||||
assert mlb.op is Ops.MULTI
|
||||
srcs.append(mlb.src[0]._unshard(mlb.axis).allreduce(Ops.ADD, mlb.device)._shard(axis))
|
||||
if mlb.axis == axis:
|
||||
# same axis, just copy through
|
||||
srcs.append(mlb.src[0])
|
||||
else:
|
||||
# axis mismatch, unshard it, send it to all devices, and shard it correctly
|
||||
srcs.append(mlb.src[0]._unshard(mlb.axis).allreduce(Ops.ADD, mlb.device)._shard(axis))
|
||||
return srcs[0].alu(root.op, *srcs[1:]).multi(axis)
|
||||
|
||||
def reduce_multi(root:UOp, multi:UOp):
|
||||
|
||||
@@ -721,12 +721,8 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
return graph_rewrite(self, pm_unbind, ctx=ret), ret
|
||||
@property
|
||||
def val(self) -> int: return self.unbind()[1]
|
||||
def vars(self) -> set[UOp]:
|
||||
topo = self.toposort()
|
||||
bound = {x.src[0]: x for x in topo if x.op is Ops.BIND and x.src[0].op is Ops.DEFINE_VAR}
|
||||
return {bound.get(x, x) for x in topo if x.op is Ops.DEFINE_VAR}
|
||||
def variables(self) -> list[Variable]:
|
||||
return sorted(set([x.unbind()[0] if x.op is not Ops.DEFINE_VAR else x for x in self.vars()]), key=lambda v: v.arg)
|
||||
return sorted({x for x in self.backward_slice_with_self if x.op is Ops.DEFINE_VAR}, key=lambda v: v.arg)
|
||||
|
||||
# *** uop symbolic stuff ***
|
||||
|
||||
@@ -820,7 +816,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
@functools.cached_property
|
||||
def _sym_fxn(self):
|
||||
sself = self.simplify()
|
||||
varnames = tuple(x.arg[0] for x in sself.toposort() if x.op is Ops.DEFINE_VAR)
|
||||
varnames = tuple(x.expr for x in sself.toposort() if x.op is Ops.DEFINE_VAR)
|
||||
# TODO: sanitize varnames, or don't use naked eval while staying fast
|
||||
return eval("lambda "+','.join(varnames)+": "+sself.render(pm=renderer_infer)), varnames # pylint: disable=eval-used
|
||||
|
||||
@@ -1387,7 +1383,7 @@ def bitcast(x, in_dtype:DType, out_dtype:DType):
|
||||
return ret[0] if out_count == 1 else ret
|
||||
|
||||
renderer = PatternMatcher([
|
||||
(UPat((Ops.DEFINE_VAR,), name="x"), lambda x: x.arg[0]),
|
||||
(UPat((Ops.DEFINE_VAR,), name="x"), lambda x: x.expr),
|
||||
(UPat((Ops.SPECIAL), name="x"), lambda x: x.arg),
|
||||
(UPat(Ops.RANGE, name="x"), lambda x: f"r{range_str(x)}"),
|
||||
(UPat((Ops.CONST, Ops.VCONST), name="x"), lambda x: str(x.arg)),
|
||||
|
||||
Reference in New Issue
Block a user