port x86 to new_style (fable slop) and now everything is new style (#16581)

* port x86 to new_style (fable slop)

* don't change ops

* port NIR to new_style (fable)

* lil cleanup

* fix tests, and remove new_style
This commit is contained in:
George Hotz
2026-06-11 21:09:34 -07:00
committed by GitHub
parent 762f50bd52
commit b8aec4cce7
12 changed files with 44 additions and 20 deletions

View File

@@ -21,9 +21,8 @@ def apply_rewrite_values(expr):
def evaluate_uop(uop, variables):
if uop.op == Ops.CONST:
return uop.arg
elif uop.op == Ops.DEFINE_VAR:
var_name = uop.arg[0]
return variables[var_name]
elif uop.op == Ops.DEFINE_VAR or (uop.op == Ops.PARAM and uop.arg.addrspace is None):
return variables[uop.expr]
elif uop.op in GroupOp.ALU:
src_values = [evaluate_uop(src, variables) for src in uop.src]
return exec_alu(uop.op, uop.dtype, src_values)

View File

@@ -406,7 +406,7 @@ class TestUOpGraph(unittest.TestCase):
vc = v+c2
out = vc+c4
uops = to_uops_list([out])
self.assertEqual(len(uops), 4) # +1 for SINK
self.assertEqual(len(uops), 5) # +1 for SINK, +1 for the PARAM shape STACK
out = uops[-2] # -2 to skip SINK
self.assertEqual(out.op, Ops.ADD)
self.assertEqual(out.src[1].op, Ops.CONST)

View File

@@ -955,7 +955,9 @@ class TestSymbolic(unittest.TestCase):
uops = get_uops(UOp(Ops.STORE, dtypes.void, (glbl.index(UOp.const(dtypes.int, 0), ptr=True), expr)).sink())
rewritten_uop = [uop for uop in uops if uop.op is Ops.STORE][0].src[1]
self.assertEqual(rewritten_uop, cond.where(a.cast(dtypes.half), b.cast(dtypes.half)))
# the vars are now scalar PARAMs
pvar = {u.expr: u for u in rewritten_uop.toposort() if u.op is Ops.PARAM}
self.assertEqual(rewritten_uop, (pvar['s']<2).where(pvar['a'].cast(dtypes.half), pvar['b'].cast(dtypes.half)))
def test_where_merge_branches(self):
cond1 = Variable("s", 0, 10) < 6

View File

@@ -126,11 +126,10 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
pm_final_rewrite = pm_decomp+pm_render+extra_matcher+pm_split_ends
sink = graph_rewrite(sink, pm_final_rewrite, ctx=ren, name="final rewrite")
if ren.new_style:
sink = graph_rewrite(sink, pm_index_is_shrink, name="index is shrink")
num_params = len([x for x in sink.toposort() if x.op is Ops.PARAM])
name_to_slot = {nm:num_params+i for i,nm in enumerate(sorted([x.arg[0] for x in sink.toposort() if x.op is Ops.DEFINE_VAR]))}
sink = graph_rewrite(sink, pm_remove_vec_dtypes, ctx=name_to_slot, name="transform to new style")
sink = graph_rewrite(sink, pm_index_is_shrink, name="index is shrink")
num_params = len([x for x in sink.toposort() if x.op is Ops.PARAM])
name_to_slot = {nm:num_params+i for i,nm in enumerate(sorted([x.arg[0] for x in sink.toposort() if x.op is Ops.DEFINE_VAR]))}
sink = graph_rewrite(sink, pm_remove_vec_dtypes, ctx=name_to_slot, name="transform to new style")
# this was the linearizer
sink = graph_rewrite(sink, pm_add_control_flow, ctx=CFGContext(sink), name="add control flow", bottom_up=True)

View File

@@ -4,7 +4,7 @@ from tinygrad.uop.ops import UOp, Ops, PatternMatcher, UPat
from tinygrad.renderer.isa import ISARenderer, Register
from tinygrad.dtype import dtypes, PtrDType
PSEUDO_OPS = {Ops.CONST, Ops.NOOP, Ops.AFTER, Ops.BARRIER, Ops.GROUP}
PSEUDO_OPS = {Ops.CONST, Ops.NOOP, Ops.AFTER, Ops.BARRIER, Ops.GROUP, Ops.STACK}
class LinearScanRegallocContext:
# returns the uop that defines the virtual register

View File

@@ -68,7 +68,6 @@ class Renderer:
pre_matcher: PatternMatcher|None = None
extra_matcher: PatternMatcher|None = None
code_for_op: dict[Ops, Callable] = {}
new_style: bool = False
compiler: Compiler = Compiler()

View File

@@ -119,7 +119,6 @@ def wmma_args(uops:list[UOp]):
return dedup((uop.arg[0], uop.arg[1], uop.arg[2], uop.dtype.scalar(), *(uop.arg[4:8])) for uop in uops if uop.op is Ops.WMMA)
class CStyleLanguage(Renderer):
new_style = True
kernel_typedef: str = "void"
buffer_prefix: str = ""
buffer_suffix: str = ""

View File

@@ -17,7 +17,7 @@ class IselContext:
def __init__(self, sink:UOp):
self.uses = consumer_map_from_toposort(sink.toposort())
self.reg_n = itertools.count()
arg_order = {Ops.PARAM: 0, Ops.DEFINE_VAR: 1, Ops.SPECIAL: 2}
arg_order = {Ops.PARAM: 0, Ops.SPECIAL: 1}
self.func_args = sorted([u for u in self.uses if u.op in arg_order], key=lambda k: (arg_order[k.op], k.arg))
def vreg(self, cons:tuple[Register, ...]|Register):

View File

@@ -171,6 +171,35 @@ extra_matcher = PatternMatcher([
(UPat(Ops.CMOD, src=(UPat.var("x"), UPat.var("y"))), lambda x,y: x - y * x.alu(Ops.CDIV, y)),
])
# ***** X86 new style -> x86 internal style (pointers, vec dtypes, GEP) *****
pm_x86_style = PatternMatcher([
# buffers are pointers, scalar PARAMs (variables) keep their shape src
(UPat(Ops.PARAM, name="x"), lambda x: x.replace(dtype=x.dtype.ptr(x.src[0].arg), src=()) \
if x.arg.addrspace is AddrSpace.GLOBAL and not isinstance(x.dtype, PtrDType) else None),
(UPat(Ops.BUFFER, name="x"), lambda x: x.replace(op=Ops.DEFINE_REG if x.arg.addrspace == AddrSpace.REG else Ops.DEFINE_LOCAL,
dtype=x.dtype.ptr(x.src[0].arg, x.arg.addrspace), src=(), arg=x.arg.slot)),
(UPat(Ops.AFTER, name="x"), lambda x: x.replace(dtype=x.src[0].dtype) if x.dtype != x.src[0].dtype else None),
# SHRINK is a vectorized INDEX
(UPat(Ops.SHRINK, src=(UPat.var("buf"), UPat.var("idx"), UPat.cvar("c"))), lambda buf,idx,c: buf.index(idx, ptr=True) \
.cast(buf.ptrdtype.base.vec(c.arg).ptr(size=buf.ptrdtype.size, addrspace=buf.ptrdtype.addrspace)) if isinstance(buf.dtype, PtrDType) else None),
# cast of a pointer is a noop in new style (any reinterpreting cast was absorbed into SHRINK)
(UPat(Ops.CAST, src=(UPat.var("y"),), name="x"), lambda x,y:
y if isinstance(y.dtype, PtrDType) and not isinstance(x.dtype, PtrDType) else None),
# INDEX on a pointer has pointer dtype, INDEX on a register value is a GEP
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat()), name="x"), lambda buf,x:
x.replace(dtype=buf.dtype) if isinstance(buf.dtype, PtrDType) and not isinstance(x.dtype, PtrDType) else None),
(UPat(Ops.INDEX, src=(UPat.var("y"), UPat.cvar("c")), name="x"), lambda y,c,x:
y.gep(c.arg) if not isinstance(y.dtype, PtrDType) and y.op not in {Ops.PARAM, Ops.BUFFER, Ops.AFTER} else None),
# restore vec dtypes from structure
(UPat(Ops.LOAD, src=(UPat(Ops.CAST, name="c"),), allow_any_len=True, name="x"), lambda x,c:
x.replace(dtype=x.dtype.scalar().vec(c.ptrdtype.base.count)) if isinstance(c.dtype, PtrDType) and c.ptrdtype.base.count > x.dtype.count else None),
(UPat(Ops.STACK, name="x"), lambda x: x.replace(dtype=x.dtype.scalar().vec(len(x.src))) if 1 < len(x.src) != x.dtype.count else None),
(UPat(GroupOp.ALU.union({Ops.CAST, Ops.BITCAST}), name="x"), lambda x: x.replace(dtype=x.dtype.scalar().vec(c)) \
if not isinstance(x.dtype, PtrDType) and not any(isinstance(s.dtype, PtrDType) for s in x.src) \
and (c:=max([s.dtype.count for s in x.src], default=1)) > x.dtype.count else None),
])
# ***** X86 pre instruction selection *****
def gated_load(ctx, base:UOp, idx:UOp, cast:UOp, alt:UOp, gate:UOp, x:UOp):
@@ -185,7 +214,7 @@ def gated_store(base:UOp, idx:UOp, cast:UOp, gate:UOp, val:UOp):
return ptr.cast(cast.dtype).store(val)
# these must be done in a separate matcher because they violate the spec
pre_isel_matcher = PatternMatcher([
pre_isel_matcher = pm_x86_style + PatternMatcher([
# zero extending scalar 32bit int is a noop
(UPat.var("y", dtypes.uint32).cast(dtypes.int64s, name="x"), lambda y,x: x.replace(op=Ops.NOOP) if y.dtype.count == 1 else None),
# cast between signed and unsigned int is a noop
@@ -379,7 +408,7 @@ isel_matcher = PatternMatcher([
x.replace(src=(x.ins(X86Ops.RET, src=x.src + tuple(def_reg(dtypes.uint64 if r in GPR else dtypes.float64.vec(2), r) for r in CALLEE_SAVED)),)) \
if not x.src or x.src[0].arg is not X86Ops.RET else None),
# function abi constraints
(UPat((Ops.PARAM, Ops.DEFINE_VAR, Ops.SPECIAL), name="x"), abi),
(UPat((Ops.PARAM, Ops.SPECIAL), name="x"), abi),
# these are treated the same for now
(UPat(Ops.DEFINE_REG, name="x"), lambda x:
x.replace(op=Ops.DEFINE_LOCAL, dtype=x.dtype.base.ptr(x.dtype.size, AddrSpace.LOCAL)) if isinstance(x.arg, int) else None),
@@ -406,7 +435,7 @@ isel_matcher = PatternMatcher([
a.ins(X86Ops.VBLENDVPD, src=(b, a, m.replace(dtype=m.src[0].dtype)))),
# in this case we have a mask producing comparison whose user expects a bool, so we convert to bool
(UPat(GroupOp.Comparison, dtypes.bool, (UPat.var("y", (dtypes.float32, dtypes.float64)), UPat()), name="x"), lambda y,x:
x.replace(dtype=y.dtype).bitcast(to_int(y.dtype)).bitwise_and(1).f(Ops.NOOP, dtype=dtypes.bool)),
UOp(Ops.AND, dt:=to_int(y.dtype), (x.replace(dtype=y.dtype).bitcast(dt), UOp.const(dt, 1))).f(Ops.NOOP, dtype=dtypes.bool)),
# conditional moves that use flags
(UPat(Ops.CMPLT, src=(UPat(dtype=dtypes.sints), UPat()), name="m").where(UPat.var("a"), UPat.var("b")), lambda m,a,b:
a.ins(X86Ops.CMOVL, src=(b, a, cmp(m)))),

View File

@@ -120,7 +120,6 @@ base_rewrite = PatternMatcher([
])
class LLVMRenderer(Renderer):
new_style = True
supports_float4 = True
abi: str | None
string_rewrite: PatternMatcher

View File

@@ -132,7 +132,6 @@ string_rewrite = PatternMatcher([
])
class PTXRenderer(Renderer):
new_style = True
suffix = "PTX"
global_max, local_max, shared_max = CUDARenderer.global_max, CUDARenderer.local_max, CUDARenderer.shared_max
tc_sm80 = [x for x in tc.cuda_sm80 if x.dtype_in in [dtypes.half, dtypes.float]]

View File

@@ -203,7 +203,6 @@ class PythonCompiler(Compiler):
def compile(self, src:str) -> bytes: return base64.b64decode(src)
class PythonRenderer(Renderer):
new_style = True
code_for_op = python_alu
compiler = PythonCompiler()