lil cleanup

This commit is contained in:
George Hotz
2026-06-11 17:44:22 -07:00
parent 32ad5e8b96
commit 785f09aac4

View File

@@ -84,7 +84,7 @@ def scope(space): return 'global' if space == AddrSpace.GLOBAL else ('shared' if
nstore = nir_instr(has_def=False, df=lambda addr:addr, intrins=lambda space,val: {"WRITE_MASK":(1<<val.num_components)-1,
**({"ALIGN_MUL":val.bit_size//8*val.num_components} if space != AddrSpace.REG else {})},
num_components=lambda val:val.num_components, srcs=lambda space, addr, val: [nsrc(val), nsrc(addr)][::1 if space != AddrSpace.REG else -1])(
lambda b, space, addr, val, dtype: mesa.nir_intrinsic_instr_create(b.shader, g(f"nir_intrinsic_store_{scope(space)}")))
lambda b, space, addr, val: mesa.nir_intrinsic_instr_create(b.shader, g(f"nir_intrinsic_store_{scope(space)}")))
nload = nir_instr(nc=lambda u:u.max_numel(), bs=lambda u:u.dtype.bitsize, num_components=lambda u:u.max_numel(),
intrins=lambda space,u:{**({"ACCESS":mesa.ACCESS_CAN_REORDER} if space==AddrSpace.GLOBAL else {}),
**({"ALIGN_MUL":u.dtype.itemsize*u.max_numel()} if space != AddrSpace.REG else {})}, srcs=lambda addr: [nsrc(addr)])(
@@ -149,8 +149,7 @@ class NIRRenderer(Renderer):
(UPat(Ops.PARAM, name="x"), lambda ctx,x: ctx.param(ctx.b, x, 8 if x.addrspace is not None else x.dtype.itemsize)),
(UPat(Ops.SPECIAL, name="x"), lambda ctx,x: nchannel(ctx.b, {'g':ngid, 'l':nlid, 'i': nid}[x.arg[0]](ctx.b), int(x.arg[-1]))),
(UPat(Ops.STORE, src=(UPat((Ops.INDEX, Ops.SHRINK), src=(UPat.var("buf"),UPat.var("off")), allow_any_len=True), UPat.var("val"))),
lambda ctx,buf,off,val: nstore(ctx.b, buf.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.addrspace, buf.dtype.itemsize), ctx.r[val],
val.dtype)),
lambda ctx,buf,off,val: nstore(ctx.b, buf.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.addrspace, buf.dtype.itemsize), ctx.r[val])),
(UPat(Ops.LOAD, src=(UPat((Ops.INDEX, Ops.SHRINK), src=(UPat.var("buf"), UPat.var("off")), allow_any_len=True), UPat.var("alt"),
UPat.var("gate")), name="x"),
lambda ctx,x,buf,off,alt,gate: if_phi(ctx.b, ctx.r[gate],
@@ -206,7 +205,7 @@ class NIRRenderer(Renderer):
self.b.shader.contents.info.shared_size += u.max_numel()*u.dtype.itemsize
elif u.op == Ops.RANGE:
ranges.append(i:=deref_var(self.b, mesa.nir_local_variable_create(self.b.impl, glsl_type(u.dtype), f"idx{range_str(u)}".encode()).contents))
nstore(self.b, AddrSpace.REG, i, nimm(self.b, 0, u.dtype), u.dtype)
nstore(self.b, AddrSpace.REG, i, nimm(self.b, 0, u.dtype))
mesa.nir_push_loop(self.b)
self.r[u] = nload(self.b, AddrSpace.REG, i, u)
nif(self.b, nalu(self.b, "ilt", self.r[u], self.r[u.src[0]]), lambda: None, lambda: njump(self.b, mesa.nir_jump_break))
@@ -215,7 +214,7 @@ class NIRRenderer(Renderer):
next_i = nalu(self.b, "iadd", self.r[r], nimm(self.b, 1, r.dtype))
# TODO: this nif should be removable ... but TestMultiTensor.test_double_matmul_shard_W_0 segfaults with it gone
nif(self.b, nalu(self.b, "ilt", next_i, self.r[r.src[0]]), lambda: None, lambda: njump(self.b, mesa.nir_jump_break))
nstore(self.b, AddrSpace.REG, ranges.pop(), next_i, r.dtype),
nstore(self.b, AddrSpace.REG, ranges.pop(), next_i),
mesa.nir_pop_loop(self.b, None)
else:
if (d:=self.def_rewrite.rewrite(u, ctx=self)) is None: raise RuntimeError(f"failed to render {u.op} srcs {[x.dtype for x in u.src]}")
@@ -297,7 +296,7 @@ class IR3Renderer(NIRRenderer, OpenCLRenderer):
def prerender(self, uops:list[UOp]):
super().prerender(uops)
self.texs:set[UOp] = set()
self.uops, self.ibo_idx, self.img_idx = uops, 0, 0
self.img_idx = 0
self.param_sz = sum([8 if u.addrspace is not None else u.dtype.itemsize for u in uops if u.op is Ops.PARAM])
def postrender(self, uops:list[UOp]):