diff --git a/test/null/test_graph_rewrite.py b/test/null/test_graph_rewrite.py index fb9dac854a..10bf46e7b4 100644 --- a/test/null/test_graph_rewrite.py +++ b/test/null/test_graph_rewrite.py @@ -21,9 +21,8 @@ def apply_rewrite_values(expr): def evaluate_uop(uop, variables): if uop.op == Ops.CONST: return uop.arg - elif uop.op == Ops.DEFINE_VAR: - var_name = uop.arg[0] - return variables[var_name] + elif uop.op == Ops.DEFINE_VAR or (uop.op == Ops.PARAM and uop.arg.addrspace is None): + return variables[uop.expr] elif uop.op in GroupOp.ALU: src_values = [evaluate_uop(src, variables) for src in uop.src] return exec_alu(uop.op, uop.dtype, src_values) diff --git a/test/null/test_uop_graph.py b/test/null/test_uop_graph.py index f05528ce19..fac87f2203 100644 --- a/test/null/test_uop_graph.py +++ b/test/null/test_uop_graph.py @@ -406,7 +406,7 @@ class TestUOpGraph(unittest.TestCase): vc = v+c2 out = vc+c4 uops = to_uops_list([out]) - self.assertEqual(len(uops), 4) # +1 for SINK + self.assertEqual(len(uops), 5) # +1 for SINK, +1 for the PARAM shape STACK out = uops[-2] # -2 to skip SINK self.assertEqual(out.op, Ops.ADD) self.assertEqual(out.src[1].op, Ops.CONST) diff --git a/test/null/test_uop_symbolic.py b/test/null/test_uop_symbolic.py index b3fa4862df..90fb9e3b2c 100644 --- a/test/null/test_uop_symbolic.py +++ b/test/null/test_uop_symbolic.py @@ -955,7 +955,9 @@ class TestSymbolic(unittest.TestCase): uops = get_uops(UOp(Ops.STORE, dtypes.void, (glbl.index(UOp.const(dtypes.int, 0), ptr=True), expr)).sink()) rewritten_uop = [uop for uop in uops if uop.op is Ops.STORE][0].src[1] - self.assertEqual(rewritten_uop, cond.where(a.cast(dtypes.half), b.cast(dtypes.half))) + # the vars are now scalar PARAMs + pvar = {u.expr: u for u in rewritten_uop.toposort() if u.op is Ops.PARAM} + self.assertEqual(rewritten_uop, (pvar['s']<2).where(pvar['a'].cast(dtypes.half), pvar['b'].cast(dtypes.half))) def test_where_merge_branches(self): cond1 = Variable("s", 0, 10) < 6 diff --git a/tinygrad/codegen/__init__.py b/tinygrad/codegen/__init__.py index 1d8f972128..84020e046f 100644 --- a/tinygrad/codegen/__init__.py +++ b/tinygrad/codegen/__init__.py @@ -126,11 +126,10 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp: pm_final_rewrite = pm_decomp+pm_render+extra_matcher+pm_split_ends sink = graph_rewrite(sink, pm_final_rewrite, ctx=ren, name="final rewrite") - if ren.new_style: - sink = graph_rewrite(sink, pm_index_is_shrink, name="index is shrink") - num_params = len([x for x in sink.toposort() if x.op is Ops.PARAM]) - name_to_slot = {nm:num_params+i for i,nm in enumerate(sorted([x.arg[0] for x in sink.toposort() if x.op is Ops.DEFINE_VAR]))} - sink = graph_rewrite(sink, pm_remove_vec_dtypes, ctx=name_to_slot, name="transform to new style") + sink = graph_rewrite(sink, pm_index_is_shrink, name="index is shrink") + num_params = len([x for x in sink.toposort() if x.op is Ops.PARAM]) + name_to_slot = {nm:num_params+i for i,nm in enumerate(sorted([x.arg[0] for x in sink.toposort() if x.op is Ops.DEFINE_VAR]))} + sink = graph_rewrite(sink, pm_remove_vec_dtypes, ctx=name_to_slot, name="transform to new style") # this was the linearizer sink = graph_rewrite(sink, pm_add_control_flow, ctx=CFGContext(sink), name="add control flow", bottom_up=True) diff --git a/tinygrad/codegen/late/regalloc.py b/tinygrad/codegen/late/regalloc.py index deb993671a..5dd63709a1 100644 --- a/tinygrad/codegen/late/regalloc.py +++ b/tinygrad/codegen/late/regalloc.py @@ -4,7 +4,7 @@ from tinygrad.uop.ops import UOp, Ops, PatternMatcher, UPat from tinygrad.renderer.isa import ISARenderer, Register from tinygrad.dtype import dtypes, PtrDType -PSEUDO_OPS = {Ops.CONST, Ops.NOOP, Ops.AFTER, Ops.BARRIER, Ops.GROUP} +PSEUDO_OPS = {Ops.CONST, Ops.NOOP, Ops.AFTER, Ops.BARRIER, Ops.GROUP, Ops.STACK} class LinearScanRegallocContext: # returns the uop that defines the virtual register diff --git a/tinygrad/renderer/__init__.py b/tinygrad/renderer/__init__.py index 5fdcd47974..df2a8349b3 100644 --- a/tinygrad/renderer/__init__.py +++ b/tinygrad/renderer/__init__.py @@ -68,7 +68,6 @@ class Renderer: pre_matcher: PatternMatcher|None = None extra_matcher: PatternMatcher|None = None code_for_op: dict[Ops, Callable] = {} - new_style: bool = False compiler: Compiler = Compiler() diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 3b64361eb6..9d5e257545 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -119,7 +119,6 @@ def wmma_args(uops:list[UOp]): return dedup((uop.arg[0], uop.arg[1], uop.arg[2], uop.dtype.scalar(), *(uop.arg[4:8])) for uop in uops if uop.op is Ops.WMMA) class CStyleLanguage(Renderer): - new_style = True kernel_typedef: str = "void" buffer_prefix: str = "" buffer_suffix: str = "" diff --git a/tinygrad/renderer/isa/__init__.py b/tinygrad/renderer/isa/__init__.py index de214c76e3..88b8c372f6 100644 --- a/tinygrad/renderer/isa/__init__.py +++ b/tinygrad/renderer/isa/__init__.py @@ -17,7 +17,7 @@ class IselContext: def __init__(self, sink:UOp): self.uses = consumer_map_from_toposort(sink.toposort()) self.reg_n = itertools.count() - arg_order = {Ops.PARAM: 0, Ops.DEFINE_VAR: 1, Ops.SPECIAL: 2} + arg_order = {Ops.PARAM: 0, Ops.SPECIAL: 1} self.func_args = sorted([u for u in self.uses if u.op in arg_order], key=lambda k: (arg_order[k.op], k.arg)) def vreg(self, cons:tuple[Register, ...]|Register): diff --git a/tinygrad/renderer/isa/x86.py b/tinygrad/renderer/isa/x86.py index 5fe2ac1029..41cf706403 100644 --- a/tinygrad/renderer/isa/x86.py +++ b/tinygrad/renderer/isa/x86.py @@ -171,6 +171,35 @@ extra_matcher = PatternMatcher([ (UPat(Ops.CMOD, src=(UPat.var("x"), UPat.var("y"))), lambda x,y: x - y * x.alu(Ops.CDIV, y)), ]) +# ***** X86 new style -> x86 internal style (pointers, vec dtypes, GEP) ***** + +pm_x86_style = PatternMatcher([ + # buffers are pointers, scalar PARAMs (variables) keep their shape src + (UPat(Ops.PARAM, name="x"), lambda x: x.replace(dtype=x.dtype.ptr(x.src[0].arg), src=()) \ + if x.arg.addrspace is AddrSpace.GLOBAL and not isinstance(x.dtype, PtrDType) else None), + (UPat(Ops.BUFFER, name="x"), lambda x: x.replace(op=Ops.DEFINE_REG if x.arg.addrspace == AddrSpace.REG else Ops.DEFINE_LOCAL, + dtype=x.dtype.ptr(x.src[0].arg, x.arg.addrspace), src=(), arg=x.arg.slot)), + (UPat(Ops.AFTER, name="x"), lambda x: x.replace(dtype=x.src[0].dtype) if x.dtype != x.src[0].dtype else None), + # SHRINK is a vectorized INDEX + (UPat(Ops.SHRINK, src=(UPat.var("buf"), UPat.var("idx"), UPat.cvar("c"))), lambda buf,idx,c: buf.index(idx, ptr=True) \ + .cast(buf.ptrdtype.base.vec(c.arg).ptr(size=buf.ptrdtype.size, addrspace=buf.ptrdtype.addrspace)) if isinstance(buf.dtype, PtrDType) else None), + # cast of a pointer is a noop in new style (any reinterpreting cast was absorbed into SHRINK) + (UPat(Ops.CAST, src=(UPat.var("y"),), name="x"), lambda x,y: + y if isinstance(y.dtype, PtrDType) and not isinstance(x.dtype, PtrDType) else None), + # INDEX on a pointer has pointer dtype, INDEX on a register value is a GEP + (UPat(Ops.INDEX, src=(UPat.var("buf"), UPat()), name="x"), lambda buf,x: + x.replace(dtype=buf.dtype) if isinstance(buf.dtype, PtrDType) and not isinstance(x.dtype, PtrDType) else None), + (UPat(Ops.INDEX, src=(UPat.var("y"), UPat.cvar("c")), name="x"), lambda y,c,x: + y.gep(c.arg) if not isinstance(y.dtype, PtrDType) and y.op not in {Ops.PARAM, Ops.BUFFER, Ops.AFTER} else None), + # restore vec dtypes from structure + (UPat(Ops.LOAD, src=(UPat(Ops.CAST, name="c"),), allow_any_len=True, name="x"), lambda x,c: + x.replace(dtype=x.dtype.scalar().vec(c.ptrdtype.base.count)) if isinstance(c.dtype, PtrDType) and c.ptrdtype.base.count > x.dtype.count else None), + (UPat(Ops.STACK, name="x"), lambda x: x.replace(dtype=x.dtype.scalar().vec(len(x.src))) if 1 < len(x.src) != x.dtype.count else None), + (UPat(GroupOp.ALU.union({Ops.CAST, Ops.BITCAST}), name="x"), lambda x: x.replace(dtype=x.dtype.scalar().vec(c)) \ + if not isinstance(x.dtype, PtrDType) and not any(isinstance(s.dtype, PtrDType) for s in x.src) \ + and (c:=max([s.dtype.count for s in x.src], default=1)) > x.dtype.count else None), +]) + # ***** X86 pre instruction selection ***** def gated_load(ctx, base:UOp, idx:UOp, cast:UOp, alt:UOp, gate:UOp, x:UOp): @@ -185,7 +214,7 @@ def gated_store(base:UOp, idx:UOp, cast:UOp, gate:UOp, val:UOp): return ptr.cast(cast.dtype).store(val) # these must be done in a separate matcher because they violate the spec -pre_isel_matcher = PatternMatcher([ +pre_isel_matcher = pm_x86_style + PatternMatcher([ # zero extending scalar 32bit int is a noop (UPat.var("y", dtypes.uint32).cast(dtypes.int64s, name="x"), lambda y,x: x.replace(op=Ops.NOOP) if y.dtype.count == 1 else None), # cast between signed and unsigned int is a noop @@ -379,7 +408,7 @@ isel_matcher = PatternMatcher([ x.replace(src=(x.ins(X86Ops.RET, src=x.src + tuple(def_reg(dtypes.uint64 if r in GPR else dtypes.float64.vec(2), r) for r in CALLEE_SAVED)),)) \ if not x.src or x.src[0].arg is not X86Ops.RET else None), # function abi constraints - (UPat((Ops.PARAM, Ops.DEFINE_VAR, Ops.SPECIAL), name="x"), abi), + (UPat((Ops.PARAM, Ops.SPECIAL), name="x"), abi), # these are treated the same for now (UPat(Ops.DEFINE_REG, name="x"), lambda x: x.replace(op=Ops.DEFINE_LOCAL, dtype=x.dtype.base.ptr(x.dtype.size, AddrSpace.LOCAL)) if isinstance(x.arg, int) else None), @@ -406,7 +435,7 @@ isel_matcher = PatternMatcher([ a.ins(X86Ops.VBLENDVPD, src=(b, a, m.replace(dtype=m.src[0].dtype)))), # in this case we have a mask producing comparison whose user expects a bool, so we convert to bool (UPat(GroupOp.Comparison, dtypes.bool, (UPat.var("y", (dtypes.float32, dtypes.float64)), UPat()), name="x"), lambda y,x: - x.replace(dtype=y.dtype).bitcast(to_int(y.dtype)).bitwise_and(1).f(Ops.NOOP, dtype=dtypes.bool)), + UOp(Ops.AND, dt:=to_int(y.dtype), (x.replace(dtype=y.dtype).bitcast(dt), UOp.const(dt, 1))).f(Ops.NOOP, dtype=dtypes.bool)), # conditional moves that use flags (UPat(Ops.CMPLT, src=(UPat(dtype=dtypes.sints), UPat()), name="m").where(UPat.var("a"), UPat.var("b")), lambda m,a,b: a.ins(X86Ops.CMOVL, src=(b, a, cmp(m)))), diff --git a/tinygrad/renderer/llvmir.py b/tinygrad/renderer/llvmir.py index c404861695..80f3027a4f 100644 --- a/tinygrad/renderer/llvmir.py +++ b/tinygrad/renderer/llvmir.py @@ -120,7 +120,6 @@ base_rewrite = PatternMatcher([ ]) class LLVMRenderer(Renderer): - new_style = True supports_float4 = True abi: str | None string_rewrite: PatternMatcher diff --git a/tinygrad/renderer/ptx.py b/tinygrad/renderer/ptx.py index 2ff8162bef..9258670f20 100644 --- a/tinygrad/renderer/ptx.py +++ b/tinygrad/renderer/ptx.py @@ -132,7 +132,6 @@ string_rewrite = PatternMatcher([ ]) class PTXRenderer(Renderer): - new_style = True suffix = "PTX" global_max, local_max, shared_max = CUDARenderer.global_max, CUDARenderer.local_max, CUDARenderer.shared_max tc_sm80 = [x for x in tc.cuda_sm80 if x.dtype_in in [dtypes.half, dtypes.float]] diff --git a/tinygrad/runtime/ops_python.py b/tinygrad/runtime/ops_python.py index dc649377c8..2a0fca6be4 100644 --- a/tinygrad/runtime/ops_python.py +++ b/tinygrad/runtime/ops_python.py @@ -203,7 +203,6 @@ class PythonCompiler(Compiler): def compile(self, src:str) -> bytes: return base64.b64decode(src) class PythonRenderer(Renderer): - new_style = True code_for_op = python_alu compiler = PythonCompiler()