diff --git a/tinygrad/renderer/nir.py b/tinygrad/renderer/nir.py index 8ee9eef396..b4c2f3c309 100644 --- a/tinygrad/renderer/nir.py +++ b/tinygrad/renderer/nir.py @@ -148,7 +148,7 @@ class NIRRenderer(Renderer): (UPat(Ops.SPECIAL, name="x"), lambda ctx,x: nchannel(ctx.b, {'g':ngid, 'l':nlid, 'i': nid}[x.arg[0]](ctx.b), int(x.arg[-1]))), (UPat(Ops.STORE, src=(UPat(Ops.INDEX, src=(UPat.var("buf"),UPat.var("off")), allow_any_len=True), UPat.var("val"))), lambda ctx,buf,off,val: nstore(ctx.b, buf.ptrdtype.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype), ctx.r[val], val.dtype)), - (UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("off"), UPat.var("gate"))), UPat.var("alt")), allow_any_len=True, name="x"), + (UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("off"))), UPat.var("alt"), UPat.var("gate")), allow_any_len=True, name="x"), lambda ctx,x,buf,off,alt,gate: if_phi(ctx.b, ctx.r[gate], lambda: nload(ctx.b, buf.ptrdtype.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype, ctx.r[gate]), x.dtype), lambda: ctx.r[alt])), (UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("off"))),), allow_any_len=True, name="x"),