From 4e65ddfad5b0eeffae7b2193eab139108e6800d0 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Tue, 2 Jun 2026 18:37:41 -0700 Subject: [PATCH] cstyle new style --- tinygrad/renderer/cstyle.py | 36 +++++++++++++++++++++++++++++------- 1 file changed, 29 insertions(+), 7 deletions(-) diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index d3e5a66922..3f2286b187 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -8,7 +8,6 @@ from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType, AddrSpace, trunc from tinygrad.renderer import Renderer from tinygrad.codegen.late.devectorizer import no_vectorized_alu - base_rewrite = PatternMatcher([ # defines (UPat(Ops.DEFINE_REG, name="x"), lambda ctx,x: f"{ctx.render_dtype(x.dtype.base)} {ctx[x]}[{x.max_numel()}];"), @@ -47,8 +46,9 @@ base_rewrite = PatternMatcher([ # default const render (UPat(Ops.CONST, name="x"), lambda ctx,x: str(x.arg)), - # movement ops - (UPat.var("buf").index(UPat.var('idx')), lambda ctx,buf,idx: f"({ctx[buf]}+{strip_parens(ctx[idx]) if idx.arg == Ops.ADD else ctx[idx]})"), + # SHRINK/INDEX + (UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var('idx'))), lambda ctx,**kwargs: ctx.render_index(**kwargs)), + (UPat(Ops.SHRINK, src=(UPat.var("buf"), UPat.var('idx'), UPat.cvar())), lambda ctx,**kwargs: ctx.render_index(**kwargs)), (UPat(Ops.STACK, name="x"), lambda ctx,x: f"{ctx.float4.replace('float4', ctx.render_dtype(x.dtype))}" + \ f"{ctx.float4_style[0]}{','.join([ctx[y] for y in x.src])}{ctx.float4_style[1]}"), @@ -106,13 +106,15 @@ pm_manual_bf16_cast = PatternMatcher([ (UPat(Ops.CAST, dtype=dtypes.bfloat16, src=(UPat.var("x", dtype=dtypes.float),)), cast_float_to_bf16), ]) -def uops_to_dtypes(uops:list[UOp]) -> list[DType]: return dedup(u.dtype for u in uops if not isinstance(u.dtype, (ImageDType, PtrDType))) +def uops_to_dtypes(uops:list[UOp]) -> list[tuple[DType, int]]: + return dedup((u.dtype, u.max_numel()) for u in uops if u.addrspace is AddrSpace.ANON) # (name, dims, dtype_in, dtype_out, device, threads, upcast_axes, reduce_axes) def wmma_args(uops:list[UOp]): return dedup((uop.arg[0], uop.arg[1], uop.arg[2], uop.dtype.scalar(), *(uop.arg[4:8])) for uop in uops if uop.op is Ops.WMMA) class CStyleLanguage(Renderer): + new_style = True kernel_typedef: str = "void" buffer_prefix: str = "" buffer_suffix: str = "" @@ -146,7 +148,7 @@ class CStyleLanguage(Renderer): tmp = "" if any(isinstance(u.dtype, ImageDType) for _,(u,_) in bufs): tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" - buftypes = [(name, self.render_dtype(u.dtype, mutable)+self.buffer_suffix if isinstance(u.dtype, (ImageDType, PtrDType)) else + buftypes = [(name, self.render_type(u, mutable)+self.buffer_suffix if u.addrspace == AddrSpace.GLOBAL else self.arg_int_prefix if u.dtype == dtypes.int else None) for name,(u,mutable) in bufs] local_dims = [u.src[0] for u in uops if u.op is Ops.SPECIAL and u.arg[0] == "l"] launch_bounds = prod([d.vmax for d in local_dims]) @@ -155,6 +157,26 @@ class CStyleLanguage(Renderer): [") {\n" + tmp] + ['\n'.join(kernel), "\n}"]) return prg if prefix is None else "\n".join(prefix)+f"\n{prg}" + def render_index(self, buf:UOp, idx:UOp): + if buf.addrspace == AddrSpace.ANON: + assert idx.op is Ops.CONST, f"{idx.op} must be CONST" + return f"{self[buf]}[{idx.arg}]" # TODO: is this syntax okay? + else: + return f"({self[buf]}+{strip_parens(self[idx]) if idx.arg == Ops.ADD else self[idx]})" + + def render_type(self, u:UOp, mutable=True): + # TODO: get this from the shape + if isinstance(u.dtype, ImageDType): return f"{'write_only' if mutable else 'read_only'} image2d_t" + if u.addrspace in (AddrSpace.LOCAL, AddrSpace.GLOBAL): + prefix = "" + if u.addrspace == AddrSpace.LOCAL and self.smem_prefix_for_cast: prefix = self.smem_prefix + if u.addrspace == AddrSpace.GLOBAL: prefix = self.buffer_prefix + return prefix + self.type_map.get(scalar:=u.dtype.scalar(), scalar.name) + "*" + elif (sz:=u.max_numel()) > 1: + # pointers + return self.type_map.get(scalar:=u.dtype.scalar(), scalar.name).replace(" ", "_") + str(sz) + return self.type_map.get(scalar:=u.dtype.scalar(), scalar.name) + def render_cast(self, u:UOp, val:str) -> str: return f"({self.render_dtype(u.dtype)})({val})" def render_dtype(self, dt:DType, mutable=True) -> str: if isinstance(dt, ImageDType): return f"{'write_only' if mutable else 'read_only'} image2d_t" @@ -208,7 +230,7 @@ class CStyleLanguage(Renderer): assert l is not None, f"failed to render {u.op} {u.dtype} {[(x.op,x.dtype) for x in u.src]} {u.arg}" if u.op in {Ops.ENDIF, Ops.END}: depth -= 1 - if (u.op is not Ops.CAST or u.dtype.vcount == 1) and (u.op in {Ops.CONST, Ops.GEP, Ops.INDEX, Ops.CUSTOMI} or \ + if (u.op is not Ops.CAST or u.dtype.vcount == 1) and (u.op in {Ops.CONST, Ops.GEP, Ops.INDEX, Ops.SHRINK, Ops.CUSTOMI} or \ (u.op is Ops.LOAD and u.src[0].addrspace == AddrSpace.REG) or \ (u.op is Ops.CAST and u.addrspace in (AddrSpace.GLOBAL, AddrSpace.LOCAL)) or \ (u.op in {Ops.STACK, *(GroupOp.ALU-{Ops.WHERE}), Ops.CAST, Ops.BITCAST} and child_count[u] == 1 and not getenv("EXPAND_SSA"))): @@ -253,7 +275,7 @@ class ClangRenderer(CStyleLanguage): if sys.platform == 'win32': kernel_typedef = "__attribute__((ms_abi)) void" - def render_vector_prefix(self, dt:DType) -> str: + def render_vector_prefix(self, dt:DType, sz:int) -> str: # round (down) to power of two (this is actually the default clang behavior) alignment = 2**int(math.log2(dt.itemsize)) if getenv("ALIGNED", 1) and not dtypes.is_bool(dt) else 1 return f"typedef {self.render_dtype(dt.scalar())} {self.render_dtype(dt)} __attribute__((aligned({alignment}),ext_vector_type({dt.count})));"