new style cleanups (#16584)

* spec tighten

* revert

* lin fix

* lin fix

* needed for x86

* revert
This commit is contained in:
George Hotz
2026-06-12 08:10:38 -07:00
committed by GitHub
parent 76c10cd635
commit 51100d2c5c
5 changed files with 12 additions and 11 deletions

View File

@@ -39,7 +39,7 @@ class Estimates:
mults = mults.substitute({x:x.const_like(0) for x in mults.toposort() if x.op is Ops.SPECIAL}) if isinstance(mults, UOp) else mults
elif u.op is Ops.END: mults = mult_stack.pop(-1)
elif u.op is Ops.SPECIAL: mults *= cast(sint, u.src[0].ssimplify()) # NOTE: we don't push to the mult_stack here, you can't end these
elif (u.op is Ops.DEFINE_VAR or (u.op is Ops.PARAM and u.arg.addrspace is None)) and u.expr == 'core_id': mults *= int(u.vmax) + 1
elif u.op is Ops.PARAM and u.arg.addrspace is None and u.expr == 'core_id': mults *= int(u.vmax) + 1
elif u.op is Ops.LOAD and u.src[0].addrspace != AddrSpace.REG:
lds += u.max_numel() * u.dtype.scalar().itemsize * mults
elif u.op is Ops.STORE and u.src[0].addrspace != AddrSpace.REG:

View File

@@ -36,7 +36,8 @@ def assemble_linear(prg:UOp, lin:UOp, arch:str) -> bytes:
# ** scan sink for metadata
sink, n_bufs, n_vars, lds_size, gids = prg.src[0], 0, 0, 0, set()
for u in sink.toposort():
if u.op is Ops.PARAM: n_bufs += 1
if u.op is Ops.PARAM and u.addrspace is not None: n_bufs += 1
elif u.op is Ops.PARAM and u.addrspace is None: n_vars += 1
elif u.op is Ops.DEFINE_VAR: n_vars += 1
elif u.op is Ops.DEFINE_LOCAL: lds_size += u.ptrdtype.size * u.ptrdtype.base.itemsize
elif u.op is Ops.SPECIAL and u.arg.startswith("gidx"): gids.add(int(u.arg[-1]))

View File

@@ -230,22 +230,21 @@ class CStyleLanguage(Renderer):
if u.op is Ops.SPECIAL: r[u] = u.arg
elif u.op is Ops.RANGE: r[u] = f"{axis_letters[u.arg[-1]]}idx"+range_str(u)
else:
prefix = {Ops.WMMA: "wmma", Ops.DEFINE_LOCAL: "temp", Ops.CONST: "const", Ops.BUFFER: "buf",
Ops.CAST: "cast", Ops.BITCAST: "cast", Ops.GEP: "gep", Ops.STACK: "cast",
Ops.INDEX: "bidx", Ops.DEFINE_REG: "acc", Ops.LOAD: "val"}.get(u.op, "alu")
prefix = {Ops.WMMA: "wmma", Ops.CONST: "const", Ops.BUFFER: "buf", Ops.CAST: "cast", Ops.BITCAST: "cast", Ops.STACK: "cast",
Ops.INDEX: "bidx", Ops.LOAD: "val"}.get(u.op, "alu")
r[u] = f"{prefix}{c[prefix]}"
l = cast(str, self.string_rewrite.rewrite(u, ctx=self))
assert l is not None, f"failed to render {u.op} {u.dtype} {[(x.op,x.dtype) for x in u.src]} {u.arg}"
if u.op in {Ops.ENDIF, Ops.END}: depth -= 1
if (u.op is not Ops.CAST or u.dtype.vcount == 1) and (u.op in {Ops.CONST, Ops.GEP, Ops.INDEX, Ops.SHRINK, Ops.CUSTOMI} or \
if (u.op is not Ops.CAST or u.dtype.vcount == 1) and (u.op in {Ops.CONST, Ops.INDEX, Ops.SHRINK, Ops.CUSTOMI} or \
(u.op is Ops.LOAD and u.src[0].addrspace == AddrSpace.REG) or \
(u.op is Ops.CAST and u.addrspace in (AddrSpace.GLOBAL, AddrSpace.LOCAL)) or \
(u.op in {Ops.STACK, *(GroupOp.ALU-{Ops.WHERE}), Ops.CAST, Ops.BITCAST} and child_count[u] == 1 and not getenv("EXPAND_SSA"))):
r[u] = l
else:
if u.op not in {Ops.RANGE, Ops.DEFINE_LOCAL, Ops.STORE, Ops.DEFINE_REG, Ops.BUFFER} and u.dtype != dtypes.void:
if u.op not in {Ops.RANGE, Ops.STORE, Ops.BUFFER} and u.dtype != dtypes.void:
l = f"{self.render_type(u)} {r[u]} = {l}" + (";" if u.op is not Ops.SPECIAL else "")
kernel.append(" "*depth + l)
if prefix: c[prefix] += 1 # if it was used, increment

View File

@@ -115,7 +115,6 @@ def nidx(b:mesa.nir_builder, buf, off, space, itemsize, gate=None) -> mesa.nir_d
return if_phi(b, gate, f, lambda: buf) if gate is not None else f()
class NIRRenderer(Renderer):
new_style = True
suffix = "NIR"
nir_options: bytes
global_max, local_max, shared_max = CUDARenderer.global_max, CUDARenderer.local_max, CUDARenderer.shared_max

View File

@@ -80,7 +80,7 @@ spec_shared = PatternMatcher([
# TODO: remove UNROLL here, it's for SPEC=2
(UPat(Ops.GROUP, dtypes.void, src=UPat((Ops.GROUP, Ops.STORE, Ops.NOOP, Ops.UNROLL, Ops.INS))), lambda: True),
# TOOD: these should be buffer with different addrspace
# TOOD: these should be buffer with different addrspace everywhere.
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG)), lambda: True),
# AFTER on Movement Op, PARAM, BUFFER, CONTIGUOUS, or another AFTER
@@ -192,11 +192,14 @@ spec_tensor = PatternMatcher([
# these ops can exist in programs but not the tensor spec. example: LOAD
spec_program = PatternMatcher([
# no more of these in programs
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.DEFINE_VAR, Ops.GEP)), lambda: False),
# weakint is not allowed in programs
(UPat(GroupOp.All, dtypes.weakint), lambda: False),
# allow special SHRINK
(UPat(Ops.SHRINK, src=(UPat((Ops.PARAM, Ops.BUFFER, Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.AFTER)), UPat(), UPat(Ops.CONST))), lambda: True),
(UPat(Ops.SHRINK, src=(UPat((Ops.PARAM, Ops.BUFFER, Ops.AFTER)), UPat(), UPat(Ops.CONST))), lambda: True),
# movement ops are not allowed in programs
(UPat(GroupOp.Movement), lambda: False),
@@ -213,7 +216,6 @@ spec_program = PatternMatcher([
# STACK/GEP in program. TODO: this should match Tensor
(UPat(Ops.STACK, name="x"), lambda x: len(x.src)>1 or len(x.src) == 0),
(UPat(Ops.GEP, src=(UPat.var("src"),), name="gep"), lambda gep,src: gep.dtype == src.dtype.scalar()),
# if has a <gate, index_for_dedup>
(UPat(Ops.IF, dtype=dtypes.void, src=(UPat(dtype=dtypes.bool), UPat((Ops.CAST, Ops.INDEX, Ops.SHRINK)))), lambda: True),