prebuild all rewrites [pr] (#10154)

* prebuild all rewrites [pr]

* fix that

* tests pass with linearizer
This commit is contained in:
George Hotz
2025-05-04 16:01:18 -04:00
committed by GitHub
parent 2b055cb59c
commit fe0724eebf
6 changed files with 131 additions and 36 deletions

View File

@@ -368,8 +368,8 @@ jobs:
# run: NULL=1 python3 examples/llama.py --gen 1 --size 7B --shard 4 --prompt "Hello." --count 3 --temperature 0 --timing
- name: Run GC tests
run: PYTHONPATH="." python test/external/external_uop_gc.py
- name: Repo line count < 12800 lines
run: MAX_LINE_COUNT=12800 python sz.py
- name: Repo line count < 13000 lines
run: MAX_LINE_COUNT=13000 python sz.py
fuzzing:
name: Fuzzing

View File

@@ -450,5 +450,6 @@ def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp:
if opts is not None and opts.pre_matcher is not None: sink = graph_rewrite(sink, opts.pre_matcher)
# final rules for the renderer (without sym)
sink = graph_rewrite(sink, symbolic_simple+get_late_rewrite_patterns(supported_ops, TRANSCENDENTAL>=2)+pm_render+extra_matcher, ctx=opts)
sink = graph_rewrite(sink, symbolic_simple+get_late_rewrite_patterns(supported_ops, TRANSCENDENTAL>=2)+pm_render+extra_matcher,
ctx=opts, name="final rewrite")
return sink

69
tinygrad/codegen/flow.py Normal file
View File

@@ -0,0 +1,69 @@
from typing import Any, Callable
import functools
from dataclasses import dataclass
from tinygrad.helpers import QUANTIZE, DEVECTORIZE, TRANSCENDENTAL
from tinygrad.ops import PatternMatcher, graph_rewrite, UOp
from tinygrad.renderer import Renderer
# import all pattern matchers here
from tinygrad.codegen.lowerer import pm_quant, pm_lowerer, get_index
from tinygrad.codegen.symbolic import sym, symbolic_simple, gep_pushing
from tinygrad.codegen.expander import migrate_indexing, pm_store_ignore, pm_move_ignore, pm_delete_ignore, expander
from tinygrad.codegen.devectorizer import load_store_folding, load_store_indexing, devectorize, \
pm_reduce, ReduceContext, correct_load_store, pm_render, get_late_rewrite_patterns
from tinygrad.codegen.linearize import block_create, pm_blockend_merge, block_merge, pm_finalize, BlockContext
@dataclass
class RewriteStep:
pm: PatternMatcher
ctx: Callable[[UOp], Any]|None = None
name: str|None = None
bottom_up: bool = False
def __call__(self, sink:UOp):
return graph_rewrite(sink, self.pm, ctx=self.ctx(sink) if self.ctx is not None else None, name=self.name, bottom_up=self.bottom_up)
def apply_rewrites(sink:UOp, rewrites:list[RewriteStep]): return functools.reduce(lambda x,f: f(x), rewrites, sink)
def get_rewrites_for_renderer(opts:Renderer, linearizer=True) -> list[RewriteStep]:
# ** lowerer (rewrite_shapetracker_with_index) **
ret: list[RewriteStep] = []
if QUANTIZE and opts.device in {"CPU", "DSP"}: ret.append(RewriteStep(pm_quant, name="quantize"))
ret.append(RewriteStep(pm_lowerer, lambda ast: get_index(ast, opts), name="lowerer"))
# ** expander (expand_rewrite) **
ret.append(RewriteStep(sym+migrate_indexing, name="initial symbolic"))
# ignore (for masked stores)
ret.append(RewriteStep(pm_store_ignore, name="store_ignore"))
ret.append(RewriteStep(pm_move_ignore, name="move_ignore"))
# expand + remove surviving ignores
ret.append(RewriteStep(pm_delete_ignore+sym+expander, name="expander"))
# ** devectorizer (full_graph_rewrite) **
# remove reduce
ret.append(RewriteStep(pm_reduce+gep_pushing, lambda _: ReduceContext(), name="remove_reduce"))
# devectorize (TODO: does this need opts?)
if DEVECTORIZE >= 2: pm_devectorize = sym+load_store_folding+load_store_indexing
elif DEVECTORIZE: pm_devectorize = sym+devectorize+load_store_folding+correct_load_store+load_store_indexing
else: pm_devectorize = sym+load_store_folding+correct_load_store+load_store_indexing
ret.append(RewriteStep(pm_devectorize, lambda _: opts, name="devectorize"))
supported_ops = tuple(opts.code_for_op.keys())
extra_matcher = opts.extra_matcher if opts.extra_matcher is not None else PatternMatcher([])
# optional pre matcher
if opts.pre_matcher is not None: ret.append(RewriteStep(opts.pre_matcher, name="pre_matcher"))
# final rules for the renderer (without sym)
pm_final_rewrite = symbolic_simple+get_late_rewrite_patterns(supported_ops, TRANSCENDENTAL>=2)+pm_render+extra_matcher
ret.append(RewriteStep(pm_final_rewrite, lambda _: opts, name="final rewrite"))
# ** linearizer **
if linearizer:
ret.append(RewriteStep(block_create, ctx=BlockContext.from_sink, name="Linearizer: Create Blocks", bottom_up=True))
ret.append(RewriteStep(pm_blockend_merge, name="Linearizer: Merge Blockends"))
ret.append(RewriteStep(block_merge, name="Linearizer: Merge Blocks"))
ret.append(RewriteStep(pm_finalize, name="Linearizer: Finalize"))
return ret

View File

@@ -14,10 +14,9 @@ from tinygrad.helpers import all_same, colored, ansilen, dedup, getenv, prod, ro
from tinygrad.helpers import DEBUG, TC_SELECT, TC_OPT, AMX, CAPTURE_PROCESS_REPLAY
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import strides_for_shape
from tinygrad.codegen.linearize import linearize_uop
from tinygrad.codegen.devectorizer import full_graph_rewrite
from tinygrad.codegen.lowerer import rewrite_shapetracker_with_index, get_contraction
from tinygrad.codegen.lowerer import get_contraction
from tinygrad.engine.grouper import view_left
from tinygrad.codegen.flow import get_rewrites_for_renderer, apply_rewrites
class KernelOptError(Exception): pass
@@ -528,7 +527,7 @@ class Kernel:
return ret
fixed_ast = fixup_ast(self.ast)
del fixup_ast
return graph_rewrite(fixed_ast, view_left)
return graph_rewrite(fixed_ast, view_left, name="fixup optimized AST")
# **** this is the lowerer ****
@@ -554,7 +553,8 @@ class Kernel:
#if __debug__: type_verify(list(modified_ast.toposort()), shape_spec)
try:
self.uops:list[UOp] = linearize_uop(full_graph_rewrite(rewrite_shapetracker_with_index(modified_ast, self.opts), self.opts))
rewrite_list = get_rewrites_for_renderer(self.opts)
self.uops:list[UOp] = list(apply_rewrites(modified_ast, rewrite_list).arg.lst)
except RuntimeError:
print("***** LINEARIZE FAILURE *****")
print(f"ast = {self.ast}")

View File

@@ -52,7 +52,7 @@ def disp(y:UOp) -> str:
@dataclass(frozen=True, eq=False)
class BasicBlock2:
lst: tuple[UOp, ...]
ctx: tuple[UOp, ...]
ctx: tuple[UOp, ...] = ()
end: UOp|None = None
cnt: int = 0
child_ctx: tuple[UOp, ...]|None = None
@@ -162,10 +162,32 @@ def make_block_bottom_up(ctx:BlockContext, x:UOp):
bb = BasicBlock2(tuple(lst), ctx=current_ctx, cnt=child_count, child_ctx=child_ctx)
return UOp(Ops.BLOCK, src=tuple(srcs), arg=bb)
block_create = PatternMatcher([(
UPat(GroupOp.All-DONT_PLACE_IN_BLOCK.union({Ops.BLOCK, Ops.BLOCKEND}), name="x"), make_block_bottom_up)
block_create = PatternMatcher([
(UPat(GroupOp.All-DONT_PLACE_IN_BLOCK.union({Ops.BLOCK, Ops.BLOCKEND}), name="x"), make_block_bottom_up),
])
# ***** blockend merging ****
def merge_blockends(sink:UOp) -> UOp|None:
# only run on the final BLOCK with the SINK in it
if sink.arg.lst[-1].op is not Ops.SINK: return None
# combine matching BLOCKENDS, the keys of this dictionary are the RANGE UOps, values are the BLOCKENDs
blockends_to_arg: dict[UOp, list[UOp]] = {}
for be in sink.toposort():
if be.op is Ops.BLOCKEND: blockends_to_arg.setdefault(be.arg.end, []).append(be)
new_forks = {}
for k,v in blockends_to_arg.items():
# NOTE: if any BLOCKEND is the parent of any other with the same arg, this algo fails
if len(v) > 1:
bb = BasicBlock2(v[0].arg.lst, _sort_ctx(flatten([y.arg.ctx for y in v])), k, cnt=sum(y.arg.cnt for y in v))
out = UOp(Ops.BLOCKEND, src=tuple(flatten([x.src for x in v])), arg=bb)
# NOTE: bb.ctx != u.arg.ctx can cause problems here
for u in v: new_forks[u] = out
if len(new_forks) == 0: return None
return sink.substitute(new_forks)
pm_blockend_merge = PatternMatcher([(UPat(Ops.BLOCK, name="sink"), merge_blockends)])
# ***** block merging ****
def merge_block(x:UOp):
@@ -209,6 +231,19 @@ block_merge = PatternMatcher([
# ****** finalize ******
def finalize(sink:UOp) -> UOp:
if sink.op is not Ops.BLOCK or not all(x.op in DONT_PLACE_IN_BLOCK for x in sink.src):
raise RuntimeError("linearize failure")
# place the early things
lst = sorted(dedup(sink.src), key=lambda x: x.tuplize) + list(sink.arg.lst)
if __debug__: type_verify(lst)
return UOp(Ops.BLOCKFINAL, arg=BasicBlock2(tuple(lst)))
pm_finalize = PatternMatcher([(UPat(Ops.BLOCK, name="sink"), finalize)])
def linearize_uop(sink:UOp, skip_check:bool=not __debug__) -> list[UOp]:
assert sink.op is Ops.SINK, f"sink isn't sink, it's {sink.op}"
@@ -218,28 +253,16 @@ def linearize_uop(sink:UOp, skip_check:bool=not __debug__) -> list[UOp]:
# wrap all uops in blocks, already reordered
sink = graph_rewrite(sink, block_create, ctx=ctx, name="Linearizer: Create Blocks", bottom_up=True)
# combine matching BLOCKENDS, the keys of this dictionary are the RANGE UOps, values are the BLOCKENDs
blockends_to_arg: dict[UOp, list[UOp]] = {}
for be in sink.toposort():
if be.op is Ops.BLOCKEND: blockends_to_arg.setdefault(be.arg.end, []).append(be)
new_forks = {}
for k,v in blockends_to_arg.items():
# NOTE: if any BLOCKEND is the parent of any other with the same arg, this algo fails
if len(v) > 1:
bb = BasicBlock2(v[0].arg.lst, _sort_ctx(flatten([y.arg.ctx for y in v])), k, cnt=sum(y.arg.cnt for y in v))
out = UOp(Ops.BLOCKEND, src=tuple(flatten([x.src for x in v])), arg=bb)
# NOTE: bb.ctx != u.arg.ctx can cause problems here
for u in v: new_forks[u] = out
sink = sink.substitute(new_forks)
# merge blockends
sink = graph_rewrite(sink, pm_blockend_merge, name="Linearizer: Merge Blockends")
# merge blocks
sink = graph_rewrite(sink, block_merge, name="Linearizer: Merge Blocks")
if sink.op is not Ops.BLOCK or not all(x.op in DONT_PLACE_IN_BLOCK for x in sink.src): raise RuntimeError("linearize failure")
# place the early things
lst = sorted(dedup(sink.src), key=lambda x: x.tuplize) + list(sink.arg.lst)
# finalize
sink = graph_rewrite(sink, pm_finalize, name="Linearizer: Finalize")
# sanity checks (NOTE: these can cause things to be skipped in BEAM)
if not skip_check: type_verify(lst)
from tinygrad.ops import print_uops
print_uops(sink.arg.lst)
return lst
return list(sink.arg.lst)

View File

@@ -1,6 +1,6 @@
from typing import Any, Callable
import itertools, inspect, functools, types
from tinygrad.helpers import partition, dedup
from tinygrad.helpers import partition, dedup, Context
from tinygrad.ops import UPat, UPatAny, UOp, Ops, PatternMatcher, graph_rewrite, deconstruct_function
class UPatCompileError(Exception): pass
@@ -139,10 +139,12 @@ def _final_render(x:UOp, has_ctx:bool, depth=1) -> list[str]:
def _get_code(self:UPat, has_ctx:bool):
ret = _get_clause(self, UOp(Ops.NOOP, arg="uop"))
try:
ret = graph_rewrite(ret, pm_proc, name="process UPat")
dyn_lookup: dict[str, Any] = {}
out = graph_rewrite(ret, pm_renderer, ctx=dyn_lookup, name="compile UPat")
rendered = _final_render(out, has_ctx)
# TODO: this should be tracked in a "system" rewrite, not untracked or tracked with kernel
with Context(TRACK_MATCH_STATS=0):
ret = graph_rewrite(ret, pm_proc, name="process UPat")
dyn_lookup: dict[str, Any] = {}
out = graph_rewrite(ret, pm_renderer, ctx=dyn_lookup, name="compile UPat")
rendered = _final_render(out, has_ctx)
except UPatCompileError:
#print("FAILED", self, self.location)
return None