From 8023748e285c88c021cbb826863a0be3f441bfdd Mon Sep 17 00:00:00 2001 From: qazal Date: Fri, 19 Apr 2024 10:32:33 +0300 Subject: [PATCH] fuzz unique toposorts --- test/external/fuzz_schedule.py | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/test/external/fuzz_schedule.py b/test/external/fuzz_schedule.py index 0d469df4dc..26a50102af 100644 --- a/test/external/fuzz_schedule.py +++ b/test/external/fuzz_schedule.py @@ -12,30 +12,34 @@ from tinygrad.tensor import Tensor ctx_vars = { MULTIOUTPUT: (0, 1) } def fuzz_schedule(outs: List[LazyBuffer]): - toposorts: List[Tuple[Dict, List[LazyBuffer], Dict[LazyBuffer, _LBScheduleItem]]] = [] + toposorts: Dict[Tuple[LazyBuffer, ...], Tuple[Dict, Dict[LazyBuffer, _LBScheduleItem]]] = {} for combination in itertools.product(*ctx_vars.values()): for var, val in zip(ctx_vars, combination): var.value = val graph, in_degree, prescheduled = _graph_schedule(outs, set()) - for ts in find_all_toposorts(graph, in_degree): toposorts.append((dict(zip([v.key for v in ctx_vars], combination)), ts, prescheduled)) + for ts in find_all_toposorts(graph, in_degree): + if ts not in toposorts: toposorts[ts] = (dict(zip([v.key for v in ctx_vars], combination)), prescheduled) if DEBUG >= 1: print(colored(f"fuzzing {len(toposorts)} schedule permutations", "yellow")) # setup ground truth ground_truth: Dict[LazyBuffer, memoryview] = {} # IMPORTANT: freeze prerealized bufs before ScheduleItem exec prerealized: Dict[LazyBuffer, memoryview] = {} - seed = Tensor._seed - for key in toposorts[0][1]: - for out in (ps:=toposorts[0][2][key]).outputs: + seed, first = Tensor._seed, list(toposorts.items())[0] + for key in first[0]: + for out in (ps:=first[1][1][key]).outputs: # freeze assign state before exec if out.op is LoadOps.ASSIGN: prerealized[out] = out.buffer.as_buffer() for x in ps.inputs: if x not in ground_truth and x.device != "NPY": prerealized[x] = x.buffer.as_buffer() si = ScheduleItem(ps.ast, tuple(x.buffer for x in ps.outputs if x.size != 0), tuple(x.buffer for x in ps.inputs if x.size != 0)) _exec_si(si, seed) - for out in ps.outputs: ground_truth[out] = out.buffer.as_buffer() + for out in ps.outputs: + ground_truth[out] = out.buffer.as_buffer() + del out.srcs # only schedule the LazyBuffer in this fuzz run # exec and validate each permutation with new Buffers - for i, (ctx, ts, prescheduled) in enumerate(toposorts[1:]): + for i, (ts, (ctx, prescheduled)) in enumerate(toposorts.items()): + if i == 0: continue if DEBUG >= 1: print(colored(f"testing permutation {i} {ctx}", "yellow")) rawbufs: Dict[LazyBuffer, Buffer] = {} for key in ts: @@ -50,7 +54,6 @@ def fuzz_schedule(outs: List[LazyBuffer]): si = ScheduleItem(ps.ast, tuple(rawbufs[x] for x in ps.outputs if x.size != 0), tuple(rawbufs[x] for x in ps.inputs if x.size != 0)) _exec_si(si, seed) for out in ps.outputs: - if hasattr(out, "srcs"): del out.srcs # can only schedule once outbuf = np.frombuffer(rawbufs[out].as_buffer(), out.dtype.np) try: np.testing.assert_allclose(outbuf, np.frombuffer(ground_truth[out], out.dtype.np), atol=1e-2, rtol=1e-2) except Exception as e: @@ -64,9 +67,9 @@ def _exec_si(si: ScheduleItem, seed:int): ei.run() T = TypeVar("T") -def find_all_toposorts(graph:DefaultDict[T, List[T]], in_degree:DefaultDict[T, int]) -> List[List[T]]: +def find_all_toposorts(graph:DefaultDict[T, List[T]], in_degree:DefaultDict[T, int]) -> List[Tuple[T, ...]]: visited: Set[T] = set() - ret: List[List[T]] = [] + ret: List[Tuple[T, ...]] = [] path: List[T] = [] def recurse_paths(path:List[T]): @@ -81,10 +84,10 @@ def find_all_toposorts(graph:DefaultDict[T, List[T]], in_degree:DefaultDict[T, i for u in graph[v]: in_degree[u] += 1 path.pop() visited.remove(v) - if len(path) == len(in_degree): ret.append([*path]) + if len(path) == len(in_degree): ret.append(tuple([*path])) recurse_paths(path) if len(ret) == 0: raise RuntimeError("detected cycle in the graph") # verify all paths are unique - assert len(ret) == len(set(map(tuple, ret))) + assert len(ret) == len(set(ret)) return ret