diff --git a/tinygrad/renderer/llvmir.py b/tinygrad/renderer/llvmir.py index 88101847dd..031fd419e8 100644 --- a/tinygrad/renderer/llvmir.py +++ b/tinygrad/renderer/llvmir.py @@ -176,7 +176,6 @@ class LLVMRenderer(Renderer): args.append((r[u], u.dtype)) elif u.op in (Ops.DEFINE_LOCAL, Ops.DEFINE_REG): r[u] = f"%{'local' if u.op is Ops.DEFINE_LOCAL else 'reg'}_{str(u.arg).replace('(', '').replace(')', '').replace(',', '_').replace(' ', '')}" - #assert isinstance(u.dtype, PtrDType) size = u.max_numel() if u.op is Ops.DEFINE_REG: kernel.append(f" {r[u]} = alloca [{size} x {ldt(u.dtype)}]") diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index fbc1ba6165..4aa1e6fa42 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -769,7 +769,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass): if self.op in {Ops.INDEX, Ops.CAST, Ops.AFTER, Ops.REDUCE, Ops.GEP}: return self.src[0].addrspace if self.op in GroupOp.Movement: return self.src[0].addrspace - if self.op is Ops.STACK or self.op in GroupOp.Elementwise: + if self.op in {Ops.STACK, Ops.WMMA} or self.op in GroupOp.Elementwise: ad = [x.addrspace for x in self.src if x.addrspace is not None] if not len(ad) or not all_same(ad): return None return ad[0] @@ -1116,7 +1116,7 @@ class ProgramInfo: if u.op is Ops.DEFINE_VAR: _vars.append(u) if u.op is Ops.PARAM: _globals.append(u.arg.slot) if u.op in (Ops.STORE, Ops.LOAD): - if (idx:=u.src[0]).op is Ops.INDEX or (u.src[0].op is Ops.CAST and (idx:=u.src[0].src[0]).op is Ops.INDEX): + if (idx:=u.src[0]).op in (Ops.INDEX, Ops.SHRINK): if (buf:=idx.src[0]).op is Ops.PARAM: (outs if u.op is Ops.STORE else ins).append(buf.arg.slot) if u.op is Ops.SPECIAL: if u.arg[0] == 'i': local_size = None diff --git a/tinygrad/uop/spec.py b/tinygrad/uop/spec.py index 4e4f7440fc..7124fb9e41 100644 --- a/tinygrad/uop/spec.py +++ b/tinygrad/uop/spec.py @@ -81,8 +81,6 @@ spec_shared = PatternMatcher([ (UPat(Ops.GROUP, dtypes.void, src=UPat((Ops.GROUP, Ops.STORE, Ops.NOOP, Ops.UNROLL, Ops.INS))), lambda: True), # TOOD: these should be buffer with different addrspace - #(UPat(Ops.DEFINE_LOCAL, name="x"), lambda x: isinstance(x.dtype, PtrDType) and x.dtype.addrspace == AddrSpace.LOCAL), - #(UPat(Ops.DEFINE_REG, src=()), lambda: True), (UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG)), lambda: True), # AFTER on Movement Op, PARAM, BUFFER, CONTIGUOUS, or another AFTER