merge ground truth with the rest

This commit is contained in:
qazal
2024-04-19 10:40:10 +03:00
parent 8023748e28
commit 1f3463bb57

View File

@@ -20,29 +20,30 @@ def fuzz_schedule(outs: List[LazyBuffer]):
if ts not in toposorts: toposorts[ts] = (dict(zip([v.key for v in ctx_vars], combination)), prescheduled)
if DEBUG >= 1: print(colored(f"fuzzing {len(toposorts)} schedule permutations", "yellow"))
# setup ground truth
ground_truth: Dict[LazyBuffer, memoryview] = {}
# IMPORTANT: freeze prerealized bufs before ScheduleItem exec
prerealized: Dict[LazyBuffer, memoryview] = {}
seed, first = Tensor._seed, list(toposorts.items())[0]
for key in first[0]:
for out in (ps:=first[1][1][key]).outputs:
# freeze assign state before exec
if out.op is LoadOps.ASSIGN: prerealized[out] = out.buffer.as_buffer()
for x in ps.inputs:
if x not in ground_truth and x.device != "NPY": prerealized[x] = x.buffer.as_buffer()
si = ScheduleItem(ps.ast, tuple(x.buffer for x in ps.outputs if x.size != 0), tuple(x.buffer for x in ps.inputs if x.size != 0))
_exec_si(si, seed)
for out in ps.outputs:
ground_truth[out] = out.buffer.as_buffer()
del out.srcs # only schedule the LazyBuffer in this fuzz run
seed = Tensor._seed
# exec and validate each permutation with new Buffers
for i, (ts, (ctx, prescheduled)) in enumerate(toposorts.items()):
if i == 0: continue
if DEBUG >= 1: print(colored(f"testing permutation {i} {ctx}", "yellow"))
rawbufs: Dict[LazyBuffer, Buffer] = {}
for key in ts:
# setup ground truth
if i == 0:
for out in (ps:=prescheduled[key]).outputs:
# freeze assign state before exec
if out.op is LoadOps.ASSIGN: prerealized[out] = out.buffer.as_buffer()
for x in ps.inputs:
if x not in ground_truth and x.device != "NPY": prerealized[x] = x.buffer.as_buffer()
si = ScheduleItem(ps.ast, tuple(x.buffer for x in ps.outputs if x.size != 0), tuple(x.buffer for x in ps.inputs if x.size != 0))
_exec_si(si, seed)
for out in ps.outputs:
ground_truth[out] = out.buffer.as_buffer()
del out.srcs # only schedule the LazyBuffer in this fuzz run
continue
# exec and validate the permutation with new Buffers
for out in (ps:=prescheduled[key]).outputs:
rawbufs[out] = Buffer(out.buffer.device, out.buffer.size, out.buffer.dtype)
if out.op is LoadOps.ASSIGN: rawbufs[out].ensure_allocated().copyin(prerealized[out])