From f4e83b30b414a551e70a00da3068cedf9dd55cd9 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Mon, 9 Sep 2024 16:59:04 +0800 Subject: [PATCH] cleanup process_replay/* namings [run_process_replay] (#6429) --- extra/optimization/extract_dataset.py | 6 ++--- extra/optimization/generate_dataset.sh | 2 +- test/external/process_replay/diff_schedule.py | 23 +++++++++++-------- .../external/process_replay/process_replay.py | 14 ++++++----- 4 files changed, 25 insertions(+), 20 deletions(-) diff --git a/extra/optimization/extract_dataset.py b/extra/optimization/extract_dataset.py index 174c276e37..7a4df252e9 100755 --- a/extra/optimization/extract_dataset.py +++ b/extra/optimization/extract_dataset.py @@ -2,14 +2,14 @@ # extract asts from process replay artifacts import os, pickle from tinygrad.helpers import db_connection, getenv, VERSION -from test.external.process_replay.process_replay import _run_differ +from test.external.process_replay.process_replay import _pmap PAGE_SIZE = 100 RUN_ID = os.getenv("GITHUB_RUN_ID", "HEAD") TABLE_NAME = f"process_replay_{RUN_ID}_{getenv('GITHUB_RUN_ATTEMPT')}_{VERSION}" LOGOPS = os.getenv("LOGOPS", "/tmp/ops") -def extract_ast(offset:int): +def extract_ast(offset:int) -> bool: logops = open(LOGOPS, "a") conn = db_connection() for row in conn.execute(f"SELECT val FROM '{TABLE_NAME}' LIMIT ? OFFSET ?", (PAGE_SIZE, offset)).fetchall(): @@ -19,4 +19,4 @@ def extract_ast(offset:int): if __name__ == "__main__": conn = db_connection() row_count = conn.execute(f"SELECT COUNT(*) FROM '{TABLE_NAME}'").fetchone()[0] - _run_differ(row_count, extract_ast) + _pmap(row_count, extract_ast) diff --git a/extra/optimization/generate_dataset.sh b/extra/optimization/generate_dataset.sh index 36f57890f0..905d7d589b 100755 --- a/extra/optimization/generate_dataset.sh +++ b/extra/optimization/generate_dataset.sh @@ -1,5 +1,5 @@ #!/bin/bash -export LOGOPS=/tmp/sops +export LOGOPS=/tmp/ops export RUN_PROCESS_REPLAY=1 rm $LOGOPS test/external/process_replay/reset.py diff --git a/test/external/process_replay/diff_schedule.py b/test/external/process_replay/diff_schedule.py index 25df93342d..cb2ac33002 100644 --- a/test/external/process_replay/diff_schedule.py +++ b/test/external/process_replay/diff_schedule.py @@ -40,24 +40,27 @@ def diff_schedule(s:List[Tuple[DefaultDict[LBScheduleItem, List[LBScheduleItem]] seen_diffs.add(cache_key) changed += 1 if CAPTURING_PROCESS_REPLAY: diskcache_put("schedule_diff", str(uuid.uuid4()), (str(buf), [ref.ast.key, compare.ast.key])) - if not CI: print_si_diff(si[0], si[1]) + if not CI: print_si_diff(ref, compare) if DEBUG >= 1: print(f"*** process replay: {changed} unique kernel{'s' if changed>1 else ''} changed") return changed -def print_si_diff(si0:ScheduleItem, si1:ScheduleItem): +def print_si_diff(ref:ScheduleItem, compare:ScheduleItem) -> None: logging.basicConfig(level=logging.INFO) - print_diff(si0.ast, si1.ast) + print_diff(ref.ast, compare.ast) # skip lowering/runtime error - with contextlib.suppress(Exception): - ei0 = lower_schedule_item(si0) - ei1 = lower_schedule_item(si1) - assert isinstance(ei0.prg, CompiledRunner) and isinstance(ei1.prg, CompiledRunner) - if DEBUG >= 4: print_diff(ei0.prg.p.src, ei1.prg.p.src) + with contextlib.suppress(Exception): lower_si_diff(ref, compare) + +def lower_si_diff(ref:ScheduleItem, compare:ScheduleItem) -> None: + if DEBUG >= 4: + ref_ei = lower_schedule_item(ref) + compare_ei = lower_schedule_item(compare) + assert isinstance(ref_ei.prg, CompiledRunner) and isinstance(compare_ei.prg, CompiledRunner) + print_diff(ref_ei.prg.p.src, compare_ei.prg.p.src) # TODO: create new Buffers for process replay to test correctness if getenv("TIMING"): with Context(DEBUG=2): - tm0 = ei0.run(wait=True) - tm1 = ei1.run(wait=True) + tm0 = ref_ei.run(wait=True) + tm1 = compare_ei.run(wait=True) assert tm0 is not None and tm1 is not None tm_diff = ((tm0 - tm1) / tm0) * 100 if tm_diff > 0: print(colored(f"{tm_diff:.2f}% faster", "green")) diff --git a/test/external/process_replay/process_replay.py b/test/external/process_replay/process_replay.py index edabbe9b42..9345ee87df 100755 --- a/test/external/process_replay/process_replay.py +++ b/test/external/process_replay/process_replay.py @@ -82,17 +82,19 @@ def diff_kernel(offset:int) -> bool: cur.close() return bool(changed) -# *** differ runners with multiprocessing +# *** generic runner for executing fxn across all rows of a table in parallel -def _run_differ(row_count:int, differ:Callable[[int], bool]) -> None: - with multiprocessing.get_context("spawn").Pool(multiprocessing.cpu_count(), maxtasksperchild=16) as pool: +def _pmap(row_count:int, fxn:Callable[[int], bool], maxtasksperchild:int=16) -> None: + with multiprocessing.get_context("spawn").Pool(multiprocessing.cpu_count(), maxtasksperchild=maxtasksperchild) as pool: inputs = list(range(0, row_count, PAGE_SIZE)) - changed: List[bool] = list(tqdm(pool.imap_unordered(differ, inputs), total=len(inputs))) + changed: List[bool] = list(tqdm(pool.imap_unordered(fxn, inputs), total=len(inputs))) pool.close() pool.join() pool.terminate() if any(changed) and ASSERT_DIFF: raise AssertionError("process replay detected changes") +# *** process replay parallel differ runners + def process_replay_schedule() -> None: conn = db_connection() cur = conn.cursor() @@ -105,7 +107,7 @@ def process_replay_schedule() -> None: if row_count != 0: logging.info("***** schedule diff") conn.commit() cur.close() - _run_differ(row_count, diff_schedule) + _pmap(row_count, diff_schedule) def process_replay_kernel() -> None: conn = db_connection() @@ -116,7 +118,7 @@ def process_replay_kernel() -> None: return None conn.commit() cur.close() - _run_differ(row_count, diff_kernel) + _pmap(row_count, diff_kernel) # *** main loop