mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 08:28:55 +08:00
add CUSTOM support to cstyle (#9020)
This commit is contained in:
@@ -152,6 +152,7 @@ class Ops(FastEnum):
|
||||
# device
|
||||
DEVICE = auto()
|
||||
MULTI = auto()
|
||||
CUSTOM = auto()
|
||||
|
||||
class GroupOp:
|
||||
Unary = {Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.SQRT, Ops.RECIP, Ops.NEG}
|
||||
@@ -593,7 +594,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
if self.op is Ops.ADD: return math.gcd(self.src[0].const_factor(), self.src[1].const_factor())
|
||||
if self.op is Ops.MUL: return self.src[0].arg if self.src[0].op is Ops.CONST else self.src[1].arg if self.src[1].op is Ops.CONST else 1
|
||||
return 1
|
||||
def divides(self, v) -> UOp|None:
|
||||
def divides(self, v:int) -> UOp|None:
|
||||
if v==1: return self
|
||||
if self.op is Ops.CONST: return self.const_like(self.arg//v) if self.arg%v == 0 else None
|
||||
if self.op is Ops.VCONST: return self.const_like(tuple(x//v for x in self.arg)) if all(x%v == 0 for x in self.arg) else None
|
||||
|
||||
@@ -52,6 +52,8 @@ base_rewrite = PatternMatcher([
|
||||
(UPat(Ops.GEP, name="x"), lambda ctx,x: ctx[x.src[0]] + \
|
||||
(f"[{x.arg[0]}]" if x.src[0].dtype.count > (8 if ctx.device in {"CUDA", "NV"} else 4) or ctx.device in {'CLANG', 'DSP'} else \
|
||||
f".{'xyzwabcd'[x.arg[0]]}")),
|
||||
# custom passes through with format
|
||||
(UPat(Ops.CUSTOM, name="x"), lambda ctx,x: x.arg.format(*[ctx[y] for y in x.src])),
|
||||
])
|
||||
|
||||
extra_pm = PatternMatcher([
|
||||
@@ -148,7 +150,7 @@ class CStyleLanguage(Renderer):
|
||||
assert l is not None, f"failed to render {u.op} {u.dtype} {[(x.op,x.dtype) for x in u.src]} {u.arg}"
|
||||
|
||||
if u.op in {Ops.ENDIF, Ops.ENDRANGE}: depth -= 1
|
||||
if u.op in {Ops.CONST, Ops.GEP, Ops.INDEX} or \
|
||||
if u.op in {Ops.CONST, Ops.GEP, Ops.INDEX, Ops.CUSTOM} or \
|
||||
(u.op in {Ops.VECTORIZE, *GroupOp.ALU, Ops.CAST, Ops.BITCAST} and child_count[u] == 1 and not getenv("EXPAND_SSA")):
|
||||
r[u] = l
|
||||
else:
|
||||
|
||||
@@ -9,11 +9,20 @@ from tinygrad.renderer.cstyle import ClangRenderer
|
||||
from tinygrad.runtime.autogen import libc, qcom_dsp
|
||||
if getenv("IOCTL"): import extra.dsp.run # noqa: F401 # pylint: disable=unused-import
|
||||
|
||||
from tinygrad.ops import PatternMatcher, UPat
|
||||
|
||||
dsp_pm = PatternMatcher([
|
||||
(UPat(Ops.VECTORIZE, src=UPat.var("y"))*UPat.var("x"), lambda x,y: UOp(Ops.CUSTOM, x.dtype, (y,), arg="{0}")*x),
|
||||
(UPat(Ops.DEFINE_ACC, src=(UPat(Ops.VECTORIZE, src=UPat(Ops.CONST, arg=0)),), dtype=dtypes.uchar.vec(128), name="d", allow_any_len=True),
|
||||
lambda d: d.replace(src=(UOp(Ops.CUSTOM, d.dtype, arg="__builtin_HEXAGON_V6_vd0_128B()"),)+d.src[1:]))
|
||||
])
|
||||
|
||||
class DSPRenderer(ClangRenderer):
|
||||
device = "DSP"
|
||||
supports_float4 = True
|
||||
buffer_suffix = " restrict __attribute__((align_value(128)))"
|
||||
kernel_prefix = "__attribute__((noinline)) "
|
||||
extra_matcher = dsp_pm+ClangRenderer.extra_matcher
|
||||
type_map = { **ClangRenderer.type_map, dtypes.uint64: "unsigned long long", dtypes.int64: "long long" }
|
||||
code_for_op = {**ClangRenderer.code_for_op, Ops.SIN: lambda x,dtype: f"__builtin_sin({x})",
|
||||
Ops.LOG2: lambda x,dtype: f"__builtin_log2l({x})" if dtype == dtypes.float64 else f"__builtin_log2f({x})",
|
||||
|
||||
@@ -115,7 +115,7 @@ spec = PatternMatcher([
|
||||
# NOTE: for testing, we let sinks be anything
|
||||
#(UPat(UOps.SINK, src=UPat(UOps.STORE)), lambda: True),
|
||||
(UPat(Ops.SINK, dtypes.void), lambda: True),
|
||||
(UPat(Ops.NOOP), lambda: True),
|
||||
(UPat((Ops.NOOP, Ops.CUSTOM)), lambda: True),
|
||||
|
||||
# PTX LOAD/STORE
|
||||
(UPat((Ops.LOAD, Ops.STORE), src=(UPat(dtype=dtypes.int64),), allow_any_len=True), lambda: True),
|
||||
|
||||
Reference in New Issue
Block a user