process_replay for get_rangeify_map (#12624)

This commit is contained in:
qazal
2025-10-12 15:14:40 +03:00
committed by GitHub
parent b5afa3848e
commit fd51ecf983

View File

@@ -42,13 +42,13 @@ class ProcessReplayWarning(Warning): pass
# *** replay the function and convert return values to string
def replay_kernelize(ret:dict[UOp, UOp], big_sink:UOp) -> tuple[str, str, tuple[Any, ...]]:
def replay_get_rangeify_map(ret:dict[UOp, UOp], big_sink:UOp) -> tuple[str, str, tuple[Any, ...]]:
UOp.unique_num = itertools.count(max([u.arg for u in big_sink.toposort() if u.op is Ops.UNIQUE], default=0)+1)
new_sink = big_sink.substitute(get_rangeify_map(big_sink))
def to_str(ret:UOp) -> str:
asts = [repr(u.arg.ast) for u in ret.toposort() if u.op is Ops.KERNEL]
return "\n".join([f"{len(asts)} kernels", *asts])
return to_str(new_sink), to_str(ret[big_sink]), (big_sink,)
return to_str(new_sink), to_str(big_sink.substitute(ret)), (big_sink,)
def replay_get_program(p:ProgramSpec, ast:UOp, renderer:Renderer|None=None, opts:list[Opt]|None=None) -> tuple[str, str, tuple[Any, ...]]:
# NOTE: this always uses the opts_to_apply path
@@ -65,7 +65,7 @@ def replay_get_program(p:ProgramSpec, ast:UOp, renderer:Renderer|None=None, opts
ast_repr = codecs.decode(str(input_ast), "unicode_escape")
return to_str(p2), to_str(p), (ast_repr, renderer)
replayers: dict[str, Callable[..., tuple[str, str, tuple[Any, ...]]]] = {"get_kernelize_map":replay_kernelize, "get_program":replay_get_program}
replayers: dict[str, Callable[..., tuple[str, str, tuple[Any, ...]]]] = {"get_rangeify_map":replay_get_rangeify_map, "get_program":replay_get_program}
# *** run replayers on captured rows and print diffs