mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
process_replay for get_rangeify_map (#12624)
This commit is contained in:
@@ -42,13 +42,13 @@ class ProcessReplayWarning(Warning): pass
|
||||
|
||||
# *** replay the function and convert return values to string
|
||||
|
||||
def replay_kernelize(ret:dict[UOp, UOp], big_sink:UOp) -> tuple[str, str, tuple[Any, ...]]:
|
||||
def replay_get_rangeify_map(ret:dict[UOp, UOp], big_sink:UOp) -> tuple[str, str, tuple[Any, ...]]:
|
||||
UOp.unique_num = itertools.count(max([u.arg for u in big_sink.toposort() if u.op is Ops.UNIQUE], default=0)+1)
|
||||
new_sink = big_sink.substitute(get_rangeify_map(big_sink))
|
||||
def to_str(ret:UOp) -> str:
|
||||
asts = [repr(u.arg.ast) for u in ret.toposort() if u.op is Ops.KERNEL]
|
||||
return "\n".join([f"{len(asts)} kernels", *asts])
|
||||
return to_str(new_sink), to_str(ret[big_sink]), (big_sink,)
|
||||
return to_str(new_sink), to_str(big_sink.substitute(ret)), (big_sink,)
|
||||
|
||||
def replay_get_program(p:ProgramSpec, ast:UOp, renderer:Renderer|None=None, opts:list[Opt]|None=None) -> tuple[str, str, tuple[Any, ...]]:
|
||||
# NOTE: this always uses the opts_to_apply path
|
||||
@@ -65,7 +65,7 @@ def replay_get_program(p:ProgramSpec, ast:UOp, renderer:Renderer|None=None, opts
|
||||
ast_repr = codecs.decode(str(input_ast), "unicode_escape")
|
||||
return to_str(p2), to_str(p), (ast_repr, renderer)
|
||||
|
||||
replayers: dict[str, Callable[..., tuple[str, str, tuple[Any, ...]]]] = {"get_kernelize_map":replay_kernelize, "get_program":replay_get_program}
|
||||
replayers: dict[str, Callable[..., tuple[str, str, tuple[Any, ...]]]] = {"get_rangeify_map":replay_get_rangeify_map, "get_program":replay_get_program}
|
||||
|
||||
# *** run replayers on captured rows and print diffs
|
||||
|
||||
|
||||
Reference in New Issue
Block a user