fuzz unique toposorts

This commit is contained in:
qazal
2024-04-19 10:32:33 +03:00
parent c47c5ae464
commit 8023748e28

View File

@@ -12,30 +12,34 @@ from tinygrad.tensor import Tensor
ctx_vars = { MULTIOUTPUT: (0, 1) }
def fuzz_schedule(outs: List[LazyBuffer]):
toposorts: List[Tuple[Dict, List[LazyBuffer], Dict[LazyBuffer, _LBScheduleItem]]] = []
toposorts: Dict[Tuple[LazyBuffer, ...], Tuple[Dict, Dict[LazyBuffer, _LBScheduleItem]]] = {}
for combination in itertools.product(*ctx_vars.values()):
for var, val in zip(ctx_vars, combination): var.value = val
graph, in_degree, prescheduled = _graph_schedule(outs, set())
for ts in find_all_toposorts(graph, in_degree): toposorts.append((dict(zip([v.key for v in ctx_vars], combination)), ts, prescheduled))
for ts in find_all_toposorts(graph, in_degree):
if ts not in toposorts: toposorts[ts] = (dict(zip([v.key for v in ctx_vars], combination)), prescheduled)
if DEBUG >= 1: print(colored(f"fuzzing {len(toposorts)} schedule permutations", "yellow"))
# setup ground truth
ground_truth: Dict[LazyBuffer, memoryview] = {}
# IMPORTANT: freeze prerealized bufs before ScheduleItem exec
prerealized: Dict[LazyBuffer, memoryview] = {}
seed = Tensor._seed
for key in toposorts[0][1]:
for out in (ps:=toposorts[0][2][key]).outputs:
seed, first = Tensor._seed, list(toposorts.items())[0]
for key in first[0]:
for out in (ps:=first[1][1][key]).outputs:
# freeze assign state before exec
if out.op is LoadOps.ASSIGN: prerealized[out] = out.buffer.as_buffer()
for x in ps.inputs:
if x not in ground_truth and x.device != "NPY": prerealized[x] = x.buffer.as_buffer()
si = ScheduleItem(ps.ast, tuple(x.buffer for x in ps.outputs if x.size != 0), tuple(x.buffer for x in ps.inputs if x.size != 0))
_exec_si(si, seed)
for out in ps.outputs: ground_truth[out] = out.buffer.as_buffer()
for out in ps.outputs:
ground_truth[out] = out.buffer.as_buffer()
del out.srcs # only schedule the LazyBuffer in this fuzz run
# exec and validate each permutation with new Buffers
for i, (ctx, ts, prescheduled) in enumerate(toposorts[1:]):
for i, (ts, (ctx, prescheduled)) in enumerate(toposorts.items()):
if i == 0: continue
if DEBUG >= 1: print(colored(f"testing permutation {i} {ctx}", "yellow"))
rawbufs: Dict[LazyBuffer, Buffer] = {}
for key in ts:
@@ -50,7 +54,6 @@ def fuzz_schedule(outs: List[LazyBuffer]):
si = ScheduleItem(ps.ast, tuple(rawbufs[x] for x in ps.outputs if x.size != 0), tuple(rawbufs[x] for x in ps.inputs if x.size != 0))
_exec_si(si, seed)
for out in ps.outputs:
if hasattr(out, "srcs"): del out.srcs # can only schedule once
outbuf = np.frombuffer(rawbufs[out].as_buffer(), out.dtype.np)
try: np.testing.assert_allclose(outbuf, np.frombuffer(ground_truth[out], out.dtype.np), atol=1e-2, rtol=1e-2)
except Exception as e:
@@ -64,9 +67,9 @@ def _exec_si(si: ScheduleItem, seed:int):
ei.run()
T = TypeVar("T")
def find_all_toposorts(graph:DefaultDict[T, List[T]], in_degree:DefaultDict[T, int]) -> List[List[T]]:
def find_all_toposorts(graph:DefaultDict[T, List[T]], in_degree:DefaultDict[T, int]) -> List[Tuple[T, ...]]:
visited: Set[T] = set()
ret: List[List[T]] = []
ret: List[Tuple[T, ...]] = []
path: List[T] = []
def recurse_paths(path:List[T]):
@@ -81,10 +84,10 @@ def find_all_toposorts(graph:DefaultDict[T, List[T]], in_degree:DefaultDict[T, i
for u in graph[v]: in_degree[u] += 1
path.pop()
visited.remove(v)
if len(path) == len(in_degree): ret.append([*path])
if len(path) == len(in_degree): ret.append(tuple([*path]))
recurse_paths(path)
if len(ret) == 0: raise RuntimeError("detected cycle in the graph")
# verify all paths are unique
assert len(ret) == len(set(map(tuple, ret)))
assert len(ret) == len(set(ret))
return ret