cstyle new style

This commit is contained in:
George Hotz
2026-06-02 18:37:41 -07:00
parent 7dcfd144b6
commit 4e65ddfad5

View File

@@ -8,7 +8,6 @@ from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType, AddrSpace, trunc
from tinygrad.renderer import Renderer
from tinygrad.codegen.late.devectorizer import no_vectorized_alu
base_rewrite = PatternMatcher([
# defines
(UPat(Ops.DEFINE_REG, name="x"), lambda ctx,x: f"{ctx.render_dtype(x.dtype.base)} {ctx[x]}[{x.max_numel()}];"),
@@ -47,8 +46,9 @@ base_rewrite = PatternMatcher([
# default const render
(UPat(Ops.CONST, name="x"), lambda ctx,x: str(x.arg)),
# movement ops
(UPat.var("buf").index(UPat.var('idx')), lambda ctx,buf,idx: f"({ctx[buf]}+{strip_parens(ctx[idx]) if idx.arg == Ops.ADD else ctx[idx]})"),
# SHRINK/INDEX
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var('idx'))), lambda ctx,**kwargs: ctx.render_index(**kwargs)),
(UPat(Ops.SHRINK, src=(UPat.var("buf"), UPat.var('idx'), UPat.cvar())), lambda ctx,**kwargs: ctx.render_index(**kwargs)),
(UPat(Ops.STACK, name="x"),
lambda ctx,x: f"{ctx.float4.replace('float4', ctx.render_dtype(x.dtype))}" + \
f"{ctx.float4_style[0]}{','.join([ctx[y] for y in x.src])}{ctx.float4_style[1]}"),
@@ -106,13 +106,15 @@ pm_manual_bf16_cast = PatternMatcher([
(UPat(Ops.CAST, dtype=dtypes.bfloat16, src=(UPat.var("x", dtype=dtypes.float),)), cast_float_to_bf16),
])
def uops_to_dtypes(uops:list[UOp]) -> list[DType]: return dedup(u.dtype for u in uops if not isinstance(u.dtype, (ImageDType, PtrDType)))
def uops_to_dtypes(uops:list[UOp]) -> list[tuple[DType, int]]:
return dedup((u.dtype, u.max_numel()) for u in uops if u.addrspace is AddrSpace.ANON)
# (name, dims, dtype_in, dtype_out, device, threads, upcast_axes, reduce_axes)
def wmma_args(uops:list[UOp]):
return dedup((uop.arg[0], uop.arg[1], uop.arg[2], uop.dtype.scalar(), *(uop.arg[4:8])) for uop in uops if uop.op is Ops.WMMA)
class CStyleLanguage(Renderer):
new_style = True
kernel_typedef: str = "void"
buffer_prefix: str = ""
buffer_suffix: str = ""
@@ -146,7 +148,7 @@ class CStyleLanguage(Renderer):
tmp = ""
if any(isinstance(u.dtype, ImageDType) for _,(u,_) in bufs):
tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n"
buftypes = [(name, self.render_dtype(u.dtype, mutable)+self.buffer_suffix if isinstance(u.dtype, (ImageDType, PtrDType)) else
buftypes = [(name, self.render_type(u, mutable)+self.buffer_suffix if u.addrspace == AddrSpace.GLOBAL else
self.arg_int_prefix if u.dtype == dtypes.int else None) for name,(u,mutable) in bufs]
local_dims = [u.src[0] for u in uops if u.op is Ops.SPECIAL and u.arg[0] == "l"]
launch_bounds = prod([d.vmax for d in local_dims])
@@ -155,6 +157,26 @@ class CStyleLanguage(Renderer):
[") {\n" + tmp] + ['\n'.join(kernel), "\n}"])
return prg if prefix is None else "\n".join(prefix)+f"\n{prg}"
def render_index(self, buf:UOp, idx:UOp):
if buf.addrspace == AddrSpace.ANON:
assert idx.op is Ops.CONST, f"{idx.op} must be CONST"
return f"{self[buf]}[{idx.arg}]" # TODO: is this syntax okay?
else:
return f"({self[buf]}+{strip_parens(self[idx]) if idx.arg == Ops.ADD else self[idx]})"
def render_type(self, u:UOp, mutable=True):
# TODO: get this from the shape
if isinstance(u.dtype, ImageDType): return f"{'write_only' if mutable else 'read_only'} image2d_t"
if u.addrspace in (AddrSpace.LOCAL, AddrSpace.GLOBAL):
prefix = ""
if u.addrspace == AddrSpace.LOCAL and self.smem_prefix_for_cast: prefix = self.smem_prefix
if u.addrspace == AddrSpace.GLOBAL: prefix = self.buffer_prefix
return prefix + self.type_map.get(scalar:=u.dtype.scalar(), scalar.name) + "*"
elif (sz:=u.max_numel()) > 1:
# pointers
return self.type_map.get(scalar:=u.dtype.scalar(), scalar.name).replace(" ", "_") + str(sz)
return self.type_map.get(scalar:=u.dtype.scalar(), scalar.name)
def render_cast(self, u:UOp, val:str) -> str: return f"({self.render_dtype(u.dtype)})({val})"
def render_dtype(self, dt:DType, mutable=True) -> str:
if isinstance(dt, ImageDType): return f"{'write_only' if mutable else 'read_only'} image2d_t"
@@ -208,7 +230,7 @@ class CStyleLanguage(Renderer):
assert l is not None, f"failed to render {u.op} {u.dtype} {[(x.op,x.dtype) for x in u.src]} {u.arg}"
if u.op in {Ops.ENDIF, Ops.END}: depth -= 1
if (u.op is not Ops.CAST or u.dtype.vcount == 1) and (u.op in {Ops.CONST, Ops.GEP, Ops.INDEX, Ops.CUSTOMI} or \
if (u.op is not Ops.CAST or u.dtype.vcount == 1) and (u.op in {Ops.CONST, Ops.GEP, Ops.INDEX, Ops.SHRINK, Ops.CUSTOMI} or \
(u.op is Ops.LOAD and u.src[0].addrspace == AddrSpace.REG) or \
(u.op is Ops.CAST and u.addrspace in (AddrSpace.GLOBAL, AddrSpace.LOCAL)) or \
(u.op in {Ops.STACK, *(GroupOp.ALU-{Ops.WHERE}), Ops.CAST, Ops.BITCAST} and child_count[u] == 1 and not getenv("EXPAND_SSA"))):
@@ -253,7 +275,7 @@ class ClangRenderer(CStyleLanguage):
if sys.platform == 'win32':
kernel_typedef = "__attribute__((ms_abi)) void"
def render_vector_prefix(self, dt:DType) -> str:
def render_vector_prefix(self, dt:DType, sz:int) -> str:
# round (down) to power of two (this is actually the default clang behavior)
alignment = 2**int(math.log2(dt.itemsize)) if getenv("ALIGNED", 1) and not dtypes.is_bool(dt) else 1
return f"typedef {self.render_dtype(dt.scalar())} {self.render_dtype(dt)} __attribute__((aligned({alignment}),ext_vector_type({dt.count})));"