mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
lil cleanup
This commit is contained in:
@@ -84,7 +84,7 @@ def scope(space): return 'global' if space == AddrSpace.GLOBAL else ('shared' if
|
||||
nstore = nir_instr(has_def=False, df=lambda addr:addr, intrins=lambda space,val: {"WRITE_MASK":(1<<val.num_components)-1,
|
||||
**({"ALIGN_MUL":val.bit_size//8*val.num_components} if space != AddrSpace.REG else {})},
|
||||
num_components=lambda val:val.num_components, srcs=lambda space, addr, val: [nsrc(val), nsrc(addr)][::1 if space != AddrSpace.REG else -1])(
|
||||
lambda b, space, addr, val, dtype: mesa.nir_intrinsic_instr_create(b.shader, g(f"nir_intrinsic_store_{scope(space)}")))
|
||||
lambda b, space, addr, val: mesa.nir_intrinsic_instr_create(b.shader, g(f"nir_intrinsic_store_{scope(space)}")))
|
||||
nload = nir_instr(nc=lambda u:u.max_numel(), bs=lambda u:u.dtype.bitsize, num_components=lambda u:u.max_numel(),
|
||||
intrins=lambda space,u:{**({"ACCESS":mesa.ACCESS_CAN_REORDER} if space==AddrSpace.GLOBAL else {}),
|
||||
**({"ALIGN_MUL":u.dtype.itemsize*u.max_numel()} if space != AddrSpace.REG else {})}, srcs=lambda addr: [nsrc(addr)])(
|
||||
@@ -149,8 +149,7 @@ class NIRRenderer(Renderer):
|
||||
(UPat(Ops.PARAM, name="x"), lambda ctx,x: ctx.param(ctx.b, x, 8 if x.addrspace is not None else x.dtype.itemsize)),
|
||||
(UPat(Ops.SPECIAL, name="x"), lambda ctx,x: nchannel(ctx.b, {'g':ngid, 'l':nlid, 'i': nid}[x.arg[0]](ctx.b), int(x.arg[-1]))),
|
||||
(UPat(Ops.STORE, src=(UPat((Ops.INDEX, Ops.SHRINK), src=(UPat.var("buf"),UPat.var("off")), allow_any_len=True), UPat.var("val"))),
|
||||
lambda ctx,buf,off,val: nstore(ctx.b, buf.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.addrspace, buf.dtype.itemsize), ctx.r[val],
|
||||
val.dtype)),
|
||||
lambda ctx,buf,off,val: nstore(ctx.b, buf.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.addrspace, buf.dtype.itemsize), ctx.r[val])),
|
||||
(UPat(Ops.LOAD, src=(UPat((Ops.INDEX, Ops.SHRINK), src=(UPat.var("buf"), UPat.var("off")), allow_any_len=True), UPat.var("alt"),
|
||||
UPat.var("gate")), name="x"),
|
||||
lambda ctx,x,buf,off,alt,gate: if_phi(ctx.b, ctx.r[gate],
|
||||
@@ -206,7 +205,7 @@ class NIRRenderer(Renderer):
|
||||
self.b.shader.contents.info.shared_size += u.max_numel()*u.dtype.itemsize
|
||||
elif u.op == Ops.RANGE:
|
||||
ranges.append(i:=deref_var(self.b, mesa.nir_local_variable_create(self.b.impl, glsl_type(u.dtype), f"idx{range_str(u)}".encode()).contents))
|
||||
nstore(self.b, AddrSpace.REG, i, nimm(self.b, 0, u.dtype), u.dtype)
|
||||
nstore(self.b, AddrSpace.REG, i, nimm(self.b, 0, u.dtype))
|
||||
mesa.nir_push_loop(self.b)
|
||||
self.r[u] = nload(self.b, AddrSpace.REG, i, u)
|
||||
nif(self.b, nalu(self.b, "ilt", self.r[u], self.r[u.src[0]]), lambda: None, lambda: njump(self.b, mesa.nir_jump_break))
|
||||
@@ -215,7 +214,7 @@ class NIRRenderer(Renderer):
|
||||
next_i = nalu(self.b, "iadd", self.r[r], nimm(self.b, 1, r.dtype))
|
||||
# TODO: this nif should be removable ... but TestMultiTensor.test_double_matmul_shard_W_0 segfaults with it gone
|
||||
nif(self.b, nalu(self.b, "ilt", next_i, self.r[r.src[0]]), lambda: None, lambda: njump(self.b, mesa.nir_jump_break))
|
||||
nstore(self.b, AddrSpace.REG, ranges.pop(), next_i, r.dtype),
|
||||
nstore(self.b, AddrSpace.REG, ranges.pop(), next_i),
|
||||
mesa.nir_pop_loop(self.b, None)
|
||||
else:
|
||||
if (d:=self.def_rewrite.rewrite(u, ctx=self)) is None: raise RuntimeError(f"failed to render {u.op} srcs {[x.dtype for x in u.src]}")
|
||||
@@ -297,7 +296,7 @@ class IR3Renderer(NIRRenderer, OpenCLRenderer):
|
||||
def prerender(self, uops:list[UOp]):
|
||||
super().prerender(uops)
|
||||
self.texs:set[UOp] = set()
|
||||
self.uops, self.ibo_idx, self.img_idx = uops, 0, 0
|
||||
self.img_idx = 0
|
||||
self.param_sz = sum([8 if u.addrspace is not None else u.dtype.itemsize for u in uops if u.op is Ops.PARAM])
|
||||
|
||||
def postrender(self, uops:list[UOp]):
|
||||
|
||||
Reference in New Issue
Block a user