Merge branch 'master' into cstyle_new_style

This commit is contained in:
George Hotz
2026-06-07 12:25:51 -07:00
committed by GitHub
3 changed files with 4 additions and 6 deletions

View File

@@ -38,7 +38,7 @@ pm_index_is_shrink = PatternMatcher([
pm_remove_vec_dtypes = PatternMatcher([
# rewrite PARAM to non pointer
(UPat((Ops.PARAM, Ops.BUFFER, Ops.DEFINE_LOCAL, Ops.DEFINE_REG), name="buf"), lambda buf:
buf.replace(dtype=buf.dtype.base, src=(UOp.const(dtypes.int, buf.ptrdtype.size),))
buf.replace(dtype=buf.dtype.base, src=(UOp.const(dtypes.int, buf.ptrdtype.size),)) \
if isinstance(buf.dtype, PtrDType) and not isinstance(buf.dtype, ImageDType) else None),
# no LOADs on register dtypes
(UPat(Ops.LOAD, name="x"), lambda x: x.src[0] if x.src[0].addrspace == AddrSpace.REG else None),

View File

@@ -108,8 +108,8 @@ class WGSLRenderer(CStyleLanguage):
prg += "fn nan() -> f32 { let bits = 0xffffffffu; return bitcast<f32>(bits); }\n"
prg += "@group(0) @binding(0)\nvar<uniform> INFINITY : f32;\n"
prg += "\n".join((external_local_bufs or [])+[f"@group(0) @binding({next(bind_it)+1})" +
f"{'var<storage,read_write>' if u.addrspace in (AddrSpace.GLOBAL, AddrSpace.LOCAL) else 'var<uniform>'}" +
f"{name}:{f'array<{self.buf_map(u)}>' if u.addrspace in (AddrSpace.GLOBAL, AddrSpace.LOCAL) else self.buf_map(u)};" for name,(u,_) in bufs])
f"{'var<storage,read_write>' if u.addrspace == AddrSpace.GLOBAL else 'var<uniform>'}" +
f"{name}:{f'array<{self.buf_map(u.dtype.base)}>' if u.addrspace == AddrSpace.GLOBAL else self.buf_map(u.dtype)};" for name,(u,_) in bufs])
prg += f"\n@compute @workgroup_size({','.join([str(x) for x in local_size])}) fn {function_name}(@builtin(workgroup_id) gindex: vec3<u32>,"
return prg + "@builtin(local_invocation_id) lindex: vec3<u32>) {\n" + "\n".join(kernel) + "\n}"

View File

@@ -278,9 +278,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
case Ops.GETADDR: return ()
case Ops.BIND | Ops.RANGE | Ops.SPECIAL: return ()
case Ops.BINARY: return (len(self.arg),)
case Ops.BUFFER:
if isinstance(self.arg, ParamArg): return self.src[0].as_shape
return (self.arg,)
case Ops.BUFFER: return self.src[0].as_shape if isinstance(self.arg, ParamArg) else (self.arg,)
case Ops.SLICE:
# HACK: SLICE is used inside kernels, so we set the shape to () if it's on an INDEX
if self.src[0].op is Ops.INDEX: return ()