From d0d58a6771ed7cd4cf653b544debd16e2896ede7 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Tue, 11 Feb 2025 18:02:58 +0800 Subject: [PATCH] add CUSTOM support to cstyle (#9020) --- tinygrad/ops.py | 3 ++- tinygrad/renderer/cstyle.py | 4 +++- tinygrad/runtime/ops_dsp.py | 9 +++++++++ tinygrad/spec.py | 2 +- 4 files changed, 15 insertions(+), 3 deletions(-) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index b96eb814f2..1c29d839e7 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -152,6 +152,7 @@ class Ops(FastEnum): # device DEVICE = auto() MULTI = auto() + CUSTOM = auto() class GroupOp: Unary = {Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.SQRT, Ops.RECIP, Ops.NEG} @@ -593,7 +594,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): if self.op is Ops.ADD: return math.gcd(self.src[0].const_factor(), self.src[1].const_factor()) if self.op is Ops.MUL: return self.src[0].arg if self.src[0].op is Ops.CONST else self.src[1].arg if self.src[1].op is Ops.CONST else 1 return 1 - def divides(self, v) -> UOp|None: + def divides(self, v:int) -> UOp|None: if v==1: return self if self.op is Ops.CONST: return self.const_like(self.arg//v) if self.arg%v == 0 else None if self.op is Ops.VCONST: return self.const_like(tuple(x//v for x in self.arg)) if all(x%v == 0 for x in self.arg) else None diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 2a21e86e46..efd8dc8d7d 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -52,6 +52,8 @@ base_rewrite = PatternMatcher([ (UPat(Ops.GEP, name="x"), lambda ctx,x: ctx[x.src[0]] + \ (f"[{x.arg[0]}]" if x.src[0].dtype.count > (8 if ctx.device in {"CUDA", "NV"} else 4) or ctx.device in {'CLANG', 'DSP'} else \ f".{'xyzwabcd'[x.arg[0]]}")), + # custom passes through with format + (UPat(Ops.CUSTOM, name="x"), lambda ctx,x: x.arg.format(*[ctx[y] for y in x.src])), ]) extra_pm = PatternMatcher([ @@ -148,7 +150,7 @@ class CStyleLanguage(Renderer): assert l is not None, f"failed to render {u.op} {u.dtype} {[(x.op,x.dtype) for x in u.src]} {u.arg}" if u.op in {Ops.ENDIF, Ops.ENDRANGE}: depth -= 1 - if u.op in {Ops.CONST, Ops.GEP, Ops.INDEX} or \ + if u.op in {Ops.CONST, Ops.GEP, Ops.INDEX, Ops.CUSTOM} or \ (u.op in {Ops.VECTORIZE, *GroupOp.ALU, Ops.CAST, Ops.BITCAST} and child_count[u] == 1 and not getenv("EXPAND_SSA")): r[u] = l else: diff --git a/tinygrad/runtime/ops_dsp.py b/tinygrad/runtime/ops_dsp.py index 0b2a53f7db..0b00ec7232 100644 --- a/tinygrad/runtime/ops_dsp.py +++ b/tinygrad/runtime/ops_dsp.py @@ -9,11 +9,20 @@ from tinygrad.renderer.cstyle import ClangRenderer from tinygrad.runtime.autogen import libc, qcom_dsp if getenv("IOCTL"): import extra.dsp.run # noqa: F401 # pylint: disable=unused-import +from tinygrad.ops import PatternMatcher, UPat + +dsp_pm = PatternMatcher([ + (UPat(Ops.VECTORIZE, src=UPat.var("y"))*UPat.var("x"), lambda x,y: UOp(Ops.CUSTOM, x.dtype, (y,), arg="{0}")*x), + (UPat(Ops.DEFINE_ACC, src=(UPat(Ops.VECTORIZE, src=UPat(Ops.CONST, arg=0)),), dtype=dtypes.uchar.vec(128), name="d", allow_any_len=True), + lambda d: d.replace(src=(UOp(Ops.CUSTOM, d.dtype, arg="__builtin_HEXAGON_V6_vd0_128B()"),)+d.src[1:])) +]) + class DSPRenderer(ClangRenderer): device = "DSP" supports_float4 = True buffer_suffix = " restrict __attribute__((align_value(128)))" kernel_prefix = "__attribute__((noinline)) " + extra_matcher = dsp_pm+ClangRenderer.extra_matcher type_map = { **ClangRenderer.type_map, dtypes.uint64: "unsigned long long", dtypes.int64: "long long" } code_for_op = {**ClangRenderer.code_for_op, Ops.SIN: lambda x,dtype: f"__builtin_sin({x})", Ops.LOG2: lambda x,dtype: f"__builtin_log2l({x})" if dtype == dtypes.float64 else f"__builtin_log2f({x})", diff --git a/tinygrad/spec.py b/tinygrad/spec.py index 41ce1341e2..567875d664 100644 --- a/tinygrad/spec.py +++ b/tinygrad/spec.py @@ -115,7 +115,7 @@ spec = PatternMatcher([ # NOTE: for testing, we let sinks be anything #(UPat(UOps.SINK, src=UPat(UOps.STORE)), lambda: True), (UPat(Ops.SINK, dtypes.void), lambda: True), - (UPat(Ops.NOOP), lambda: True), + (UPat((Ops.NOOP, Ops.CUSTOM)), lambda: True), # PTX LOAD/STORE (UPat((Ops.LOAD, Ops.STORE), src=(UPat(dtype=dtypes.int64),), allow_any_len=True), lambda: True),