Files
StarPilot/tinygrad_repo/tinygrad/codegen/opt/postrange.py
T
2025-09-27 12:00:00 -07:00

19 lines
712 B
Python

from dataclasses import replace
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, KernelInfo
from tinygrad.helpers import colored
from tinygrad.codegen.opt.kernel import axis_colors
def rename_sink(s:UOp):
if s.arg is not None and s.arg.name != "test": return None
# get all ranges (sorted)
rngs = sorted([u for u in s.parents if u.op is Ops.RANGE], key=lambda x: x.arg[0:-1])
# add name to kernel
name = "k" + colored('_', 'BLACK').join(['']+[colored(x.src[0].render(), axis_colors[x.arg[-1]]) for x in rngs])
return s.replace(arg=KernelInfo(name=name) if s.arg is None else replace(s.arg, name=name))
pm_postrange_opt = PatternMatcher([
(UPat(Ops.SINK, name="s"), rename_sink),
])