mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-15 17:40:13 +08:00
port x86 to new_style (fable slop) and now everything is new style (#16581)
* port x86 to new_style (fable slop) * don't change ops * port NIR to new_style (fable) * lil cleanup * fix tests, and remove new_style
This commit is contained in:
@@ -21,9 +21,8 @@ def apply_rewrite_values(expr):
|
||||
def evaluate_uop(uop, variables):
|
||||
if uop.op == Ops.CONST:
|
||||
return uop.arg
|
||||
elif uop.op == Ops.DEFINE_VAR:
|
||||
var_name = uop.arg[0]
|
||||
return variables[var_name]
|
||||
elif uop.op == Ops.DEFINE_VAR or (uop.op == Ops.PARAM and uop.arg.addrspace is None):
|
||||
return variables[uop.expr]
|
||||
elif uop.op in GroupOp.ALU:
|
||||
src_values = [evaluate_uop(src, variables) for src in uop.src]
|
||||
return exec_alu(uop.op, uop.dtype, src_values)
|
||||
|
||||
@@ -406,7 +406,7 @@ class TestUOpGraph(unittest.TestCase):
|
||||
vc = v+c2
|
||||
out = vc+c4
|
||||
uops = to_uops_list([out])
|
||||
self.assertEqual(len(uops), 4) # +1 for SINK
|
||||
self.assertEqual(len(uops), 5) # +1 for SINK, +1 for the PARAM shape STACK
|
||||
out = uops[-2] # -2 to skip SINK
|
||||
self.assertEqual(out.op, Ops.ADD)
|
||||
self.assertEqual(out.src[1].op, Ops.CONST)
|
||||
|
||||
@@ -955,7 +955,9 @@ class TestSymbolic(unittest.TestCase):
|
||||
uops = get_uops(UOp(Ops.STORE, dtypes.void, (glbl.index(UOp.const(dtypes.int, 0), ptr=True), expr)).sink())
|
||||
rewritten_uop = [uop for uop in uops if uop.op is Ops.STORE][0].src[1]
|
||||
|
||||
self.assertEqual(rewritten_uop, cond.where(a.cast(dtypes.half), b.cast(dtypes.half)))
|
||||
# the vars are now scalar PARAMs
|
||||
pvar = {u.expr: u for u in rewritten_uop.toposort() if u.op is Ops.PARAM}
|
||||
self.assertEqual(rewritten_uop, (pvar['s']<2).where(pvar['a'].cast(dtypes.half), pvar['b'].cast(dtypes.half)))
|
||||
|
||||
def test_where_merge_branches(self):
|
||||
cond1 = Variable("s", 0, 10) < 6
|
||||
|
||||
@@ -126,11 +126,10 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
|
||||
pm_final_rewrite = pm_decomp+pm_render+extra_matcher+pm_split_ends
|
||||
sink = graph_rewrite(sink, pm_final_rewrite, ctx=ren, name="final rewrite")
|
||||
|
||||
if ren.new_style:
|
||||
sink = graph_rewrite(sink, pm_index_is_shrink, name="index is shrink")
|
||||
num_params = len([x for x in sink.toposort() if x.op is Ops.PARAM])
|
||||
name_to_slot = {nm:num_params+i for i,nm in enumerate(sorted([x.arg[0] for x in sink.toposort() if x.op is Ops.DEFINE_VAR]))}
|
||||
sink = graph_rewrite(sink, pm_remove_vec_dtypes, ctx=name_to_slot, name="transform to new style")
|
||||
sink = graph_rewrite(sink, pm_index_is_shrink, name="index is shrink")
|
||||
num_params = len([x for x in sink.toposort() if x.op is Ops.PARAM])
|
||||
name_to_slot = {nm:num_params+i for i,nm in enumerate(sorted([x.arg[0] for x in sink.toposort() if x.op is Ops.DEFINE_VAR]))}
|
||||
sink = graph_rewrite(sink, pm_remove_vec_dtypes, ctx=name_to_slot, name="transform to new style")
|
||||
|
||||
# this was the linearizer
|
||||
sink = graph_rewrite(sink, pm_add_control_flow, ctx=CFGContext(sink), name="add control flow", bottom_up=True)
|
||||
|
||||
@@ -4,7 +4,7 @@ from tinygrad.uop.ops import UOp, Ops, PatternMatcher, UPat
|
||||
from tinygrad.renderer.isa import ISARenderer, Register
|
||||
from tinygrad.dtype import dtypes, PtrDType
|
||||
|
||||
PSEUDO_OPS = {Ops.CONST, Ops.NOOP, Ops.AFTER, Ops.BARRIER, Ops.GROUP}
|
||||
PSEUDO_OPS = {Ops.CONST, Ops.NOOP, Ops.AFTER, Ops.BARRIER, Ops.GROUP, Ops.STACK}
|
||||
|
||||
class LinearScanRegallocContext:
|
||||
# returns the uop that defines the virtual register
|
||||
|
||||
@@ -68,7 +68,6 @@ class Renderer:
|
||||
pre_matcher: PatternMatcher|None = None
|
||||
extra_matcher: PatternMatcher|None = None
|
||||
code_for_op: dict[Ops, Callable] = {}
|
||||
new_style: bool = False
|
||||
|
||||
compiler: Compiler = Compiler()
|
||||
|
||||
|
||||
@@ -119,7 +119,6 @@ def wmma_args(uops:list[UOp]):
|
||||
return dedup((uop.arg[0], uop.arg[1], uop.arg[2], uop.dtype.scalar(), *(uop.arg[4:8])) for uop in uops if uop.op is Ops.WMMA)
|
||||
|
||||
class CStyleLanguage(Renderer):
|
||||
new_style = True
|
||||
kernel_typedef: str = "void"
|
||||
buffer_prefix: str = ""
|
||||
buffer_suffix: str = ""
|
||||
|
||||
@@ -17,7 +17,7 @@ class IselContext:
|
||||
def __init__(self, sink:UOp):
|
||||
self.uses = consumer_map_from_toposort(sink.toposort())
|
||||
self.reg_n = itertools.count()
|
||||
arg_order = {Ops.PARAM: 0, Ops.DEFINE_VAR: 1, Ops.SPECIAL: 2}
|
||||
arg_order = {Ops.PARAM: 0, Ops.SPECIAL: 1}
|
||||
self.func_args = sorted([u for u in self.uses if u.op in arg_order], key=lambda k: (arg_order[k.op], k.arg))
|
||||
|
||||
def vreg(self, cons:tuple[Register, ...]|Register):
|
||||
|
||||
@@ -171,6 +171,35 @@ extra_matcher = PatternMatcher([
|
||||
(UPat(Ops.CMOD, src=(UPat.var("x"), UPat.var("y"))), lambda x,y: x - y * x.alu(Ops.CDIV, y)),
|
||||
])
|
||||
|
||||
# ***** X86 new style -> x86 internal style (pointers, vec dtypes, GEP) *****
|
||||
|
||||
pm_x86_style = PatternMatcher([
|
||||
# buffers are pointers, scalar PARAMs (variables) keep their shape src
|
||||
(UPat(Ops.PARAM, name="x"), lambda x: x.replace(dtype=x.dtype.ptr(x.src[0].arg), src=()) \
|
||||
if x.arg.addrspace is AddrSpace.GLOBAL and not isinstance(x.dtype, PtrDType) else None),
|
||||
(UPat(Ops.BUFFER, name="x"), lambda x: x.replace(op=Ops.DEFINE_REG if x.arg.addrspace == AddrSpace.REG else Ops.DEFINE_LOCAL,
|
||||
dtype=x.dtype.ptr(x.src[0].arg, x.arg.addrspace), src=(), arg=x.arg.slot)),
|
||||
(UPat(Ops.AFTER, name="x"), lambda x: x.replace(dtype=x.src[0].dtype) if x.dtype != x.src[0].dtype else None),
|
||||
# SHRINK is a vectorized INDEX
|
||||
(UPat(Ops.SHRINK, src=(UPat.var("buf"), UPat.var("idx"), UPat.cvar("c"))), lambda buf,idx,c: buf.index(idx, ptr=True) \
|
||||
.cast(buf.ptrdtype.base.vec(c.arg).ptr(size=buf.ptrdtype.size, addrspace=buf.ptrdtype.addrspace)) if isinstance(buf.dtype, PtrDType) else None),
|
||||
# cast of a pointer is a noop in new style (any reinterpreting cast was absorbed into SHRINK)
|
||||
(UPat(Ops.CAST, src=(UPat.var("y"),), name="x"), lambda x,y:
|
||||
y if isinstance(y.dtype, PtrDType) and not isinstance(x.dtype, PtrDType) else None),
|
||||
# INDEX on a pointer has pointer dtype, INDEX on a register value is a GEP
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat()), name="x"), lambda buf,x:
|
||||
x.replace(dtype=buf.dtype) if isinstance(buf.dtype, PtrDType) and not isinstance(x.dtype, PtrDType) else None),
|
||||
(UPat(Ops.INDEX, src=(UPat.var("y"), UPat.cvar("c")), name="x"), lambda y,c,x:
|
||||
y.gep(c.arg) if not isinstance(y.dtype, PtrDType) and y.op not in {Ops.PARAM, Ops.BUFFER, Ops.AFTER} else None),
|
||||
# restore vec dtypes from structure
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.CAST, name="c"),), allow_any_len=True, name="x"), lambda x,c:
|
||||
x.replace(dtype=x.dtype.scalar().vec(c.ptrdtype.base.count)) if isinstance(c.dtype, PtrDType) and c.ptrdtype.base.count > x.dtype.count else None),
|
||||
(UPat(Ops.STACK, name="x"), lambda x: x.replace(dtype=x.dtype.scalar().vec(len(x.src))) if 1 < len(x.src) != x.dtype.count else None),
|
||||
(UPat(GroupOp.ALU.union({Ops.CAST, Ops.BITCAST}), name="x"), lambda x: x.replace(dtype=x.dtype.scalar().vec(c)) \
|
||||
if not isinstance(x.dtype, PtrDType) and not any(isinstance(s.dtype, PtrDType) for s in x.src) \
|
||||
and (c:=max([s.dtype.count for s in x.src], default=1)) > x.dtype.count else None),
|
||||
])
|
||||
|
||||
# ***** X86 pre instruction selection *****
|
||||
|
||||
def gated_load(ctx, base:UOp, idx:UOp, cast:UOp, alt:UOp, gate:UOp, x:UOp):
|
||||
@@ -185,7 +214,7 @@ def gated_store(base:UOp, idx:UOp, cast:UOp, gate:UOp, val:UOp):
|
||||
return ptr.cast(cast.dtype).store(val)
|
||||
|
||||
# these must be done in a separate matcher because they violate the spec
|
||||
pre_isel_matcher = PatternMatcher([
|
||||
pre_isel_matcher = pm_x86_style + PatternMatcher([
|
||||
# zero extending scalar 32bit int is a noop
|
||||
(UPat.var("y", dtypes.uint32).cast(dtypes.int64s, name="x"), lambda y,x: x.replace(op=Ops.NOOP) if y.dtype.count == 1 else None),
|
||||
# cast between signed and unsigned int is a noop
|
||||
@@ -379,7 +408,7 @@ isel_matcher = PatternMatcher([
|
||||
x.replace(src=(x.ins(X86Ops.RET, src=x.src + tuple(def_reg(dtypes.uint64 if r in GPR else dtypes.float64.vec(2), r) for r in CALLEE_SAVED)),)) \
|
||||
if not x.src or x.src[0].arg is not X86Ops.RET else None),
|
||||
# function abi constraints
|
||||
(UPat((Ops.PARAM, Ops.DEFINE_VAR, Ops.SPECIAL), name="x"), abi),
|
||||
(UPat((Ops.PARAM, Ops.SPECIAL), name="x"), abi),
|
||||
# these are treated the same for now
|
||||
(UPat(Ops.DEFINE_REG, name="x"), lambda x:
|
||||
x.replace(op=Ops.DEFINE_LOCAL, dtype=x.dtype.base.ptr(x.dtype.size, AddrSpace.LOCAL)) if isinstance(x.arg, int) else None),
|
||||
@@ -406,7 +435,7 @@ isel_matcher = PatternMatcher([
|
||||
a.ins(X86Ops.VBLENDVPD, src=(b, a, m.replace(dtype=m.src[0].dtype)))),
|
||||
# in this case we have a mask producing comparison whose user expects a bool, so we convert to bool
|
||||
(UPat(GroupOp.Comparison, dtypes.bool, (UPat.var("y", (dtypes.float32, dtypes.float64)), UPat()), name="x"), lambda y,x:
|
||||
x.replace(dtype=y.dtype).bitcast(to_int(y.dtype)).bitwise_and(1).f(Ops.NOOP, dtype=dtypes.bool)),
|
||||
UOp(Ops.AND, dt:=to_int(y.dtype), (x.replace(dtype=y.dtype).bitcast(dt), UOp.const(dt, 1))).f(Ops.NOOP, dtype=dtypes.bool)),
|
||||
# conditional moves that use flags
|
||||
(UPat(Ops.CMPLT, src=(UPat(dtype=dtypes.sints), UPat()), name="m").where(UPat.var("a"), UPat.var("b")), lambda m,a,b:
|
||||
a.ins(X86Ops.CMOVL, src=(b, a, cmp(m)))),
|
||||
|
||||
@@ -120,7 +120,6 @@ base_rewrite = PatternMatcher([
|
||||
])
|
||||
|
||||
class LLVMRenderer(Renderer):
|
||||
new_style = True
|
||||
supports_float4 = True
|
||||
abi: str | None
|
||||
string_rewrite: PatternMatcher
|
||||
|
||||
@@ -132,7 +132,6 @@ string_rewrite = PatternMatcher([
|
||||
])
|
||||
|
||||
class PTXRenderer(Renderer):
|
||||
new_style = True
|
||||
suffix = "PTX"
|
||||
global_max, local_max, shared_max = CUDARenderer.global_max, CUDARenderer.local_max, CUDARenderer.shared_max
|
||||
tc_sm80 = [x for x in tc.cuda_sm80 if x.dtype_in in [dtypes.half, dtypes.float]]
|
||||
|
||||
@@ -203,7 +203,6 @@ class PythonCompiler(Compiler):
|
||||
def compile(self, src:str) -> bytes: return base64.b64decode(src)
|
||||
|
||||
class PythonRenderer(Renderer):
|
||||
new_style = True
|
||||
code_for_op = python_alu
|
||||
compiler = PythonCompiler()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user