mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 08:28:55 +08:00
search: skip timing the unoptimized kernel (#4395)
* search: skip timing the unoptimized kernel also ensure the return the unoptimized kernel if no opts are valid and refactor debugging to a single BEAM_DEBUG variable * stop early on fast kernels that can't improve enough
This commit is contained in:
@@ -100,7 +100,7 @@ def get_linearizer_actions(lin:Linearizer, include_0=True) -> Dict[int, Lineariz
|
||||
except KernelOptError: pass
|
||||
return acted_lins
|
||||
|
||||
beam_pool = None
|
||||
beam_pool, BEAM_DEBUG = None, getenv("BEAM_DEBUG")
|
||||
def beam_search(lin:Linearizer, rawbufs:List[Buffer], amt:int, allow_test_size=True) -> Linearizer:
|
||||
global beam_pool
|
||||
key = {"ast": lin.ast[0].key, "amt": amt, "allow_test_size": allow_test_size, "device": lin.opts.device, "suffix": lin.opts.suffix}
|
||||
@@ -109,14 +109,16 @@ def beam_search(lin:Linearizer, rawbufs:List[Buffer], amt:int, allow_test_size=T
|
||||
for o in val[len(lin.applied_opts):]: ret.apply_opt(o)
|
||||
return ret
|
||||
|
||||
beam: List[Tuple[Linearizer, float]] = []
|
||||
beam: List[Tuple[Linearizer, float]] = [(lin, float("inf"))]
|
||||
seen_libs = set()
|
||||
|
||||
default_parallel = multiprocessing.cpu_count() if lin.opts.device in {"CUDA", "HSA", "AMD", "NV"} else 0
|
||||
if beam_pool is None and (workers := getenv("PARALLEL", default_parallel)):
|
||||
beam_pool = multiprocessing.get_context("spawn").Pool(workers, _init_worker, (), getenv("BEAM_MAX_TASKS_PER_CHILD", 16))
|
||||
|
||||
min_progress_micros = getenv("BEAM_MIN_PROGRESS", 0.01)
|
||||
min_progress = getenv("BEAM_MIN_PROGRESS", 0.01)/1e6
|
||||
if BEAM_DEBUG: print(f"BEAM_SEARCH:\n{lin.ast}")
|
||||
if DEBUG >= 2: print(f" 0.00s: from 1 -> 1 actions {lin.colored_shape()}")
|
||||
|
||||
try:
|
||||
rawbufs = _ensure_buffer_alloc(rawbufs)
|
||||
@@ -124,7 +126,7 @@ def beam_search(lin:Linearizer, rawbufs:List[Buffer], amt:int, allow_test_size=T
|
||||
exiting, st = False, time.perf_counter()
|
||||
dev = Device[lin.opts.device]
|
||||
while not exiting:
|
||||
acted_lins: List[Linearizer] = flatten([get_linearizer_actions(lin, include_0=False).values() for lin,_ in beam]) if len(beam) else [lin]
|
||||
acted_lins: List[Linearizer] = flatten([get_linearizer_actions(lin, include_0=False).values() for lin,_ in beam])
|
||||
timed_lins: List[Tuple[Linearizer, float]] = []
|
||||
_compile_fn = functools.partial(_try_compile_linearized_w_idx, compiler=dev.compiler)
|
||||
for i,proc in (map(_compile_fn, enumerate(acted_lins)) if beam_pool is None else beam_pool.imap_unordered(_compile_fn, enumerate(acted_lins))):
|
||||
@@ -136,21 +138,21 @@ def beam_search(lin:Linearizer, rawbufs:List[Buffer], amt:int, allow_test_size=T
|
||||
try: tms = _time_program(uops, dev, lib, global_size, local_size, var_vals, rawbufs, early_stop=beam[0][1]*3 if len(beam) else 1.0)
|
||||
except RuntimeError: continue # for runtime issues
|
||||
timed_lins.append((acted_lins[i], min(tms)))
|
||||
if getenv("BEAM_LOG") > 0: print(f"{time.perf_counter() - st:7.2f}s: {i:5d} {len(uops.uops):5d} uops {compile_et*1e6:12.2f} us compile/{timed_lins[-1][1]*1e6:12.2f} us run {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}") # noqa: E501
|
||||
if BEAM_DEBUG > 1: print(f"{time.perf_counter() - st:7.2f}s: {i:5d} {len(uops.uops):5d} uops {compile_et*1e6:12.2f} us compile/{timed_lins[-1][1]*1e6:12.2f} us run {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}") # noqa: E501
|
||||
elif DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s: {timed_lins[-1][1]*1e6:12.2f} us {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}\033[K", end="") # noqa: E501
|
||||
|
||||
# done
|
||||
opts = sorted(timed_lins, key=lambda x: x[1])
|
||||
exiting = len(opts) == 0 or (len(beam) > 0 and ((beam[0][1]-opts[0][1])*1e6 < min_progress_micros))
|
||||
exiting = len(opts) == 0 or (opts[0][1] < min_progress) or (len(beam) > 0 and ((beam[0][1]-opts[0][1]) < min_progress))
|
||||
if not exiting: beam = opts[:amt]
|
||||
elif len(opts) > 0 and opts[0][1] < beam[0][1]: beam = opts[:1]
|
||||
assert len(beam) > 0, "no BEAM items succeeded?!?" # this asserts in unet3d multi-gpu, need to figure out why
|
||||
if DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s:", colored(f"{beam[0][1]*1e6:12.2f} us", "green" if exiting else None), f"from {len(acted_lins):3d} -> {len(opts):3d} actions\033[K", beam[0][0].colored_shape()) # noqa: E501
|
||||
except KeyboardInterrupt as e:
|
||||
if beam_pool is not None: beam_pool.terminate()
|
||||
raise e
|
||||
|
||||
if CACHELEVEL >= 1: diskcache_put("beam_search", key, beam[0][0].applied_opts)
|
||||
if BEAM_DEBUG: print(f"BEAM_SEARCH: final tm={beam[0][1]*1e6:0.2f} us, applied_opts={beam[0][0].applied_opts}")
|
||||
return beam[0][0]
|
||||
|
||||
def optimize_local_size(clprg:Callable, global_size:List[int], rawbufs:List[Buffer]) -> List[int]:
|
||||
|
||||
Reference in New Issue
Block a user