diff --git a/tinygrad/codegen/__init__.py b/tinygrad/codegen/__init__.py index 7cb90bbb21..1f5c8d0d90 100644 --- a/tinygrad/codegen/__init__.py +++ b/tinygrad/codegen/__init__.py @@ -38,7 +38,7 @@ pm_index_is_shrink = PatternMatcher([ pm_remove_vec_dtypes = PatternMatcher([ # rewrite PARAM to non pointer (UPat((Ops.PARAM, Ops.BUFFER, Ops.DEFINE_LOCAL, Ops.DEFINE_REG), name="buf"), lambda buf: - buf.replace(dtype=buf.dtype.base, src=(UOp.const(dtypes.int, buf.ptrdtype.size),)) + buf.replace(dtype=buf.dtype.base, src=(UOp.const(dtypes.int, buf.ptrdtype.size),)) \ if isinstance(buf.dtype, PtrDType) and not isinstance(buf.dtype, ImageDType) else None), # no LOADs on register dtypes (UPat(Ops.LOAD, name="x"), lambda x: x.src[0] if x.src[0].addrspace == AddrSpace.REG else None), diff --git a/tinygrad/renderer/wgsl.py b/tinygrad/renderer/wgsl.py index 9b44e72a5a..c048068c61 100644 --- a/tinygrad/renderer/wgsl.py +++ b/tinygrad/renderer/wgsl.py @@ -108,8 +108,8 @@ class WGSLRenderer(CStyleLanguage): prg += "fn nan() -> f32 { let bits = 0xffffffffu; return bitcast(bits); }\n" prg += "@group(0) @binding(0)\nvar INFINITY : f32;\n" prg += "\n".join((external_local_bufs or [])+[f"@group(0) @binding({next(bind_it)+1})" + - f"{'var' if u.addrspace in (AddrSpace.GLOBAL, AddrSpace.LOCAL) else 'var'}" + - f"{name}:{f'array<{self.buf_map(u)}>' if u.addrspace in (AddrSpace.GLOBAL, AddrSpace.LOCAL) else self.buf_map(u)};" for name,(u,_) in bufs]) + f"{'var' if u.addrspace == AddrSpace.GLOBAL else 'var'}" + + f"{name}:{f'array<{self.buf_map(u.dtype.base)}>' if u.addrspace == AddrSpace.GLOBAL else self.buf_map(u.dtype)};" for name,(u,_) in bufs]) prg += f"\n@compute @workgroup_size({','.join([str(x) for x in local_size])}) fn {function_name}(@builtin(workgroup_id) gindex: vec3," return prg + "@builtin(local_invocation_id) lindex: vec3) {\n" + "\n".join(kernel) + "\n}" diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index cacba59c60..b686180511 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -278,9 +278,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass): case Ops.GETADDR: return () case Ops.BIND | Ops.RANGE | Ops.SPECIAL: return () case Ops.BINARY: return (len(self.arg),) - case Ops.BUFFER: - if isinstance(self.arg, ParamArg): return self.src[0].as_shape - return (self.arg,) + case Ops.BUFFER: return self.src[0].as_shape if isinstance(self.arg, ParamArg) else (self.arg,) case Ops.SLICE: # HACK: SLICE is used inside kernels, so we set the shape to () if it's on an INDEX if self.src[0].op is Ops.INDEX: return ()