mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 08:28:55 +08:00
fuzz unique toposorts
This commit is contained in:
27
test/external/fuzz_schedule.py
vendored
27
test/external/fuzz_schedule.py
vendored
@@ -12,30 +12,34 @@ from tinygrad.tensor import Tensor
|
||||
ctx_vars = { MULTIOUTPUT: (0, 1) }
|
||||
|
||||
def fuzz_schedule(outs: List[LazyBuffer]):
|
||||
toposorts: List[Tuple[Dict, List[LazyBuffer], Dict[LazyBuffer, _LBScheduleItem]]] = []
|
||||
toposorts: Dict[Tuple[LazyBuffer, ...], Tuple[Dict, Dict[LazyBuffer, _LBScheduleItem]]] = {}
|
||||
for combination in itertools.product(*ctx_vars.values()):
|
||||
for var, val in zip(ctx_vars, combination): var.value = val
|
||||
graph, in_degree, prescheduled = _graph_schedule(outs, set())
|
||||
for ts in find_all_toposorts(graph, in_degree): toposorts.append((dict(zip([v.key for v in ctx_vars], combination)), ts, prescheduled))
|
||||
for ts in find_all_toposorts(graph, in_degree):
|
||||
if ts not in toposorts: toposorts[ts] = (dict(zip([v.key for v in ctx_vars], combination)), prescheduled)
|
||||
if DEBUG >= 1: print(colored(f"fuzzing {len(toposorts)} schedule permutations", "yellow"))
|
||||
|
||||
# setup ground truth
|
||||
ground_truth: Dict[LazyBuffer, memoryview] = {}
|
||||
# IMPORTANT: freeze prerealized bufs before ScheduleItem exec
|
||||
prerealized: Dict[LazyBuffer, memoryview] = {}
|
||||
seed = Tensor._seed
|
||||
for key in toposorts[0][1]:
|
||||
for out in (ps:=toposorts[0][2][key]).outputs:
|
||||
seed, first = Tensor._seed, list(toposorts.items())[0]
|
||||
for key in first[0]:
|
||||
for out in (ps:=first[1][1][key]).outputs:
|
||||
# freeze assign state before exec
|
||||
if out.op is LoadOps.ASSIGN: prerealized[out] = out.buffer.as_buffer()
|
||||
for x in ps.inputs:
|
||||
if x not in ground_truth and x.device != "NPY": prerealized[x] = x.buffer.as_buffer()
|
||||
si = ScheduleItem(ps.ast, tuple(x.buffer for x in ps.outputs if x.size != 0), tuple(x.buffer for x in ps.inputs if x.size != 0))
|
||||
_exec_si(si, seed)
|
||||
for out in ps.outputs: ground_truth[out] = out.buffer.as_buffer()
|
||||
for out in ps.outputs:
|
||||
ground_truth[out] = out.buffer.as_buffer()
|
||||
del out.srcs # only schedule the LazyBuffer in this fuzz run
|
||||
|
||||
# exec and validate each permutation with new Buffers
|
||||
for i, (ctx, ts, prescheduled) in enumerate(toposorts[1:]):
|
||||
for i, (ts, (ctx, prescheduled)) in enumerate(toposorts.items()):
|
||||
if i == 0: continue
|
||||
if DEBUG >= 1: print(colored(f"testing permutation {i} {ctx}", "yellow"))
|
||||
rawbufs: Dict[LazyBuffer, Buffer] = {}
|
||||
for key in ts:
|
||||
@@ -50,7 +54,6 @@ def fuzz_schedule(outs: List[LazyBuffer]):
|
||||
si = ScheduleItem(ps.ast, tuple(rawbufs[x] for x in ps.outputs if x.size != 0), tuple(rawbufs[x] for x in ps.inputs if x.size != 0))
|
||||
_exec_si(si, seed)
|
||||
for out in ps.outputs:
|
||||
if hasattr(out, "srcs"): del out.srcs # can only schedule once
|
||||
outbuf = np.frombuffer(rawbufs[out].as_buffer(), out.dtype.np)
|
||||
try: np.testing.assert_allclose(outbuf, np.frombuffer(ground_truth[out], out.dtype.np), atol=1e-2, rtol=1e-2)
|
||||
except Exception as e:
|
||||
@@ -64,9 +67,9 @@ def _exec_si(si: ScheduleItem, seed:int):
|
||||
ei.run()
|
||||
|
||||
T = TypeVar("T")
|
||||
def find_all_toposorts(graph:DefaultDict[T, List[T]], in_degree:DefaultDict[T, int]) -> List[List[T]]:
|
||||
def find_all_toposorts(graph:DefaultDict[T, List[T]], in_degree:DefaultDict[T, int]) -> List[Tuple[T, ...]]:
|
||||
visited: Set[T] = set()
|
||||
ret: List[List[T]] = []
|
||||
ret: List[Tuple[T, ...]] = []
|
||||
path: List[T] = []
|
||||
|
||||
def recurse_paths(path:List[T]):
|
||||
@@ -81,10 +84,10 @@ def find_all_toposorts(graph:DefaultDict[T, List[T]], in_degree:DefaultDict[T, i
|
||||
for u in graph[v]: in_degree[u] += 1
|
||||
path.pop()
|
||||
visited.remove(v)
|
||||
if len(path) == len(in_degree): ret.append([*path])
|
||||
if len(path) == len(in_degree): ret.append(tuple([*path]))
|
||||
recurse_paths(path)
|
||||
|
||||
if len(ret) == 0: raise RuntimeError("detected cycle in the graph")
|
||||
# verify all paths are unique
|
||||
assert len(ret) == len(set(map(tuple, ret)))
|
||||
assert len(ret) == len(set(ret))
|
||||
return ret
|
||||
|
||||
Reference in New Issue
Block a user