From 5864627abe6dffdccf3ebafb230eb65c1bd40e2c Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Fri, 13 Dec 2024 11:43:43 +0200 Subject: [PATCH] process replay filter warnings [pr] (#8199) --- test/external/process_replay/process_replay.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/test/external/process_replay/process_replay.py b/test/external/process_replay/process_replay.py index 95e56de660..c0d80fee1f 100755 --- a/test/external/process_replay/process_replay.py +++ b/test/external/process_replay/process_replay.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 # compare kernels created by HEAD against master -import os, multiprocessing, logging, pickle, sqlite3, difflib, functools +import os, multiprocessing, logging, pickle, sqlite3, difflib, functools, warnings from typing import Callable, List, Set, Tuple, Union, cast from tinygrad.helpers import VERSION, Context, ContextVar, colored, db_connection, getenv, tqdm from tinygrad.engine.schedule import ScheduleContext, full_ast_rewrite @@ -25,6 +25,7 @@ ASSERT_DIFF = int((flag:="[pr]") in os.getenv("COMMIT_MESSAGE", flag) or flag in if not getenv("ASSERT_PROCESS_REPLAY", 1): ASSERT_DIFF = 0 SKIP_PROCESS_REPLAY = (k:="[skip_process_replay]") in os.getenv("COMMIT_MESSAGE", "") or k in os.getenv("PR_TITLE", "") if REF == "master": SKIP_PROCESS_REPLAY = True +class ProcessReplayWarning(Warning): pass # *** recreators @@ -56,9 +57,8 @@ def diff(offset:int, name:str, fxn:Callable) -> Union[Tuple[int, int], bool]: with Context(**{k:v for k,v in args[-2].items() if k in ContextVar._cache and k != "DEBUG"}): good = fxn(*args[:-2]) if good is None: continue except Exception as e: - logging.warning(f"FAILED TO RECREATE KERNEL {e}") + warnings.warn(f"FAILED TO RECREATE KERNEL {e}", ProcessReplayWarning) for x in args[:-1]: logging.info(x) - if ASSERT_DIFF: return True continue # diff kernels try: assert args[-1] == good @@ -85,7 +85,7 @@ def _pmap(name:str, fxn:Callable, maxtasksperchild:int=16) -> None: cur = conn.cursor() try: row_count = cur.execute(f"select count(*) from '{name}_{TABLE_NAME}'").fetchone()[0] except sqlite3.OperationalError: - logging.warning(f"{name}_{TABLE_NAME} isn't accessible in master, did DB_VERSION change?") + warnings.warn(f"{name}_{TABLE_NAME} isn't accessible in master, did DB_VERSION change?", ProcessReplayWarning) return None conn.commit() cur.close() @@ -100,7 +100,7 @@ def _pmap(name:str, fxn:Callable, maxtasksperchild:int=16) -> None: logging.info(f"{sum(changed)} kernels changed") if sum(insertion) != 0: logging.info(colored(f"{sum(insertion)} insertions(+)", "green")) if sum(deletions) != 0: logging.info(colored(f"{sum(deletions)} deletions(-)", "red")) - if any(changed) and ASSERT_DIFF: raise AssertionError("process replay detected changes") + if any(changed): warnings.warn("process replay detected changes", ProcessReplayWarning) # *** main loop @@ -109,6 +109,7 @@ if __name__ == "__main__": logging.info("skipping process replay.") exit(0) + if ASSERT_DIFF: warnings.filterwarnings("error", category=ProcessReplayWarning) for name,fxn in [("schedule", recreate_sched), ("kernel", recreate_kernel)]: logging.info(f"***** {name} diff") try: _pmap(name, fxn)