multi custom kernel support (#13716)

* multi custom kernel support

* custom kernel xfrom

* works

* no SPEC=2 on ck

* panic

* touchups
This commit is contained in:
George Hotz
2025-12-16 11:36:30 -04:00
committed by GitHub
parent 5778722979
commit e5a66ace80
9 changed files with 48 additions and 15 deletions

View File

@@ -310,7 +310,7 @@ jobs:
deps: testing_unit
python-version: '3.14'
- name: Test SPEC=2
run: IGNORE_OOB=0 SPEC=2 PYTHONPATH="." pytest --maxfail=10 -n auto --durations=30 --ignore=test/models --ignore test/unit/test_hashing.py --timeout 60 -k "not test_setitem_big" --splits 2 --group ${{ matrix.group }}
run: IGNORE_OOB=0 SPEC=2 PYTHONPATH="." pytest --maxfail=10 -n auto --durations=30 --ignore=test/models --ignore test/test_custom_kernel.py --ignore test/unit/test_hashing.py --timeout 60 -k "not test_setitem_big" --splits 2 --group ${{ matrix.group }}
fuzzing:
name: Fuzzing

View File

@@ -1,5 +1,5 @@
import unittest
from tinygrad import Tensor, UOp, Context
from tinygrad import Tensor, UOp
from tinygrad.dtype import AddrSpace
from tinygrad.uop.ops import KernelInfo, AxisType
@@ -117,6 +117,17 @@ class TestCustomKernel(unittest.TestCase):
out = c.flatten().tolist()
assert all(x == 2 for x in out), "all 2"
def test_simple_sharded(self):
devs = ("CPU:0", "CPU:1")
a = Tensor.ones(16, 16).contiguous().shard(devs, axis=0)
b = Tensor.ones(16, 16).contiguous().shard(devs, axis=0)
# ugly construction to get a sharded empty tensor
c = Tensor(Tensor.empty(8, 16, device=devs).uop.multi(0), device=devs)
c = Tensor.custom_kernel(c,a,b, fxn=custom_elementwise_add_kernel)[0]
out = c.flatten().tolist()
assert all(x == 2 for x in out), "all 2"
def test_multioutput(self):
a = Tensor.full((16, 16), 3.).contiguous()
b = Tensor.full((16, 16), 3.).contiguous()
@@ -184,7 +195,6 @@ class TestCustomKernel(unittest.TestCase):
def test_gemm_backward_custom(self): self.test_gemm_backward(True)
# NOTE: grad_fxn doesn't work with pyrender
@Context(SPEC=1)
def test_gemm_backward(self, custom_backward_gemm=False):
N = 4
a_rand = Tensor.randn(N, 8)

View File

@@ -42,7 +42,7 @@ pm_gradient = PatternMatcher([
(UPat(Ops.MULTI, name="ret"), lambda ctx, ret: ctx.shard(ret.device, ret.axis).src),
# NOTE: this is only correct when the KERNEL has a single output
(UPat(Ops.AFTER), lambda ctx: (ctx, ctx)),
(UPat(Ops.KERNEL, name="k"), lambda ctx, k: k.arg.grad_fxn(ctx, k)),
(UPat(Ops.CUSTOM_KERNEL, name="k"), lambda ctx, k: k.arg.grad_fxn(ctx, k)),
# there's no gradient for bitcast
(UPat(Ops.BITCAST), lambda: (None,)),
])

View File

@@ -218,6 +218,11 @@ multi_pm = PatternMatcher([
lambda multi,device,red: multi.src[0].allreduce(red.arg, device).multi(axis=multi.axis)),
(UPat((Ops.CAST, Ops.BITCAST, Ops.CONTIGUOUS, Ops.DETACH, Ops.CONTIGUOUS_BACKWARD),
src=(UPat(Ops.MULTI, name="multi"), ), name="root"), passthrough_multi),
# multi supports custom kernels with CUSTOM_KERNEL + AFTER
(UPat(Ops.CUSTOM_KERNEL, src=UPat(Ops.MULTI), name="ck"),
lambda ck: ck.replace(src=tuple(m.src[0] for m in ck.src))),
(UPat(Ops.AFTER, src=(UPat(Ops.MULTI, name="multi"), UPat(Ops.CUSTOM_KERNEL)), name="a"),
lambda multi,a: a.replace(src=(multi.src[0],)+a.src[1:]).multi(multi.axis))
])+replace_allreduce
def get_multi_map(big_sink:UOp) -> dict[UOp, UOp]:

View File

@@ -63,10 +63,17 @@ mop_cleanup = PatternMatcher([
lambda x,x2: x.replace(src=(x2.src[0], x.src[1])) if x.tag is None and x2.tag is None else None),
])
def resolve_custom_kernel(ck:UOp) -> UOp:
placeholders = [UOp.placeholder_like(s, slot=i) for i,s in enumerate(ck.src)]
return UOp(Ops.KERNEL, src=ck.src, arg=Kernel(ck.arg.fxn(*placeholders)))
earliest_rewrites = mop_cleanup+PatternMatcher([
# just removing it works...
(UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD), name="x"), lambda x: x.src[0]),
# resolve custom kernels
(UPat(Ops.CUSTOM_KERNEL, name="ck"), resolve_custom_kernel),
# remove CONTIGUOUS if the BUFFER is already contiguous
(UPat(Ops.BUFFER).f(Ops.RESHAPE, allow_any_len=True, name="r").f(Ops.CONTIGUOUS, name="c"), lambda r,c: r.replace(tag=c.tag)),

View File

@@ -75,6 +75,7 @@ class Ops(FastEnum):
# tensor graph ops
UNIQUE = auto(); DEVICE = auto(); KERNEL = auto(); ASSIGN = auto()
CUSTOM_KERNEL = auto()
# local unique
LUNIQUE = auto()

View File

@@ -7,7 +7,7 @@ from tinygrad.uop import Ops, GroupOp
from tinygrad.dtype import ConstType, ImageDType, dtypes, DType, truncate, PtrDType, least_upper_dtype, Invalid, InvalidType, AddrSpace
from tinygrad.helpers import ContextVar, all_int, prod, getenv, all_same, Context, partition, temp, unwrap, T, argfix, Metadata, flatten, TRACEMETA
from tinygrad.helpers import PROFILE, dedup, cdiv, cmod, diskcache_put, to_function_name, cpu_profile, TracingKey, VIZ, SPEC, CI
from tinygrad.helpers import strip_parens, colored, ansilen, printable
from tinygrad.helpers import strip_parens, colored, ansilen, printable, panic
if TYPE_CHECKING:
from tinygrad.device import Buffer, MultiBuffer
@@ -218,7 +218,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
match self.op:
# late ops don't have shape
case Ops.UNIQUE | Ops.LUNIQUE | Ops.DEVICE | Ops.RANGE | Ops.LOAD | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \
Ops.VECTORIZE | Ops.VCONST | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.CONTRACT:
Ops.VECTORIZE | Ops.VCONST | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.CONTRACT | Ops.CUSTOM_KERNEL:
return None
case Ops.INDEX:
@@ -470,7 +470,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
def is_contiguous(self):
# TODO: this is is_realized
if self.op is Ops.RESHAPE: return self.src[0].is_contiguous()
if self.op in {Ops.RESHAPE, Ops.MULTI}: return self.src[0].is_contiguous()
return self.op is Ops.BUFFER
def contiguous(self, *args, **kwargs):
@@ -840,9 +840,8 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
return self.src[0].after(self.store(val).end(*argfix(end)))
def custom_kernel(*srcs:UOp, fxn:Callable, grad_fxn:Callable|None=None) -> list[UOp]:
placeholders = [UOp.placeholder_like(s, slot=i) for i,s in enumerate(srcs)]
contig_srcs = tuple(x.contiguous() for x in srcs)
kernel = UOp(Ops.KERNEL, src=tuple(x.base for x in contig_srcs), arg=Kernel(fxn(*placeholders), grad_fxn=grad_fxn))
kernel = UOp(Ops.CUSTOM_KERNEL, src=contig_srcs, arg=CustomKernel(fxn=fxn, grad_fxn=grad_fxn))
return [s.after(kernel) for s in contig_srcs]
@dataclass(frozen=True)
@@ -855,6 +854,14 @@ class KernelInfo:
@property
def function_name(self): return to_function_name(self.name)
@dataclass(frozen=True)
class CustomKernel:
fxn: Callable
grad_fxn: Callable|None = None
# sadly CustomKernel can't be pickled or reconstructed as a str
def __reduce__(self): return (CustomKernel, (panic,))
def __repr__(self): return "CustomKernel(panic)"
@dataclass(frozen=True)
class Kernel:
ast: UOp

View File

@@ -1,8 +1,8 @@
import math
from typing import cast, Any
from tinygrad.uop.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, print_uops, AxisType, KernelInfo, pyrender, Kernel
from tinygrad.uop.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, print_uops, AxisType, KernelInfo, pyrender, Kernel, CustomKernel
from tinygrad.dtype import DType, ImageDType, dtypes, PtrDType, AddrSpace, Invalid
from tinygrad.helpers import DEBUG, Context, prod, SPEC, Metadata
from tinygrad.helpers import DEBUG, Context, prod, SPEC, Metadata, panic
from tinygrad.uop.validate import validate_index
# four specs:
@@ -54,7 +54,10 @@ movement_ops = PatternMatcher([
(UPat({Ops.ADD, Ops.MUL, Ops.IDIV}, dtype=dtypes.index), lambda: True),
# AFTER on Movement Op
(UPat(Ops.AFTER, src=(UPat(GroupOp.Movement),), allow_any_len=True), lambda: True),
(UPat(Ops.AFTER, src=(UPat(GroupOp.Movement.union({Ops.MULTI})),), allow_any_len=True), lambda: True),
# custom kernels allowed here
(UPat(Ops.CUSTOM_KERNEL), lambda: True),
])
_tensor_spec = PatternMatcher([
@@ -274,8 +277,8 @@ def type_verify(ast:UOp|list[UOp], check_spec:PatternMatcher):
from tinygrad.codegen.opt import Opt, OptOps
from tinygrad.schedule.rangeify import BufferizeOpts
glbls:dict[str, Any] = {"inf": math.inf, "nan": math.nan, "KernelInfo": KernelInfo, "Kernel": Kernel, "Metadata": Metadata,
"UOp": UOp, "dtypes": dtypes, "Ops": Ops, "AxisType": AxisType, "Invalid": Invalid,
"Opt": Opt, "OptOps": OptOps, "BufferizeOpts": BufferizeOpts, "AddrSpace": AddrSpace}
"UOp": UOp, "dtypes": dtypes, "Ops": Ops, "AxisType": AxisType, "Invalid": Invalid, "CustomKernel": CustomKernel,
"Opt": Opt, "OptOps": OptOps, "BufferizeOpts": BufferizeOpts, "AddrSpace": AddrSpace, "panic": panic}
def eval_pyrender(code:str) -> UOp:
lcls:dict[str, Any] = {}
exec(code, glbls, lcls)

View File

@@ -16,7 +16,7 @@ from tinygrad.dtype import dtypes
uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0", Ops.REDUCE: "#FF5B5B",
Ops.DEFINE_GLOBAL:"#cb9037", **{x:"#f2cb91" for x in {Ops.DEFINE_LOCAL, Ops.DEFINE_REG}}, Ops.REDUCE_AXIS: "#FF6B6B",
Ops.RANGE: "#c8a0e0", Ops.ASSIGN: "#909090", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff",
Ops.INDEX: "#cef263", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.KERNEL: "#3e7f55",
Ops.INDEX: "#cef263", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.KERNEL: "#3e7f55", Ops.CUSTOM_KERNEL: "#3ebf55",
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80",
Ops.BUFFER_VIEW: "#E5EAFF", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0", Ops.ENCDEC: "#bf71b6",
Ops.ALLREDUCE: "#ff40a0", Ops.MSELECT: "#d040a0", Ops.MSTACK: "#d040a0", Ops.CONTIGUOUS: "#FFC14D",