from typing import cast from dataclasses import replace import itertools from tinygrad.helpers import DISABLE_FAST_IDIV, TRANSCENDENTAL, SPEC, DEBUG, VIZ, IMAGE, NOOPT, EMULATED_DTYPES, NOLOCALS, USE_TC from tinygrad.helpers import ALLOW_TF32, TracingKey, Context, panic from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype, Ops, UPat, track_rewrites, KernelInfo, ProgramInfo from tinygrad.uop.render import pyrender from tinygrad.uop.spec import type_verify, spec_tensor, spec_program from tinygrad.renderer import Renderer, Estimates from tinygrad.renderer.isa import ISARenderer, IselContext, PreRegAllocContext from tinygrad.dtype import dtypes # import all pattern matchers here from tinygrad.codegen.gpudims import pm_add_gpudims from tinygrad.uop.symbolic import sym, symbolic_simple, gep_pushing, symbolic, pm_move_where_on_load from tinygrad.uop.decompositions import get_late_rewrite_patterns, get_transcendental_patterns, pm_dtype_decomps from tinygrad.codegen.late.expander import expander, pm_pre_expander, pm_group_for_reduce from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_indexing, devectorize_buf_and_index, devectorize_alu, pm_reduce, \ ReduceContext, correct_load_store, pm_render, pm_add_loads, pm_make_images from tinygrad.codegen.opt.postrange import apply_opts from tinygrad.codegen.late.gater import pm_move_gates_from_index from tinygrad.codegen.simplify import pm_simplify_ranges, pm_flatten_range, pm_split_ranges, pm_load_collapse from tinygrad.schedule.rangeify import pm_add_buffers_local, rangeify_codegen, pm_mops, pm_syntactic_sugar, pm_store_ranges from tinygrad.codegen.late.linearizer import CFGContext, pm_split_ends, pm_add_control_flow, linearize from tinygrad.codegen.late.regalloc import LinearScanRegallocContext, pm_regalloc_rewrite def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp: if VIZ: graph_rewrite(ast, PatternMatcher([]), name="View Base AST") if DEBUG >= 5: print(pyrender(ast)) if SPEC: type_verify(ast, spec_tensor) # preprocess sink = graph_rewrite(ast, pm_mops+pm_syntactic_sugar+pm_store_ranges, ctx=itertools.count(1000), name="early movement ops", bottom_up=True) # first we optimize if optimize: # collapse loads reduce (indexing by a tensor) sink = graph_rewrite(sink, pm_load_collapse, name="load collapse") # split ranges sink = graph_rewrite(sink, pm_split_ranges+pm_flatten_range, ctx={}, name="split ranges") # symbolic (NOTE: this is a requirement for pm_simplify_ranges to be correct) sink = graph_rewrite(sink, sym+pm_flatten_range, name="initial symbolic") # optimize (schedule) the AST sink = graph_rewrite(sink, pm_flatten_range+pm_simplify_ranges, ctx={}, name="simplify ranges") # do postrange optimization, BEAM or hand_coded_optimizations sink = apply_opts(sink, ren, beam=ast.arg.beam) # ** expander (expand_rewrite) ** sink = graph_rewrite(sink, sym+pm_move_where_on_load, name="postopt symbolic") # expand sink = graph_rewrite(sink, sym+pm_pre_expander+pm_group_for_reduce+expander, name="expander") # add locals sink = graph_rewrite(sink, pm_add_buffers_local+rangeify_codegen, ctx=itertools.count(0), name="add local buffers") # ** devectorizer (full_graph_rewrite) ** # remove reduce sink = graph_rewrite(sink, pm_reduce+gep_pushing, ctx=ReduceContext(), name="remove_reduce") # add gpu dims (late). this works after devectorize, but it's faster here sink = graph_rewrite(sink, pm_add_gpudims, ctx=ren, name="add gpudims") # **** optimizations are done, now we lower to actual code **** # add loads sink = graph_rewrite(sink, pm_add_loads, name="** add loads (code)") # create image buffers if IMAGE and ren.target.device in {"QCOM", "CL", "PYTHON", "NULL"}: sink = graph_rewrite(sink, pm_make_images, name="create image buffers", bottom_up=True, ctx=ren.target.arch) # devectorize sink = graph_rewrite(sink, sym+devectorize_alu+devectorize_buf_and_index+load_store_folding+correct_load_store+load_store_indexing, ctx=ren, name="devectorize") # lower the index dtype to a concrete int sink = graph_rewrite(sink, pm_lower_index_dtype+load_store_indexing+gep_pushing, name="lower all index dtypes") sink = graph_rewrite(sink, symbolic, name="post index symbolic") # optional pre matcher if ren.pre_matcher is not None: sink = graph_rewrite(sink, ren.pre_matcher, name="pre_matcher") # decompositions supported_ops = tuple(ren.code_for_op.keys()) pm_decomp = symbolic_simple+get_late_rewrite_patterns(supported_ops, bool(DISABLE_FAST_IDIV)) pm_transcendental = symbolic_simple+get_transcendental_patterns(supported_ops, TRANSCENDENTAL>=2) sink = graph_rewrite(sink, pm_decomp, ctx=ren, name="decompositions") sink = graph_rewrite(sink, pm_dtype_decomps, ctx=(set(), ren), name="decomp dtypes") sink = graph_rewrite(sink, pm_transcendental, name="transcendental") # move gates from unrenderable INVALID where sink = graph_rewrite(sink, pm_move_gates_from_index, name="move gates from index") # final rules for the renderer (without sym) extra_matcher = ren.extra_matcher if ren.extra_matcher is not None else PatternMatcher([]) pm_final_rewrite = pm_decomp+pm_render+extra_matcher+pm_split_ends sink = graph_rewrite(sink, pm_final_rewrite, ctx=ren, name="final rewrite") # this was the linearizer sink = graph_rewrite(sink, pm_add_control_flow, ctx=CFGContext(sink), name="add control flow", bottom_up=True) if VIZ: graph_rewrite(sink, PatternMatcher([]), name="View Output AST") # return the rewritten sink return sink # inject IF/ENDIF. only needed if device doesn't support gated stores pm_linearize_cleanups = PatternMatcher([ # if statements are not allowed in the graph (UPat((Ops.IF, Ops.ENDIF)), lambda: panic(RuntimeError, "if not allowed in graph")), # gated STORE becomes IF-STORE-ENDIF. this is the only use of IF-ENDIF (UPat(Ops.STORE, name="u", src=(UPat(Ops.INDEX).or_casted(), UPat(), UPat(name="gate", dtype=dtypes.bool))), lambda u, gate: (u, [mif:=UOp(Ops.IF, src=(gate, u.src[0])), u, UOp(Ops.ENDIF, src=(mif,))])) ]) # requires lst be toposorted. like graph rewrite, but for lines def line_rewrite(lst:list[UOp], pm:PatternMatcher, ctx=None) -> list[UOp]: newlst = [] replaced: dict[UOp, UOp] = {} for u in lst: nu = u.replace(src=tuple([replaced.get(x, x) for x in u.src])) ret: tuple[UOp, list[UOp]] = cast(tuple[UOp, list[UOp]]|None, pm.rewrite(nu, ctx)) or (nu, [nu]) replaced[u] = ret[0] newlst.extend(ret[1]) return newlst def do_linearize(ctx:Renderer, prg:UOp, sink:UOp) -> UOp: if DEBUG >= 3 and sink.arg.applied_opts: print(f"{sink.arg.function_name:<25} opts: {sink.arg.applied_opts}") lst = line_rewrite(linearize(sink), pm_linearize_cleanups) if SPEC: type_verify(lst, spec_program) # isa renderers need to allocate registers if isinstance(ctx, ISARenderer): if ctx.pre_regalloc_matcher is not None: lst = line_rewrite(lst, ctx.pre_regalloc_matcher, PreRegAllocContext()) regalloc_ctx = LinearScanRegallocContext(lst, ctx) lst = line_rewrite(lst, pm_regalloc_rewrite, regalloc_ctx) lst = line_rewrite(lst, ctx.post_regalloc_matcher, regalloc_ctx) if DEBUG >= 4: print(ctx.asm_str(lst, sink.arg.function_name)) return prg.replace(src=prg.src + (UOp(Ops.LINEAR, src=tuple(lst)),)) def do_estimates(prg:UOp, sink:UOp, lin:UOp) -> UOp|None: if sink.arg.estimates is not None: return None return prg.replace(src=(sink.replace(arg=replace(sink.arg, estimates=Estimates.from_uops(lin.src, ignore_indexing=True))),)+prg.src[1:]) def do_assemble(ctx:Renderer, prg:UOp, lin:UOp) -> UOp: src = "\n".join(str(u.arg) for u in lin.src) if DEBUG >= 4: print(src) binary = ctx.asm(prg, lin) return prg.replace(src=prg.src[:3]+(UOp(Ops.SOURCE, arg=src), UOp(Ops.BINARY, arg=binary))) def do_render(ctx:Renderer, prg:UOp, lin:UOp) -> UOp: src = ctx.render(list(lin.src)) new_arg = replace(prg.arg, aux=tuple(ctx.aux(list(lin.src)))) if ctx.has_aux else prg.arg return prg.replace(src=prg.src + (UOp(Ops.SOURCE, arg=src),), arg=new_arg) def do_compile(ctx:Renderer, prg:UOp, source:UOp) -> UOp|None: if DEBUG >= 4: print(source.arg) lib = ctx.compiler.compile_cached(source.arg) if DEBUG >= 7: ctx.compiler.disassemble(lib) return prg.replace(src=prg.src + (UOp(Ops.BINARY, arg=lib),)) pm_to_program = PatternMatcher([ (UPat(Ops.PROGRAM, src=(UPat(Ops.SINK, name="sink"), UPat(Ops.DEVICE)), name="prg"), do_linearize), (UPat(Ops.PROGRAM, src=(UPat(Ops.SINK, name="sink"), UPat(Ops.DEVICE), UPat(Ops.LINEAR, name="lin")), name="prg"), do_estimates), (UPat(Ops.PROGRAM, src=(UPat(), UPat(Ops.DEVICE), UPat(Ops.LINEAR, src=UPat(Ops.INS), name="lin")), name="prg"), do_assemble), (UPat(Ops.PROGRAM, src=(UPat(), UPat(Ops.DEVICE), UPat(Ops.LINEAR, name="lin")), name="prg"), do_render), (UPat(Ops.PROGRAM, src=(UPat(), UPat(Ops.DEVICE), UPat(Ops.LINEAR), UPat(Ops.SOURCE, name="source")), name="prg"), do_compile), ]) @track_rewrites(name=lambda ast,renderer,ret,**kwargs: TracingKey(ret.src[0].arg.name,(ret.src[0].arg.function_name, ast), ret=renderer), replay=True) @Context(ALLOW_DEVICE_USAGE=0) def do_to_program(ast:UOp, renderer:Renderer) -> UOp: """ Transform an AST into a compiled PROGRAM. May trigger BEAM search. Args: ast: The Ops.SINK/Ops.PROGRAM rooted AST renderer: The renderer used to generate the code Returns: The Ops.PROGRAM with SINK/DEVICE/LINEAR/SOURCE/BINARY. """ if ast.op is Ops.PROGRAM: prg = ast elif ast.op is Ops.SINK: assert isinstance(ast.arg, KernelInfo), "requires KernelInfo on arg to to_program" full_sink = full_rewrite_to_sink(ast, renderer, optimize=ast.tag is None) prog_info = ProgramInfo.from_sink(full_sink) # instruction selection if isinstance(renderer, ISARenderer): full_sink = graph_rewrite(full_sink, renderer.pre_isel_matcher, ctx=itertools.count(-1, -1), name="pre instruction selection", bottom_up=True) full_sink = graph_rewrite(full_sink, renderer.isel_matcher, ctx=IselContext(full_sink), name="instruction selection", bottom_up=True) prg = UOp(Ops.PROGRAM, src=(full_sink, UOp(Ops.DEVICE, arg=renderer.target.device)), arg=prog_info) else: raise RuntimeError(f"can't call to_program on {ast.op}") if not isinstance(prg.arg, ProgramInfo): prg = prg.replace(arg=ProgramInfo.from_sink(prg.src[0])) prg = graph_rewrite(prg, pm_to_program, ctx=renderer, name="linearize/render") if VIZ: graph_rewrite(prg, PatternMatcher([]), name="View Program") return prg to_program_cache: dict[tuple, UOp] = {} def to_program(ast:UOp, renderer:Renderer) -> UOp: config = (NOOPT, EMULATED_DTYPES, NOLOCALS, USE_TC, IMAGE, DISABLE_FAST_IDIV, TRANSCENDENTAL, ALLOW_TF32) key = (ast.key, type(renderer), renderer.target, *[x.value for x in config]) if (prg:=to_program_cache.get(key)) is None: to_program_cache[key] = prg = do_to_program(ast, renderer) return prg