mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-08 05:54:59 +08:00
Merge branch 'master' into cstyle_new_style
This commit is contained in:
@@ -38,7 +38,7 @@ pm_index_is_shrink = PatternMatcher([
|
||||
pm_remove_vec_dtypes = PatternMatcher([
|
||||
# rewrite PARAM to non pointer
|
||||
(UPat((Ops.PARAM, Ops.BUFFER, Ops.DEFINE_LOCAL, Ops.DEFINE_REG), name="buf"), lambda buf:
|
||||
buf.replace(dtype=buf.dtype.base, src=(UOp.const(dtypes.int, buf.ptrdtype.size),))
|
||||
buf.replace(dtype=buf.dtype.base, src=(UOp.const(dtypes.int, buf.ptrdtype.size),)) \
|
||||
if isinstance(buf.dtype, PtrDType) and not isinstance(buf.dtype, ImageDType) else None),
|
||||
# no LOADs on register dtypes
|
||||
(UPat(Ops.LOAD, name="x"), lambda x: x.src[0] if x.src[0].addrspace == AddrSpace.REG else None),
|
||||
|
||||
@@ -108,8 +108,8 @@ class WGSLRenderer(CStyleLanguage):
|
||||
prg += "fn nan() -> f32 { let bits = 0xffffffffu; return bitcast<f32>(bits); }\n"
|
||||
prg += "@group(0) @binding(0)\nvar<uniform> INFINITY : f32;\n"
|
||||
prg += "\n".join((external_local_bufs or [])+[f"@group(0) @binding({next(bind_it)+1})" +
|
||||
f"{'var<storage,read_write>' if u.addrspace in (AddrSpace.GLOBAL, AddrSpace.LOCAL) else 'var<uniform>'}" +
|
||||
f"{name}:{f'array<{self.buf_map(u)}>' if u.addrspace in (AddrSpace.GLOBAL, AddrSpace.LOCAL) else self.buf_map(u)};" for name,(u,_) in bufs])
|
||||
f"{'var<storage,read_write>' if u.addrspace == AddrSpace.GLOBAL else 'var<uniform>'}" +
|
||||
f"{name}:{f'array<{self.buf_map(u.dtype.base)}>' if u.addrspace == AddrSpace.GLOBAL else self.buf_map(u.dtype)};" for name,(u,_) in bufs])
|
||||
prg += f"\n@compute @workgroup_size({','.join([str(x) for x in local_size])}) fn {function_name}(@builtin(workgroup_id) gindex: vec3<u32>,"
|
||||
return prg + "@builtin(local_invocation_id) lindex: vec3<u32>) {\n" + "\n".join(kernel) + "\n}"
|
||||
|
||||
|
||||
@@ -278,9 +278,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
case Ops.GETADDR: return ()
|
||||
case Ops.BIND | Ops.RANGE | Ops.SPECIAL: return ()
|
||||
case Ops.BINARY: return (len(self.arg),)
|
||||
case Ops.BUFFER:
|
||||
if isinstance(self.arg, ParamArg): return self.src[0].as_shape
|
||||
return (self.arg,)
|
||||
case Ops.BUFFER: return self.src[0].as_shape if isinstance(self.arg, ParamArg) else (self.arg,)
|
||||
case Ops.SLICE:
|
||||
# HACK: SLICE is used inside kernels, so we set the shape to () if it's on an INDEX
|
||||
if self.src[0].op is Ops.INDEX: return ()
|
||||
|
||||
Reference in New Issue
Block a user