mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
move fuzz_schedule.py to extra [pr] (#10444)
This commit is contained in:
68
test/external/fuzz_uops.py
vendored
68
test/external/fuzz_uops.py
vendored
@@ -1,68 +0,0 @@
|
||||
import itertools
|
||||
from collections import defaultdict
|
||||
import numpy as np
|
||||
from dataclasses import replace
|
||||
from typing import DefaultDict, Dict, List, Tuple
|
||||
from tinygrad.uop.ops import UOp, print_uops, Ops
|
||||
from tinygrad.device import Buffer, Device
|
||||
from tinygrad.engine.realize import CompiledRunner
|
||||
from tinygrad.helpers import DEBUG, colored
|
||||
from tinygrad.uop.ops import Variable
|
||||
from tinygrad.tensor import _to_np_dtype
|
||||
from test.external.fuzz_schedule import FUZZ_SCHEDULE_MAX_PATHS, find_all_toposorts
|
||||
|
||||
END_FOR_UOP = {Ops.IF:(Ops.STORE, Ops.ENDIF), Ops.RANGE:(Ops.ASSIGN, Ops.ENDRANGE)}
|
||||
|
||||
def fuzz_uops(uops:List[UOp]) -> List[Tuple[UOp, ...]]:
|
||||
blocks: List[List[UOp]] = [[]]
|
||||
for u in uops:
|
||||
if u.op in END_FOR_UOP: blocks.append([u])
|
||||
elif u.op in {x[1] for x in END_FOR_UOP.values()}: blocks.extend([[u], []])
|
||||
else: blocks[-1].append(u)
|
||||
|
||||
paths_for_block: Dict[int, List[Tuple[UOp, ...]]] = {}
|
||||
for bi, bb in enumerate(blocks):
|
||||
children: DefaultDict[UOp, List[UOp]] = defaultdict(list)
|
||||
in_degree: Dict[UOp, int] = {}
|
||||
for u in bb:
|
||||
in_degree[u] = 0
|
||||
for x in u.src:
|
||||
if x in bb:
|
||||
children[x].append(u)
|
||||
in_degree[u] += 1
|
||||
paths_for_block[bi] = find_all_toposorts(children, in_degree)
|
||||
paths: Dict[Tuple[UOp, ...], None] = {}
|
||||
for up in itertools.product(*paths_for_block.values()):
|
||||
paths[tuple(uop for path in up for uop in path)] = None
|
||||
if len(paths) >= FUZZ_SCHEDULE_MAX_PATHS: break
|
||||
return list(paths)
|
||||
|
||||
class UOpsFuzzerRunner(CompiledRunner):
|
||||
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False):
|
||||
assert self.p.uops is not None
|
||||
fuzz_paths = fuzz_uops(self.p.uops)
|
||||
init_rawbufs, init_name = {x:x.as_buffer() for x in rawbufs}, self.p.function_name
|
||||
init_globals = dict(zip(self.p.globals, rawbufs))
|
||||
if DEBUG >= 1: print(colored(f"fuzzing {len(fuzz_paths)} uop permutations for {init_name}", "yellow"))
|
||||
|
||||
super().__call__(rawbufs, var_vals, wait)
|
||||
ground_truth = {x:np.frombuffer(x.as_buffer(), _to_np_dtype(x.dtype)) for x in rawbufs}
|
||||
|
||||
for i, path in enumerate(fuzz_paths):
|
||||
# setup prg
|
||||
uops = list(path)
|
||||
if DEBUG >= 5: print_uops(uops)
|
||||
self.p = replace(self.p, name=(name:=f"{init_name}fuzz{i}"), src=Device[self.p.device].renderer.render(uops), uops=uops)
|
||||
if DEBUG >= 4: print(self.p.src)
|
||||
self.lib = Device[self.p.device].compiler.compile_cached(self.p.src)
|
||||
self.clprg = Device[self.p.device].runtime(name, self.lib)
|
||||
for x in (rawbufs:=[init_globals[i] for i in self.p.globals]): x.copyin(init_rawbufs[x])
|
||||
# verify
|
||||
super().__call__(rawbufs, var_vals, wait)
|
||||
for i, x in enumerate(rawbufs):
|
||||
try:
|
||||
np.testing.assert_allclose(np.frombuffer(x.as_buffer(), _to_np_dtype(x.dtype)), ground_truth[x], atol=1e-6, rtol=1e-6)
|
||||
if DEBUG >= 2: print(colored(name, "green"))
|
||||
except AssertionError as e:
|
||||
print(colored(name, "red"))
|
||||
raise e
|
||||
Reference in New Issue
Block a user