add CUSTOM support to cstyle (#9020)

This commit is contained in:
George Hotz
2025-02-11 18:02:58 +08:00
committed by GitHub
parent fb698920f1
commit d0d58a6771
4 changed files with 15 additions and 3 deletions

View File

@@ -152,6 +152,7 @@ class Ops(FastEnum):
# device
DEVICE = auto()
MULTI = auto()
CUSTOM = auto()
class GroupOp:
Unary = {Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.SQRT, Ops.RECIP, Ops.NEG}
@@ -593,7 +594,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
if self.op is Ops.ADD: return math.gcd(self.src[0].const_factor(), self.src[1].const_factor())
if self.op is Ops.MUL: return self.src[0].arg if self.src[0].op is Ops.CONST else self.src[1].arg if self.src[1].op is Ops.CONST else 1
return 1
def divides(self, v) -> UOp|None:
def divides(self, v:int) -> UOp|None:
if v==1: return self
if self.op is Ops.CONST: return self.const_like(self.arg//v) if self.arg%v == 0 else None
if self.op is Ops.VCONST: return self.const_like(tuple(x//v for x in self.arg)) if all(x%v == 0 for x in self.arg) else None

View File

@@ -52,6 +52,8 @@ base_rewrite = PatternMatcher([
(UPat(Ops.GEP, name="x"), lambda ctx,x: ctx[x.src[0]] + \
(f"[{x.arg[0]}]" if x.src[0].dtype.count > (8 if ctx.device in {"CUDA", "NV"} else 4) or ctx.device in {'CLANG', 'DSP'} else \
f".{'xyzwabcd'[x.arg[0]]}")),
# custom passes through with format
(UPat(Ops.CUSTOM, name="x"), lambda ctx,x: x.arg.format(*[ctx[y] for y in x.src])),
])
extra_pm = PatternMatcher([
@@ -148,7 +150,7 @@ class CStyleLanguage(Renderer):
assert l is not None, f"failed to render {u.op} {u.dtype} {[(x.op,x.dtype) for x in u.src]} {u.arg}"
if u.op in {Ops.ENDIF, Ops.ENDRANGE}: depth -= 1
if u.op in {Ops.CONST, Ops.GEP, Ops.INDEX} or \
if u.op in {Ops.CONST, Ops.GEP, Ops.INDEX, Ops.CUSTOM} or \
(u.op in {Ops.VECTORIZE, *GroupOp.ALU, Ops.CAST, Ops.BITCAST} and child_count[u] == 1 and not getenv("EXPAND_SSA")):
r[u] = l
else:

View File

@@ -9,11 +9,20 @@ from tinygrad.renderer.cstyle import ClangRenderer
from tinygrad.runtime.autogen import libc, qcom_dsp
if getenv("IOCTL"): import extra.dsp.run # noqa: F401 # pylint: disable=unused-import
from tinygrad.ops import PatternMatcher, UPat
dsp_pm = PatternMatcher([
(UPat(Ops.VECTORIZE, src=UPat.var("y"))*UPat.var("x"), lambda x,y: UOp(Ops.CUSTOM, x.dtype, (y,), arg="{0}")*x),
(UPat(Ops.DEFINE_ACC, src=(UPat(Ops.VECTORIZE, src=UPat(Ops.CONST, arg=0)),), dtype=dtypes.uchar.vec(128), name="d", allow_any_len=True),
lambda d: d.replace(src=(UOp(Ops.CUSTOM, d.dtype, arg="__builtin_HEXAGON_V6_vd0_128B()"),)+d.src[1:]))
])
class DSPRenderer(ClangRenderer):
device = "DSP"
supports_float4 = True
buffer_suffix = " restrict __attribute__((align_value(128)))"
kernel_prefix = "__attribute__((noinline)) "
extra_matcher = dsp_pm+ClangRenderer.extra_matcher
type_map = { **ClangRenderer.type_map, dtypes.uint64: "unsigned long long", dtypes.int64: "long long" }
code_for_op = {**ClangRenderer.code_for_op, Ops.SIN: lambda x,dtype: f"__builtin_sin({x})",
Ops.LOG2: lambda x,dtype: f"__builtin_log2l({x})" if dtype == dtypes.float64 else f"__builtin_log2f({x})",

View File

@@ -115,7 +115,7 @@ spec = PatternMatcher([
# NOTE: for testing, we let sinks be anything
#(UPat(UOps.SINK, src=UPat(UOps.STORE)), lambda: True),
(UPat(Ops.SINK, dtypes.void), lambda: True),
(UPat(Ops.NOOP), lambda: True),
(UPat((Ops.NOOP, Ops.CUSTOM)), lambda: True),
# PTX LOAD/STORE
(UPat((Ops.LOAD, Ops.STORE), src=(UPat(dtype=dtypes.int64),), allow_any_len=True), lambda: True),