From e5a66ace80cf8a365243eb34087a8d450c81e4aa Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Tue, 16 Dec 2025 11:36:30 -0400 Subject: [PATCH] multi custom kernel support (#13716) * multi custom kernel support * custom kernel xfrom * works * no SPEC=2 on ck * panic * touchups --- .github/workflows/test.yml | 2 +- test/test_custom_kernel.py | 14 ++++++++++++-- tinygrad/gradient.py | 2 +- tinygrad/schedule/multi.py | 5 +++++ tinygrad/schedule/rangeify.py | 7 +++++++ tinygrad/uop/__init__.py | 1 + tinygrad/uop/ops.py | 17 ++++++++++++----- tinygrad/uop/spec.py | 13 ++++++++----- tinygrad/viz/serve.py | 2 +- 9 files changed, 48 insertions(+), 15 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 9250366dee..27d97fd484 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -310,7 +310,7 @@ jobs: deps: testing_unit python-version: '3.14' - name: Test SPEC=2 - run: IGNORE_OOB=0 SPEC=2 PYTHONPATH="." pytest --maxfail=10 -n auto --durations=30 --ignore=test/models --ignore test/unit/test_hashing.py --timeout 60 -k "not test_setitem_big" --splits 2 --group ${{ matrix.group }} + run: IGNORE_OOB=0 SPEC=2 PYTHONPATH="." pytest --maxfail=10 -n auto --durations=30 --ignore=test/models --ignore test/test_custom_kernel.py --ignore test/unit/test_hashing.py --timeout 60 -k "not test_setitem_big" --splits 2 --group ${{ matrix.group }} fuzzing: name: Fuzzing diff --git a/test/test_custom_kernel.py b/test/test_custom_kernel.py index c5929d5e26..40a6d2a71c 100644 --- a/test/test_custom_kernel.py +++ b/test/test_custom_kernel.py @@ -1,5 +1,5 @@ import unittest -from tinygrad import Tensor, UOp, Context +from tinygrad import Tensor, UOp from tinygrad.dtype import AddrSpace from tinygrad.uop.ops import KernelInfo, AxisType @@ -117,6 +117,17 @@ class TestCustomKernel(unittest.TestCase): out = c.flatten().tolist() assert all(x == 2 for x in out), "all 2" + def test_simple_sharded(self): + devs = ("CPU:0", "CPU:1") + + a = Tensor.ones(16, 16).contiguous().shard(devs, axis=0) + b = Tensor.ones(16, 16).contiguous().shard(devs, axis=0) + # ugly construction to get a sharded empty tensor + c = Tensor(Tensor.empty(8, 16, device=devs).uop.multi(0), device=devs) + c = Tensor.custom_kernel(c,a,b, fxn=custom_elementwise_add_kernel)[0] + out = c.flatten().tolist() + assert all(x == 2 for x in out), "all 2" + def test_multioutput(self): a = Tensor.full((16, 16), 3.).contiguous() b = Tensor.full((16, 16), 3.).contiguous() @@ -184,7 +195,6 @@ class TestCustomKernel(unittest.TestCase): def test_gemm_backward_custom(self): self.test_gemm_backward(True) # NOTE: grad_fxn doesn't work with pyrender - @Context(SPEC=1) def test_gemm_backward(self, custom_backward_gemm=False): N = 4 a_rand = Tensor.randn(N, 8) diff --git a/tinygrad/gradient.py b/tinygrad/gradient.py index 0bcd63d4ee..b4576cdf27 100644 --- a/tinygrad/gradient.py +++ b/tinygrad/gradient.py @@ -42,7 +42,7 @@ pm_gradient = PatternMatcher([ (UPat(Ops.MULTI, name="ret"), lambda ctx, ret: ctx.shard(ret.device, ret.axis).src), # NOTE: this is only correct when the KERNEL has a single output (UPat(Ops.AFTER), lambda ctx: (ctx, ctx)), - (UPat(Ops.KERNEL, name="k"), lambda ctx, k: k.arg.grad_fxn(ctx, k)), + (UPat(Ops.CUSTOM_KERNEL, name="k"), lambda ctx, k: k.arg.grad_fxn(ctx, k)), # there's no gradient for bitcast (UPat(Ops.BITCAST), lambda: (None,)), ]) diff --git a/tinygrad/schedule/multi.py b/tinygrad/schedule/multi.py index 769e4784ab..dbd9c7b098 100644 --- a/tinygrad/schedule/multi.py +++ b/tinygrad/schedule/multi.py @@ -218,6 +218,11 @@ multi_pm = PatternMatcher([ lambda multi,device,red: multi.src[0].allreduce(red.arg, device).multi(axis=multi.axis)), (UPat((Ops.CAST, Ops.BITCAST, Ops.CONTIGUOUS, Ops.DETACH, Ops.CONTIGUOUS_BACKWARD), src=(UPat(Ops.MULTI, name="multi"), ), name="root"), passthrough_multi), + # multi supports custom kernels with CUSTOM_KERNEL + AFTER + (UPat(Ops.CUSTOM_KERNEL, src=UPat(Ops.MULTI), name="ck"), + lambda ck: ck.replace(src=tuple(m.src[0] for m in ck.src))), + (UPat(Ops.AFTER, src=(UPat(Ops.MULTI, name="multi"), UPat(Ops.CUSTOM_KERNEL)), name="a"), + lambda multi,a: a.replace(src=(multi.src[0],)+a.src[1:]).multi(multi.axis)) ])+replace_allreduce def get_multi_map(big_sink:UOp) -> dict[UOp, UOp]: diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index 69b74f0446..a15de5f020 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -63,10 +63,17 @@ mop_cleanup = PatternMatcher([ lambda x,x2: x.replace(src=(x2.src[0], x.src[1])) if x.tag is None and x2.tag is None else None), ]) +def resolve_custom_kernel(ck:UOp) -> UOp: + placeholders = [UOp.placeholder_like(s, slot=i) for i,s in enumerate(ck.src)] + return UOp(Ops.KERNEL, src=ck.src, arg=Kernel(ck.arg.fxn(*placeholders))) + earliest_rewrites = mop_cleanup+PatternMatcher([ # just removing it works... (UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD), name="x"), lambda x: x.src[0]), + # resolve custom kernels + (UPat(Ops.CUSTOM_KERNEL, name="ck"), resolve_custom_kernel), + # remove CONTIGUOUS if the BUFFER is already contiguous (UPat(Ops.BUFFER).f(Ops.RESHAPE, allow_any_len=True, name="r").f(Ops.CONTIGUOUS, name="c"), lambda r,c: r.replace(tag=c.tag)), diff --git a/tinygrad/uop/__init__.py b/tinygrad/uop/__init__.py index 201117b13f..39a9427a87 100644 --- a/tinygrad/uop/__init__.py +++ b/tinygrad/uop/__init__.py @@ -75,6 +75,7 @@ class Ops(FastEnum): # tensor graph ops UNIQUE = auto(); DEVICE = auto(); KERNEL = auto(); ASSIGN = auto() + CUSTOM_KERNEL = auto() # local unique LUNIQUE = auto() diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index f241c81b13..3829483c1c 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -7,7 +7,7 @@ from tinygrad.uop import Ops, GroupOp from tinygrad.dtype import ConstType, ImageDType, dtypes, DType, truncate, PtrDType, least_upper_dtype, Invalid, InvalidType, AddrSpace from tinygrad.helpers import ContextVar, all_int, prod, getenv, all_same, Context, partition, temp, unwrap, T, argfix, Metadata, flatten, TRACEMETA from tinygrad.helpers import PROFILE, dedup, cdiv, cmod, diskcache_put, to_function_name, cpu_profile, TracingKey, VIZ, SPEC, CI -from tinygrad.helpers import strip_parens, colored, ansilen, printable +from tinygrad.helpers import strip_parens, colored, ansilen, printable, panic if TYPE_CHECKING: from tinygrad.device import Buffer, MultiBuffer @@ -218,7 +218,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass): match self.op: # late ops don't have shape case Ops.UNIQUE | Ops.LUNIQUE | Ops.DEVICE | Ops.RANGE | Ops.LOAD | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \ - Ops.VECTORIZE | Ops.VCONST | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.CONTRACT: + Ops.VECTORIZE | Ops.VCONST | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.CONTRACT | Ops.CUSTOM_KERNEL: return None case Ops.INDEX: @@ -470,7 +470,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass): def is_contiguous(self): # TODO: this is is_realized - if self.op is Ops.RESHAPE: return self.src[0].is_contiguous() + if self.op in {Ops.RESHAPE, Ops.MULTI}: return self.src[0].is_contiguous() return self.op is Ops.BUFFER def contiguous(self, *args, **kwargs): @@ -840,9 +840,8 @@ class UOp(OpMixin, metaclass=UOpMetaClass): return self.src[0].after(self.store(val).end(*argfix(end))) def custom_kernel(*srcs:UOp, fxn:Callable, grad_fxn:Callable|None=None) -> list[UOp]: - placeholders = [UOp.placeholder_like(s, slot=i) for i,s in enumerate(srcs)] contig_srcs = tuple(x.contiguous() for x in srcs) - kernel = UOp(Ops.KERNEL, src=tuple(x.base for x in contig_srcs), arg=Kernel(fxn(*placeholders), grad_fxn=grad_fxn)) + kernel = UOp(Ops.CUSTOM_KERNEL, src=contig_srcs, arg=CustomKernel(fxn=fxn, grad_fxn=grad_fxn)) return [s.after(kernel) for s in contig_srcs] @dataclass(frozen=True) @@ -855,6 +854,14 @@ class KernelInfo: @property def function_name(self): return to_function_name(self.name) +@dataclass(frozen=True) +class CustomKernel: + fxn: Callable + grad_fxn: Callable|None = None + # sadly CustomKernel can't be pickled or reconstructed as a str + def __reduce__(self): return (CustomKernel, (panic,)) + def __repr__(self): return "CustomKernel(panic)" + @dataclass(frozen=True) class Kernel: ast: UOp diff --git a/tinygrad/uop/spec.py b/tinygrad/uop/spec.py index 1ed84f9542..2c9312b49d 100644 --- a/tinygrad/uop/spec.py +++ b/tinygrad/uop/spec.py @@ -1,8 +1,8 @@ import math from typing import cast, Any -from tinygrad.uop.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, print_uops, AxisType, KernelInfo, pyrender, Kernel +from tinygrad.uop.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, print_uops, AxisType, KernelInfo, pyrender, Kernel, CustomKernel from tinygrad.dtype import DType, ImageDType, dtypes, PtrDType, AddrSpace, Invalid -from tinygrad.helpers import DEBUG, Context, prod, SPEC, Metadata +from tinygrad.helpers import DEBUG, Context, prod, SPEC, Metadata, panic from tinygrad.uop.validate import validate_index # four specs: @@ -54,7 +54,10 @@ movement_ops = PatternMatcher([ (UPat({Ops.ADD, Ops.MUL, Ops.IDIV}, dtype=dtypes.index), lambda: True), # AFTER on Movement Op - (UPat(Ops.AFTER, src=(UPat(GroupOp.Movement),), allow_any_len=True), lambda: True), + (UPat(Ops.AFTER, src=(UPat(GroupOp.Movement.union({Ops.MULTI})),), allow_any_len=True), lambda: True), + + # custom kernels allowed here + (UPat(Ops.CUSTOM_KERNEL), lambda: True), ]) _tensor_spec = PatternMatcher([ @@ -274,8 +277,8 @@ def type_verify(ast:UOp|list[UOp], check_spec:PatternMatcher): from tinygrad.codegen.opt import Opt, OptOps from tinygrad.schedule.rangeify import BufferizeOpts glbls:dict[str, Any] = {"inf": math.inf, "nan": math.nan, "KernelInfo": KernelInfo, "Kernel": Kernel, "Metadata": Metadata, - "UOp": UOp, "dtypes": dtypes, "Ops": Ops, "AxisType": AxisType, "Invalid": Invalid, - "Opt": Opt, "OptOps": OptOps, "BufferizeOpts": BufferizeOpts, "AddrSpace": AddrSpace} + "UOp": UOp, "dtypes": dtypes, "Ops": Ops, "AxisType": AxisType, "Invalid": Invalid, "CustomKernel": CustomKernel, + "Opt": Opt, "OptOps": OptOps, "BufferizeOpts": BufferizeOpts, "AddrSpace": AddrSpace, "panic": panic} def eval_pyrender(code:str) -> UOp: lcls:dict[str, Any] = {} exec(code, glbls, lcls) diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index befad3cade..f29c25d02d 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -16,7 +16,7 @@ from tinygrad.dtype import dtypes uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0", Ops.REDUCE: "#FF5B5B", Ops.DEFINE_GLOBAL:"#cb9037", **{x:"#f2cb91" for x in {Ops.DEFINE_LOCAL, Ops.DEFINE_REG}}, Ops.REDUCE_AXIS: "#FF6B6B", Ops.RANGE: "#c8a0e0", Ops.ASSIGN: "#909090", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff", - Ops.INDEX: "#cef263", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.KERNEL: "#3e7f55", + Ops.INDEX: "#cef263", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.KERNEL: "#3e7f55", Ops.CUSTOM_KERNEL: "#3ebf55", **{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80", Ops.BUFFER_VIEW: "#E5EAFF", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0", Ops.ENCDEC: "#bf71b6", Ops.ALLREDUCE: "#ff40a0", Ops.MSELECT: "#d040a0", Ops.MSTACK: "#d040a0", Ops.CONTIGUOUS: "#FFC14D",