mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 16:37:04 +08:00
new style cleanups (#16584)
* spec tighten * revert * lin fix * lin fix * needed for x86 * revert
This commit is contained in:
@@ -39,7 +39,7 @@ class Estimates:
|
||||
mults = mults.substitute({x:x.const_like(0) for x in mults.toposort() if x.op is Ops.SPECIAL}) if isinstance(mults, UOp) else mults
|
||||
elif u.op is Ops.END: mults = mult_stack.pop(-1)
|
||||
elif u.op is Ops.SPECIAL: mults *= cast(sint, u.src[0].ssimplify()) # NOTE: we don't push to the mult_stack here, you can't end these
|
||||
elif (u.op is Ops.DEFINE_VAR or (u.op is Ops.PARAM and u.arg.addrspace is None)) and u.expr == 'core_id': mults *= int(u.vmax) + 1
|
||||
elif u.op is Ops.PARAM and u.arg.addrspace is None and u.expr == 'core_id': mults *= int(u.vmax) + 1
|
||||
elif u.op is Ops.LOAD and u.src[0].addrspace != AddrSpace.REG:
|
||||
lds += u.max_numel() * u.dtype.scalar().itemsize * mults
|
||||
elif u.op is Ops.STORE and u.src[0].addrspace != AddrSpace.REG:
|
||||
|
||||
@@ -36,7 +36,8 @@ def assemble_linear(prg:UOp, lin:UOp, arch:str) -> bytes:
|
||||
# ** scan sink for metadata
|
||||
sink, n_bufs, n_vars, lds_size, gids = prg.src[0], 0, 0, 0, set()
|
||||
for u in sink.toposort():
|
||||
if u.op is Ops.PARAM: n_bufs += 1
|
||||
if u.op is Ops.PARAM and u.addrspace is not None: n_bufs += 1
|
||||
elif u.op is Ops.PARAM and u.addrspace is None: n_vars += 1
|
||||
elif u.op is Ops.DEFINE_VAR: n_vars += 1
|
||||
elif u.op is Ops.DEFINE_LOCAL: lds_size += u.ptrdtype.size * u.ptrdtype.base.itemsize
|
||||
elif u.op is Ops.SPECIAL and u.arg.startswith("gidx"): gids.add(int(u.arg[-1]))
|
||||
|
||||
@@ -230,22 +230,21 @@ class CStyleLanguage(Renderer):
|
||||
if u.op is Ops.SPECIAL: r[u] = u.arg
|
||||
elif u.op is Ops.RANGE: r[u] = f"{axis_letters[u.arg[-1]]}idx"+range_str(u)
|
||||
else:
|
||||
prefix = {Ops.WMMA: "wmma", Ops.DEFINE_LOCAL: "temp", Ops.CONST: "const", Ops.BUFFER: "buf",
|
||||
Ops.CAST: "cast", Ops.BITCAST: "cast", Ops.GEP: "gep", Ops.STACK: "cast",
|
||||
Ops.INDEX: "bidx", Ops.DEFINE_REG: "acc", Ops.LOAD: "val"}.get(u.op, "alu")
|
||||
prefix = {Ops.WMMA: "wmma", Ops.CONST: "const", Ops.BUFFER: "buf", Ops.CAST: "cast", Ops.BITCAST: "cast", Ops.STACK: "cast",
|
||||
Ops.INDEX: "bidx", Ops.LOAD: "val"}.get(u.op, "alu")
|
||||
r[u] = f"{prefix}{c[prefix]}"
|
||||
|
||||
l = cast(str, self.string_rewrite.rewrite(u, ctx=self))
|
||||
assert l is not None, f"failed to render {u.op} {u.dtype} {[(x.op,x.dtype) for x in u.src]} {u.arg}"
|
||||
|
||||
if u.op in {Ops.ENDIF, Ops.END}: depth -= 1
|
||||
if (u.op is not Ops.CAST or u.dtype.vcount == 1) and (u.op in {Ops.CONST, Ops.GEP, Ops.INDEX, Ops.SHRINK, Ops.CUSTOMI} or \
|
||||
if (u.op is not Ops.CAST or u.dtype.vcount == 1) and (u.op in {Ops.CONST, Ops.INDEX, Ops.SHRINK, Ops.CUSTOMI} or \
|
||||
(u.op is Ops.LOAD and u.src[0].addrspace == AddrSpace.REG) or \
|
||||
(u.op is Ops.CAST and u.addrspace in (AddrSpace.GLOBAL, AddrSpace.LOCAL)) or \
|
||||
(u.op in {Ops.STACK, *(GroupOp.ALU-{Ops.WHERE}), Ops.CAST, Ops.BITCAST} and child_count[u] == 1 and not getenv("EXPAND_SSA"))):
|
||||
r[u] = l
|
||||
else:
|
||||
if u.op not in {Ops.RANGE, Ops.DEFINE_LOCAL, Ops.STORE, Ops.DEFINE_REG, Ops.BUFFER} and u.dtype != dtypes.void:
|
||||
if u.op not in {Ops.RANGE, Ops.STORE, Ops.BUFFER} and u.dtype != dtypes.void:
|
||||
l = f"{self.render_type(u)} {r[u]} = {l}" + (";" if u.op is not Ops.SPECIAL else "")
|
||||
kernel.append(" "*depth + l)
|
||||
if prefix: c[prefix] += 1 # if it was used, increment
|
||||
|
||||
@@ -115,7 +115,6 @@ def nidx(b:mesa.nir_builder, buf, off, space, itemsize, gate=None) -> mesa.nir_d
|
||||
return if_phi(b, gate, f, lambda: buf) if gate is not None else f()
|
||||
|
||||
class NIRRenderer(Renderer):
|
||||
new_style = True
|
||||
suffix = "NIR"
|
||||
nir_options: bytes
|
||||
global_max, local_max, shared_max = CUDARenderer.global_max, CUDARenderer.local_max, CUDARenderer.shared_max
|
||||
|
||||
@@ -80,7 +80,7 @@ spec_shared = PatternMatcher([
|
||||
# TODO: remove UNROLL here, it's for SPEC=2
|
||||
(UPat(Ops.GROUP, dtypes.void, src=UPat((Ops.GROUP, Ops.STORE, Ops.NOOP, Ops.UNROLL, Ops.INS))), lambda: True),
|
||||
|
||||
# TOOD: these should be buffer with different addrspace
|
||||
# TOOD: these should be buffer with different addrspace everywhere.
|
||||
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG)), lambda: True),
|
||||
|
||||
# AFTER on Movement Op, PARAM, BUFFER, CONTIGUOUS, or another AFTER
|
||||
@@ -192,11 +192,14 @@ spec_tensor = PatternMatcher([
|
||||
|
||||
# these ops can exist in programs but not the tensor spec. example: LOAD
|
||||
spec_program = PatternMatcher([
|
||||
# no more of these in programs
|
||||
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.DEFINE_VAR, Ops.GEP)), lambda: False),
|
||||
|
||||
# weakint is not allowed in programs
|
||||
(UPat(GroupOp.All, dtypes.weakint), lambda: False),
|
||||
|
||||
# allow special SHRINK
|
||||
(UPat(Ops.SHRINK, src=(UPat((Ops.PARAM, Ops.BUFFER, Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.AFTER)), UPat(), UPat(Ops.CONST))), lambda: True),
|
||||
(UPat(Ops.SHRINK, src=(UPat((Ops.PARAM, Ops.BUFFER, Ops.AFTER)), UPat(), UPat(Ops.CONST))), lambda: True),
|
||||
|
||||
# movement ops are not allowed in programs
|
||||
(UPat(GroupOp.Movement), lambda: False),
|
||||
@@ -213,7 +216,6 @@ spec_program = PatternMatcher([
|
||||
|
||||
# STACK/GEP in program. TODO: this should match Tensor
|
||||
(UPat(Ops.STACK, name="x"), lambda x: len(x.src)>1 or len(x.src) == 0),
|
||||
(UPat(Ops.GEP, src=(UPat.var("src"),), name="gep"), lambda gep,src: gep.dtype == src.dtype.scalar()),
|
||||
|
||||
# if has a <gate, index_for_dedup>
|
||||
(UPat(Ops.IF, dtype=dtypes.void, src=(UPat(dtype=dtypes.bool), UPat((Ops.CAST, Ops.INDEX, Ops.SHRINK)))), lambda: True),
|
||||
|
||||
Reference in New Issue
Block a user