mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
prebuild all rewrites [pr] (#10154)
* prebuild all rewrites [pr] * fix that * tests pass with linearizer
This commit is contained in:
4
.github/workflows/test.yml
vendored
4
.github/workflows/test.yml
vendored
@@ -368,8 +368,8 @@ jobs:
|
||||
# run: NULL=1 python3 examples/llama.py --gen 1 --size 7B --shard 4 --prompt "Hello." --count 3 --temperature 0 --timing
|
||||
- name: Run GC tests
|
||||
run: PYTHONPATH="." python test/external/external_uop_gc.py
|
||||
- name: Repo line count < 12800 lines
|
||||
run: MAX_LINE_COUNT=12800 python sz.py
|
||||
- name: Repo line count < 13000 lines
|
||||
run: MAX_LINE_COUNT=13000 python sz.py
|
||||
|
||||
fuzzing:
|
||||
name: Fuzzing
|
||||
|
||||
@@ -450,5 +450,6 @@ def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp:
|
||||
if opts is not None and opts.pre_matcher is not None: sink = graph_rewrite(sink, opts.pre_matcher)
|
||||
|
||||
# final rules for the renderer (without sym)
|
||||
sink = graph_rewrite(sink, symbolic_simple+get_late_rewrite_patterns(supported_ops, TRANSCENDENTAL>=2)+pm_render+extra_matcher, ctx=opts)
|
||||
sink = graph_rewrite(sink, symbolic_simple+get_late_rewrite_patterns(supported_ops, TRANSCENDENTAL>=2)+pm_render+extra_matcher,
|
||||
ctx=opts, name="final rewrite")
|
||||
return sink
|
||||
|
||||
69
tinygrad/codegen/flow.py
Normal file
69
tinygrad/codegen/flow.py
Normal file
@@ -0,0 +1,69 @@
|
||||
from typing import Any, Callable
|
||||
import functools
|
||||
from dataclasses import dataclass
|
||||
from tinygrad.helpers import QUANTIZE, DEVECTORIZE, TRANSCENDENTAL
|
||||
from tinygrad.ops import PatternMatcher, graph_rewrite, UOp
|
||||
from tinygrad.renderer import Renderer
|
||||
|
||||
# import all pattern matchers here
|
||||
from tinygrad.codegen.lowerer import pm_quant, pm_lowerer, get_index
|
||||
from tinygrad.codegen.symbolic import sym, symbolic_simple, gep_pushing
|
||||
from tinygrad.codegen.expander import migrate_indexing, pm_store_ignore, pm_move_ignore, pm_delete_ignore, expander
|
||||
from tinygrad.codegen.devectorizer import load_store_folding, load_store_indexing, devectorize, \
|
||||
pm_reduce, ReduceContext, correct_load_store, pm_render, get_late_rewrite_patterns
|
||||
from tinygrad.codegen.linearize import block_create, pm_blockend_merge, block_merge, pm_finalize, BlockContext
|
||||
|
||||
@dataclass
|
||||
class RewriteStep:
|
||||
pm: PatternMatcher
|
||||
ctx: Callable[[UOp], Any]|None = None
|
||||
name: str|None = None
|
||||
bottom_up: bool = False
|
||||
def __call__(self, sink:UOp):
|
||||
return graph_rewrite(sink, self.pm, ctx=self.ctx(sink) if self.ctx is not None else None, name=self.name, bottom_up=self.bottom_up)
|
||||
|
||||
def apply_rewrites(sink:UOp, rewrites:list[RewriteStep]): return functools.reduce(lambda x,f: f(x), rewrites, sink)
|
||||
|
||||
def get_rewrites_for_renderer(opts:Renderer, linearizer=True) -> list[RewriteStep]:
|
||||
# ** lowerer (rewrite_shapetracker_with_index) **
|
||||
ret: list[RewriteStep] = []
|
||||
if QUANTIZE and opts.device in {"CPU", "DSP"}: ret.append(RewriteStep(pm_quant, name="quantize"))
|
||||
ret.append(RewriteStep(pm_lowerer, lambda ast: get_index(ast, opts), name="lowerer"))
|
||||
|
||||
# ** expander (expand_rewrite) **
|
||||
ret.append(RewriteStep(sym+migrate_indexing, name="initial symbolic"))
|
||||
|
||||
# ignore (for masked stores)
|
||||
ret.append(RewriteStep(pm_store_ignore, name="store_ignore"))
|
||||
ret.append(RewriteStep(pm_move_ignore, name="move_ignore"))
|
||||
|
||||
# expand + remove surviving ignores
|
||||
ret.append(RewriteStep(pm_delete_ignore+sym+expander, name="expander"))
|
||||
|
||||
# ** devectorizer (full_graph_rewrite) **
|
||||
# remove reduce
|
||||
ret.append(RewriteStep(pm_reduce+gep_pushing, lambda _: ReduceContext(), name="remove_reduce"))
|
||||
|
||||
# devectorize (TODO: does this need opts?)
|
||||
if DEVECTORIZE >= 2: pm_devectorize = sym+load_store_folding+load_store_indexing
|
||||
elif DEVECTORIZE: pm_devectorize = sym+devectorize+load_store_folding+correct_load_store+load_store_indexing
|
||||
else: pm_devectorize = sym+load_store_folding+correct_load_store+load_store_indexing
|
||||
ret.append(RewriteStep(pm_devectorize, lambda _: opts, name="devectorize"))
|
||||
|
||||
supported_ops = tuple(opts.code_for_op.keys())
|
||||
extra_matcher = opts.extra_matcher if opts.extra_matcher is not None else PatternMatcher([])
|
||||
|
||||
# optional pre matcher
|
||||
if opts.pre_matcher is not None: ret.append(RewriteStep(opts.pre_matcher, name="pre_matcher"))
|
||||
|
||||
# final rules for the renderer (without sym)
|
||||
pm_final_rewrite = symbolic_simple+get_late_rewrite_patterns(supported_ops, TRANSCENDENTAL>=2)+pm_render+extra_matcher
|
||||
ret.append(RewriteStep(pm_final_rewrite, lambda _: opts, name="final rewrite"))
|
||||
|
||||
# ** linearizer **
|
||||
if linearizer:
|
||||
ret.append(RewriteStep(block_create, ctx=BlockContext.from_sink, name="Linearizer: Create Blocks", bottom_up=True))
|
||||
ret.append(RewriteStep(pm_blockend_merge, name="Linearizer: Merge Blockends"))
|
||||
ret.append(RewriteStep(block_merge, name="Linearizer: Merge Blocks"))
|
||||
ret.append(RewriteStep(pm_finalize, name="Linearizer: Finalize"))
|
||||
return ret
|
||||
@@ -14,10 +14,9 @@ from tinygrad.helpers import all_same, colored, ansilen, dedup, getenv, prod, ro
|
||||
from tinygrad.helpers import DEBUG, TC_SELECT, TC_OPT, AMX, CAPTURE_PROCESS_REPLAY
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import strides_for_shape
|
||||
from tinygrad.codegen.linearize import linearize_uop
|
||||
from tinygrad.codegen.devectorizer import full_graph_rewrite
|
||||
from tinygrad.codegen.lowerer import rewrite_shapetracker_with_index, get_contraction
|
||||
from tinygrad.codegen.lowerer import get_contraction
|
||||
from tinygrad.engine.grouper import view_left
|
||||
from tinygrad.codegen.flow import get_rewrites_for_renderer, apply_rewrites
|
||||
|
||||
class KernelOptError(Exception): pass
|
||||
|
||||
@@ -528,7 +527,7 @@ class Kernel:
|
||||
return ret
|
||||
fixed_ast = fixup_ast(self.ast)
|
||||
del fixup_ast
|
||||
return graph_rewrite(fixed_ast, view_left)
|
||||
return graph_rewrite(fixed_ast, view_left, name="fixup optimized AST")
|
||||
|
||||
# **** this is the lowerer ****
|
||||
|
||||
@@ -554,7 +553,8 @@ class Kernel:
|
||||
#if __debug__: type_verify(list(modified_ast.toposort()), shape_spec)
|
||||
|
||||
try:
|
||||
self.uops:list[UOp] = linearize_uop(full_graph_rewrite(rewrite_shapetracker_with_index(modified_ast, self.opts), self.opts))
|
||||
rewrite_list = get_rewrites_for_renderer(self.opts)
|
||||
self.uops:list[UOp] = list(apply_rewrites(modified_ast, rewrite_list).arg.lst)
|
||||
except RuntimeError:
|
||||
print("***** LINEARIZE FAILURE *****")
|
||||
print(f"ast = {self.ast}")
|
||||
|
||||
@@ -52,7 +52,7 @@ def disp(y:UOp) -> str:
|
||||
@dataclass(frozen=True, eq=False)
|
||||
class BasicBlock2:
|
||||
lst: tuple[UOp, ...]
|
||||
ctx: tuple[UOp, ...]
|
||||
ctx: tuple[UOp, ...] = ()
|
||||
end: UOp|None = None
|
||||
cnt: int = 0
|
||||
child_ctx: tuple[UOp, ...]|None = None
|
||||
@@ -162,10 +162,32 @@ def make_block_bottom_up(ctx:BlockContext, x:UOp):
|
||||
bb = BasicBlock2(tuple(lst), ctx=current_ctx, cnt=child_count, child_ctx=child_ctx)
|
||||
return UOp(Ops.BLOCK, src=tuple(srcs), arg=bb)
|
||||
|
||||
block_create = PatternMatcher([(
|
||||
UPat(GroupOp.All-DONT_PLACE_IN_BLOCK.union({Ops.BLOCK, Ops.BLOCKEND}), name="x"), make_block_bottom_up)
|
||||
block_create = PatternMatcher([
|
||||
(UPat(GroupOp.All-DONT_PLACE_IN_BLOCK.union({Ops.BLOCK, Ops.BLOCKEND}), name="x"), make_block_bottom_up),
|
||||
])
|
||||
|
||||
# ***** blockend merging ****
|
||||
|
||||
def merge_blockends(sink:UOp) -> UOp|None:
|
||||
# only run on the final BLOCK with the SINK in it
|
||||
if sink.arg.lst[-1].op is not Ops.SINK: return None
|
||||
# combine matching BLOCKENDS, the keys of this dictionary are the RANGE UOps, values are the BLOCKENDs
|
||||
blockends_to_arg: dict[UOp, list[UOp]] = {}
|
||||
for be in sink.toposort():
|
||||
if be.op is Ops.BLOCKEND: blockends_to_arg.setdefault(be.arg.end, []).append(be)
|
||||
new_forks = {}
|
||||
for k,v in blockends_to_arg.items():
|
||||
# NOTE: if any BLOCKEND is the parent of any other with the same arg, this algo fails
|
||||
if len(v) > 1:
|
||||
bb = BasicBlock2(v[0].arg.lst, _sort_ctx(flatten([y.arg.ctx for y in v])), k, cnt=sum(y.arg.cnt for y in v))
|
||||
out = UOp(Ops.BLOCKEND, src=tuple(flatten([x.src for x in v])), arg=bb)
|
||||
# NOTE: bb.ctx != u.arg.ctx can cause problems here
|
||||
for u in v: new_forks[u] = out
|
||||
if len(new_forks) == 0: return None
|
||||
return sink.substitute(new_forks)
|
||||
|
||||
pm_blockend_merge = PatternMatcher([(UPat(Ops.BLOCK, name="sink"), merge_blockends)])
|
||||
|
||||
# ***** block merging ****
|
||||
|
||||
def merge_block(x:UOp):
|
||||
@@ -209,6 +231,19 @@ block_merge = PatternMatcher([
|
||||
|
||||
# ****** finalize ******
|
||||
|
||||
def finalize(sink:UOp) -> UOp:
|
||||
if sink.op is not Ops.BLOCK or not all(x.op in DONT_PLACE_IN_BLOCK for x in sink.src):
|
||||
raise RuntimeError("linearize failure")
|
||||
|
||||
# place the early things
|
||||
lst = sorted(dedup(sink.src), key=lambda x: x.tuplize) + list(sink.arg.lst)
|
||||
|
||||
if __debug__: type_verify(lst)
|
||||
|
||||
return UOp(Ops.BLOCKFINAL, arg=BasicBlock2(tuple(lst)))
|
||||
|
||||
pm_finalize = PatternMatcher([(UPat(Ops.BLOCK, name="sink"), finalize)])
|
||||
|
||||
def linearize_uop(sink:UOp, skip_check:bool=not __debug__) -> list[UOp]:
|
||||
assert sink.op is Ops.SINK, f"sink isn't sink, it's {sink.op}"
|
||||
|
||||
@@ -218,28 +253,16 @@ def linearize_uop(sink:UOp, skip_check:bool=not __debug__) -> list[UOp]:
|
||||
# wrap all uops in blocks, already reordered
|
||||
sink = graph_rewrite(sink, block_create, ctx=ctx, name="Linearizer: Create Blocks", bottom_up=True)
|
||||
|
||||
# combine matching BLOCKENDS, the keys of this dictionary are the RANGE UOps, values are the BLOCKENDs
|
||||
blockends_to_arg: dict[UOp, list[UOp]] = {}
|
||||
for be in sink.toposort():
|
||||
if be.op is Ops.BLOCKEND: blockends_to_arg.setdefault(be.arg.end, []).append(be)
|
||||
new_forks = {}
|
||||
for k,v in blockends_to_arg.items():
|
||||
# NOTE: if any BLOCKEND is the parent of any other with the same arg, this algo fails
|
||||
if len(v) > 1:
|
||||
bb = BasicBlock2(v[0].arg.lst, _sort_ctx(flatten([y.arg.ctx for y in v])), k, cnt=sum(y.arg.cnt for y in v))
|
||||
out = UOp(Ops.BLOCKEND, src=tuple(flatten([x.src for x in v])), arg=bb)
|
||||
# NOTE: bb.ctx != u.arg.ctx can cause problems here
|
||||
for u in v: new_forks[u] = out
|
||||
sink = sink.substitute(new_forks)
|
||||
|
||||
# merge blockends
|
||||
sink = graph_rewrite(sink, pm_blockend_merge, name="Linearizer: Merge Blockends")
|
||||
|
||||
# merge blocks
|
||||
sink = graph_rewrite(sink, block_merge, name="Linearizer: Merge Blocks")
|
||||
if sink.op is not Ops.BLOCK or not all(x.op in DONT_PLACE_IN_BLOCK for x in sink.src): raise RuntimeError("linearize failure")
|
||||
|
||||
# place the early things
|
||||
lst = sorted(dedup(sink.src), key=lambda x: x.tuplize) + list(sink.arg.lst)
|
||||
# finalize
|
||||
sink = graph_rewrite(sink, pm_finalize, name="Linearizer: Finalize")
|
||||
|
||||
# sanity checks (NOTE: these can cause things to be skipped in BEAM)
|
||||
if not skip_check: type_verify(lst)
|
||||
from tinygrad.ops import print_uops
|
||||
print_uops(sink.arg.lst)
|
||||
|
||||
return lst
|
||||
return list(sink.arg.lst)
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Any, Callable
|
||||
import itertools, inspect, functools, types
|
||||
from tinygrad.helpers import partition, dedup
|
||||
from tinygrad.helpers import partition, dedup, Context
|
||||
from tinygrad.ops import UPat, UPatAny, UOp, Ops, PatternMatcher, graph_rewrite, deconstruct_function
|
||||
|
||||
class UPatCompileError(Exception): pass
|
||||
@@ -139,10 +139,12 @@ def _final_render(x:UOp, has_ctx:bool, depth=1) -> list[str]:
|
||||
def _get_code(self:UPat, has_ctx:bool):
|
||||
ret = _get_clause(self, UOp(Ops.NOOP, arg="uop"))
|
||||
try:
|
||||
ret = graph_rewrite(ret, pm_proc, name="process UPat")
|
||||
dyn_lookup: dict[str, Any] = {}
|
||||
out = graph_rewrite(ret, pm_renderer, ctx=dyn_lookup, name="compile UPat")
|
||||
rendered = _final_render(out, has_ctx)
|
||||
# TODO: this should be tracked in a "system" rewrite, not untracked or tracked with kernel
|
||||
with Context(TRACK_MATCH_STATS=0):
|
||||
ret = graph_rewrite(ret, pm_proc, name="process UPat")
|
||||
dyn_lookup: dict[str, Any] = {}
|
||||
out = graph_rewrite(ret, pm_renderer, ctx=dyn_lookup, name="compile UPat")
|
||||
rendered = _final_render(out, has_ctx)
|
||||
except UPatCompileError:
|
||||
#print("FAILED", self, self.location)
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user