mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
cstyle new style
This commit is contained in:
@@ -8,7 +8,6 @@ from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType, AddrSpace, trunc
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.codegen.late.devectorizer import no_vectorized_alu
|
||||
|
||||
|
||||
base_rewrite = PatternMatcher([
|
||||
# defines
|
||||
(UPat(Ops.DEFINE_REG, name="x"), lambda ctx,x: f"{ctx.render_dtype(x.dtype.base)} {ctx[x]}[{x.max_numel()}];"),
|
||||
@@ -47,8 +46,9 @@ base_rewrite = PatternMatcher([
|
||||
# default const render
|
||||
(UPat(Ops.CONST, name="x"), lambda ctx,x: str(x.arg)),
|
||||
|
||||
# movement ops
|
||||
(UPat.var("buf").index(UPat.var('idx')), lambda ctx,buf,idx: f"({ctx[buf]}+{strip_parens(ctx[idx]) if idx.arg == Ops.ADD else ctx[idx]})"),
|
||||
# SHRINK/INDEX
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var('idx'))), lambda ctx,**kwargs: ctx.render_index(**kwargs)),
|
||||
(UPat(Ops.SHRINK, src=(UPat.var("buf"), UPat.var('idx'), UPat.cvar())), lambda ctx,**kwargs: ctx.render_index(**kwargs)),
|
||||
(UPat(Ops.STACK, name="x"),
|
||||
lambda ctx,x: f"{ctx.float4.replace('float4', ctx.render_dtype(x.dtype))}" + \
|
||||
f"{ctx.float4_style[0]}{','.join([ctx[y] for y in x.src])}{ctx.float4_style[1]}"),
|
||||
@@ -106,13 +106,15 @@ pm_manual_bf16_cast = PatternMatcher([
|
||||
(UPat(Ops.CAST, dtype=dtypes.bfloat16, src=(UPat.var("x", dtype=dtypes.float),)), cast_float_to_bf16),
|
||||
])
|
||||
|
||||
def uops_to_dtypes(uops:list[UOp]) -> list[DType]: return dedup(u.dtype for u in uops if not isinstance(u.dtype, (ImageDType, PtrDType)))
|
||||
def uops_to_dtypes(uops:list[UOp]) -> list[tuple[DType, int]]:
|
||||
return dedup((u.dtype, u.max_numel()) for u in uops if u.addrspace is AddrSpace.ANON)
|
||||
|
||||
# (name, dims, dtype_in, dtype_out, device, threads, upcast_axes, reduce_axes)
|
||||
def wmma_args(uops:list[UOp]):
|
||||
return dedup((uop.arg[0], uop.arg[1], uop.arg[2], uop.dtype.scalar(), *(uop.arg[4:8])) for uop in uops if uop.op is Ops.WMMA)
|
||||
|
||||
class CStyleLanguage(Renderer):
|
||||
new_style = True
|
||||
kernel_typedef: str = "void"
|
||||
buffer_prefix: str = ""
|
||||
buffer_suffix: str = ""
|
||||
@@ -146,7 +148,7 @@ class CStyleLanguage(Renderer):
|
||||
tmp = ""
|
||||
if any(isinstance(u.dtype, ImageDType) for _,(u,_) in bufs):
|
||||
tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n"
|
||||
buftypes = [(name, self.render_dtype(u.dtype, mutable)+self.buffer_suffix if isinstance(u.dtype, (ImageDType, PtrDType)) else
|
||||
buftypes = [(name, self.render_type(u, mutable)+self.buffer_suffix if u.addrspace == AddrSpace.GLOBAL else
|
||||
self.arg_int_prefix if u.dtype == dtypes.int else None) for name,(u,mutable) in bufs]
|
||||
local_dims = [u.src[0] for u in uops if u.op is Ops.SPECIAL and u.arg[0] == "l"]
|
||||
launch_bounds = prod([d.vmax for d in local_dims])
|
||||
@@ -155,6 +157,26 @@ class CStyleLanguage(Renderer):
|
||||
[") {\n" + tmp] + ['\n'.join(kernel), "\n}"])
|
||||
return prg if prefix is None else "\n".join(prefix)+f"\n{prg}"
|
||||
|
||||
def render_index(self, buf:UOp, idx:UOp):
|
||||
if buf.addrspace == AddrSpace.ANON:
|
||||
assert idx.op is Ops.CONST, f"{idx.op} must be CONST"
|
||||
return f"{self[buf]}[{idx.arg}]" # TODO: is this syntax okay?
|
||||
else:
|
||||
return f"({self[buf]}+{strip_parens(self[idx]) if idx.arg == Ops.ADD else self[idx]})"
|
||||
|
||||
def render_type(self, u:UOp, mutable=True):
|
||||
# TODO: get this from the shape
|
||||
if isinstance(u.dtype, ImageDType): return f"{'write_only' if mutable else 'read_only'} image2d_t"
|
||||
if u.addrspace in (AddrSpace.LOCAL, AddrSpace.GLOBAL):
|
||||
prefix = ""
|
||||
if u.addrspace == AddrSpace.LOCAL and self.smem_prefix_for_cast: prefix = self.smem_prefix
|
||||
if u.addrspace == AddrSpace.GLOBAL: prefix = self.buffer_prefix
|
||||
return prefix + self.type_map.get(scalar:=u.dtype.scalar(), scalar.name) + "*"
|
||||
elif (sz:=u.max_numel()) > 1:
|
||||
# pointers
|
||||
return self.type_map.get(scalar:=u.dtype.scalar(), scalar.name).replace(" ", "_") + str(sz)
|
||||
return self.type_map.get(scalar:=u.dtype.scalar(), scalar.name)
|
||||
|
||||
def render_cast(self, u:UOp, val:str) -> str: return f"({self.render_dtype(u.dtype)})({val})"
|
||||
def render_dtype(self, dt:DType, mutable=True) -> str:
|
||||
if isinstance(dt, ImageDType): return f"{'write_only' if mutable else 'read_only'} image2d_t"
|
||||
@@ -208,7 +230,7 @@ class CStyleLanguage(Renderer):
|
||||
assert l is not None, f"failed to render {u.op} {u.dtype} {[(x.op,x.dtype) for x in u.src]} {u.arg}"
|
||||
|
||||
if u.op in {Ops.ENDIF, Ops.END}: depth -= 1
|
||||
if (u.op is not Ops.CAST or u.dtype.vcount == 1) and (u.op in {Ops.CONST, Ops.GEP, Ops.INDEX, Ops.CUSTOMI} or \
|
||||
if (u.op is not Ops.CAST or u.dtype.vcount == 1) and (u.op in {Ops.CONST, Ops.GEP, Ops.INDEX, Ops.SHRINK, Ops.CUSTOMI} or \
|
||||
(u.op is Ops.LOAD and u.src[0].addrspace == AddrSpace.REG) or \
|
||||
(u.op is Ops.CAST and u.addrspace in (AddrSpace.GLOBAL, AddrSpace.LOCAL)) or \
|
||||
(u.op in {Ops.STACK, *(GroupOp.ALU-{Ops.WHERE}), Ops.CAST, Ops.BITCAST} and child_count[u] == 1 and not getenv("EXPAND_SSA"))):
|
||||
@@ -253,7 +275,7 @@ class ClangRenderer(CStyleLanguage):
|
||||
|
||||
if sys.platform == 'win32':
|
||||
kernel_typedef = "__attribute__((ms_abi)) void"
|
||||
def render_vector_prefix(self, dt:DType) -> str:
|
||||
def render_vector_prefix(self, dt:DType, sz:int) -> str:
|
||||
# round (down) to power of two (this is actually the default clang behavior)
|
||||
alignment = 2**int(math.log2(dt.itemsize)) if getenv("ALIGNED", 1) and not dtypes.is_bool(dt) else 1
|
||||
return f"typedef {self.render_dtype(dt.scalar())} {self.render_dtype(dt)} __attribute__((aligned({alignment}),ext_vector_type({dt.count})));"
|
||||
|
||||
Reference in New Issue
Block a user