diff --git a/test/external/fuzz_schedule.py b/test/external/fuzz_schedule.py index 26a50102af..5c7b888256 100644 --- a/test/external/fuzz_schedule.py +++ b/test/external/fuzz_schedule.py @@ -20,29 +20,30 @@ def fuzz_schedule(outs: List[LazyBuffer]): if ts not in toposorts: toposorts[ts] = (dict(zip([v.key for v in ctx_vars], combination)), prescheduled) if DEBUG >= 1: print(colored(f"fuzzing {len(toposorts)} schedule permutations", "yellow")) - # setup ground truth ground_truth: Dict[LazyBuffer, memoryview] = {} # IMPORTANT: freeze prerealized bufs before ScheduleItem exec prerealized: Dict[LazyBuffer, memoryview] = {} - seed, first = Tensor._seed, list(toposorts.items())[0] - for key in first[0]: - for out in (ps:=first[1][1][key]).outputs: - # freeze assign state before exec - if out.op is LoadOps.ASSIGN: prerealized[out] = out.buffer.as_buffer() - for x in ps.inputs: - if x not in ground_truth and x.device != "NPY": prerealized[x] = x.buffer.as_buffer() - si = ScheduleItem(ps.ast, tuple(x.buffer for x in ps.outputs if x.size != 0), tuple(x.buffer for x in ps.inputs if x.size != 0)) - _exec_si(si, seed) - for out in ps.outputs: - ground_truth[out] = out.buffer.as_buffer() - del out.srcs # only schedule the LazyBuffer in this fuzz run + seed = Tensor._seed - # exec and validate each permutation with new Buffers for i, (ts, (ctx, prescheduled)) in enumerate(toposorts.items()): - if i == 0: continue if DEBUG >= 1: print(colored(f"testing permutation {i} {ctx}", "yellow")) rawbufs: Dict[LazyBuffer, Buffer] = {} for key in ts: + # setup ground truth + if i == 0: + for out in (ps:=prescheduled[key]).outputs: + # freeze assign state before exec + if out.op is LoadOps.ASSIGN: prerealized[out] = out.buffer.as_buffer() + for x in ps.inputs: + if x not in ground_truth and x.device != "NPY": prerealized[x] = x.buffer.as_buffer() + si = ScheduleItem(ps.ast, tuple(x.buffer for x in ps.outputs if x.size != 0), tuple(x.buffer for x in ps.inputs if x.size != 0)) + _exec_si(si, seed) + for out in ps.outputs: + ground_truth[out] = out.buffer.as_buffer() + del out.srcs # only schedule the LazyBuffer in this fuzz run + continue + + # exec and validate the permutation with new Buffers for out in (ps:=prescheduled[key]).outputs: rawbufs[out] = Buffer(out.buffer.device, out.buffer.size, out.buffer.dtype) if out.op is LoadOps.ASSIGN: rawbufs[out].ensure_allocated().copyin(prerealized[out])