diff --git a/test/external/process_replay/process_replay.py b/test/external/process_replay/process_replay.py index 3f09df6d86..de16b111ca 100755 --- a/test/external/process_replay/process_replay.py +++ b/test/external/process_replay/process_replay.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 # compare kernels created by HEAD against master -import os, multiprocessing, logging, pickle, sqlite3, difflib, warnings, itertools, functools, base64, codecs +import os, multiprocessing, logging, pickle, sqlite3, difflib, warnings, itertools, functools, base64 from typing import Callable, Any from tinygrad.helpers import VERSION, Context, ContextVar, colored, db_connection, getenv, tqdm from tinygrad.kernelize.kernelize import get_kernelize_map @@ -20,8 +20,7 @@ early_stop = multiprocessing.Event() logging.basicConfig(level=logging.INFO, format="%(message)s") MAX_LINES = 500 def trunc_log(x): - if len(lines:=(x if isinstance(x, str) else repr(x)).splitlines()) > MAX_LINES: - lines = lines[:MAX_LINES]+[f"WARN: truncated string with {len(lines)} lines"] + if len(lines:=repr(x).splitlines()) > MAX_LINES: lines = lines[:MAX_LINES]+[f"WARN: truncated string with {len(lines)} lines"] logging.info("\n".join(lines)) # user config @@ -48,7 +47,7 @@ def replay_get_program(p:ProgramSpec, ast:UOp, renderer:Renderer) -> tuple[str, # PYTHON renderer pickles UOps, first unpickle and decode here if p.device.startswith("PYTHON"): return "\n".join([str(x) for x in pickle.loads(base64.b64decode(ret.src))]) return ret.src - return to_str(p2), to_str(p), (codecs.decode(str(input_ast), "unicode_escape"), renderer) + return to_str(p2), to_str(p), (input_ast, renderer) replayers: dict[str, Callable[..., tuple[str, str, tuple[Any, ...]]]] = {"get_kernelize_map":replay_kernelize, "get_program":replay_get_program}