mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
multi custom kernel support (#13716)
* multi custom kernel support * custom kernel xfrom * works * no SPEC=2 on ck * panic * touchups
This commit is contained in:
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
@@ -310,7 +310,7 @@ jobs:
|
||||
deps: testing_unit
|
||||
python-version: '3.14'
|
||||
- name: Test SPEC=2
|
||||
run: IGNORE_OOB=0 SPEC=2 PYTHONPATH="." pytest --maxfail=10 -n auto --durations=30 --ignore=test/models --ignore test/unit/test_hashing.py --timeout 60 -k "not test_setitem_big" --splits 2 --group ${{ matrix.group }}
|
||||
run: IGNORE_OOB=0 SPEC=2 PYTHONPATH="." pytest --maxfail=10 -n auto --durations=30 --ignore=test/models --ignore test/test_custom_kernel.py --ignore test/unit/test_hashing.py --timeout 60 -k "not test_setitem_big" --splits 2 --group ${{ matrix.group }}
|
||||
|
||||
fuzzing:
|
||||
name: Fuzzing
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import unittest
|
||||
from tinygrad import Tensor, UOp, Context
|
||||
from tinygrad import Tensor, UOp
|
||||
from tinygrad.dtype import AddrSpace
|
||||
from tinygrad.uop.ops import KernelInfo, AxisType
|
||||
|
||||
@@ -117,6 +117,17 @@ class TestCustomKernel(unittest.TestCase):
|
||||
out = c.flatten().tolist()
|
||||
assert all(x == 2 for x in out), "all 2"
|
||||
|
||||
def test_simple_sharded(self):
|
||||
devs = ("CPU:0", "CPU:1")
|
||||
|
||||
a = Tensor.ones(16, 16).contiguous().shard(devs, axis=0)
|
||||
b = Tensor.ones(16, 16).contiguous().shard(devs, axis=0)
|
||||
# ugly construction to get a sharded empty tensor
|
||||
c = Tensor(Tensor.empty(8, 16, device=devs).uop.multi(0), device=devs)
|
||||
c = Tensor.custom_kernel(c,a,b, fxn=custom_elementwise_add_kernel)[0]
|
||||
out = c.flatten().tolist()
|
||||
assert all(x == 2 for x in out), "all 2"
|
||||
|
||||
def test_multioutput(self):
|
||||
a = Tensor.full((16, 16), 3.).contiguous()
|
||||
b = Tensor.full((16, 16), 3.).contiguous()
|
||||
@@ -184,7 +195,6 @@ class TestCustomKernel(unittest.TestCase):
|
||||
|
||||
def test_gemm_backward_custom(self): self.test_gemm_backward(True)
|
||||
# NOTE: grad_fxn doesn't work with pyrender
|
||||
@Context(SPEC=1)
|
||||
def test_gemm_backward(self, custom_backward_gemm=False):
|
||||
N = 4
|
||||
a_rand = Tensor.randn(N, 8)
|
||||
|
||||
@@ -42,7 +42,7 @@ pm_gradient = PatternMatcher([
|
||||
(UPat(Ops.MULTI, name="ret"), lambda ctx, ret: ctx.shard(ret.device, ret.axis).src),
|
||||
# NOTE: this is only correct when the KERNEL has a single output
|
||||
(UPat(Ops.AFTER), lambda ctx: (ctx, ctx)),
|
||||
(UPat(Ops.KERNEL, name="k"), lambda ctx, k: k.arg.grad_fxn(ctx, k)),
|
||||
(UPat(Ops.CUSTOM_KERNEL, name="k"), lambda ctx, k: k.arg.grad_fxn(ctx, k)),
|
||||
# there's no gradient for bitcast
|
||||
(UPat(Ops.BITCAST), lambda: (None,)),
|
||||
])
|
||||
|
||||
@@ -218,6 +218,11 @@ multi_pm = PatternMatcher([
|
||||
lambda multi,device,red: multi.src[0].allreduce(red.arg, device).multi(axis=multi.axis)),
|
||||
(UPat((Ops.CAST, Ops.BITCAST, Ops.CONTIGUOUS, Ops.DETACH, Ops.CONTIGUOUS_BACKWARD),
|
||||
src=(UPat(Ops.MULTI, name="multi"), ), name="root"), passthrough_multi),
|
||||
# multi supports custom kernels with CUSTOM_KERNEL + AFTER
|
||||
(UPat(Ops.CUSTOM_KERNEL, src=UPat(Ops.MULTI), name="ck"),
|
||||
lambda ck: ck.replace(src=tuple(m.src[0] for m in ck.src))),
|
||||
(UPat(Ops.AFTER, src=(UPat(Ops.MULTI, name="multi"), UPat(Ops.CUSTOM_KERNEL)), name="a"),
|
||||
lambda multi,a: a.replace(src=(multi.src[0],)+a.src[1:]).multi(multi.axis))
|
||||
])+replace_allreduce
|
||||
|
||||
def get_multi_map(big_sink:UOp) -> dict[UOp, UOp]:
|
||||
|
||||
@@ -63,10 +63,17 @@ mop_cleanup = PatternMatcher([
|
||||
lambda x,x2: x.replace(src=(x2.src[0], x.src[1])) if x.tag is None and x2.tag is None else None),
|
||||
])
|
||||
|
||||
def resolve_custom_kernel(ck:UOp) -> UOp:
|
||||
placeholders = [UOp.placeholder_like(s, slot=i) for i,s in enumerate(ck.src)]
|
||||
return UOp(Ops.KERNEL, src=ck.src, arg=Kernel(ck.arg.fxn(*placeholders)))
|
||||
|
||||
earliest_rewrites = mop_cleanup+PatternMatcher([
|
||||
# just removing it works...
|
||||
(UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD), name="x"), lambda x: x.src[0]),
|
||||
|
||||
# resolve custom kernels
|
||||
(UPat(Ops.CUSTOM_KERNEL, name="ck"), resolve_custom_kernel),
|
||||
|
||||
# remove CONTIGUOUS if the BUFFER is already contiguous
|
||||
(UPat(Ops.BUFFER).f(Ops.RESHAPE, allow_any_len=True, name="r").f(Ops.CONTIGUOUS, name="c"), lambda r,c: r.replace(tag=c.tag)),
|
||||
|
||||
|
||||
@@ -75,6 +75,7 @@ class Ops(FastEnum):
|
||||
|
||||
# tensor graph ops
|
||||
UNIQUE = auto(); DEVICE = auto(); KERNEL = auto(); ASSIGN = auto()
|
||||
CUSTOM_KERNEL = auto()
|
||||
|
||||
# local unique
|
||||
LUNIQUE = auto()
|
||||
|
||||
@@ -7,7 +7,7 @@ from tinygrad.uop import Ops, GroupOp
|
||||
from tinygrad.dtype import ConstType, ImageDType, dtypes, DType, truncate, PtrDType, least_upper_dtype, Invalid, InvalidType, AddrSpace
|
||||
from tinygrad.helpers import ContextVar, all_int, prod, getenv, all_same, Context, partition, temp, unwrap, T, argfix, Metadata, flatten, TRACEMETA
|
||||
from tinygrad.helpers import PROFILE, dedup, cdiv, cmod, diskcache_put, to_function_name, cpu_profile, TracingKey, VIZ, SPEC, CI
|
||||
from tinygrad.helpers import strip_parens, colored, ansilen, printable
|
||||
from tinygrad.helpers import strip_parens, colored, ansilen, printable, panic
|
||||
if TYPE_CHECKING:
|
||||
from tinygrad.device import Buffer, MultiBuffer
|
||||
|
||||
@@ -218,7 +218,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
match self.op:
|
||||
# late ops don't have shape
|
||||
case Ops.UNIQUE | Ops.LUNIQUE | Ops.DEVICE | Ops.RANGE | Ops.LOAD | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \
|
||||
Ops.VECTORIZE | Ops.VCONST | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.CONTRACT:
|
||||
Ops.VECTORIZE | Ops.VCONST | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.CONTRACT | Ops.CUSTOM_KERNEL:
|
||||
return None
|
||||
|
||||
case Ops.INDEX:
|
||||
@@ -470,7 +470,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
|
||||
def is_contiguous(self):
|
||||
# TODO: this is is_realized
|
||||
if self.op is Ops.RESHAPE: return self.src[0].is_contiguous()
|
||||
if self.op in {Ops.RESHAPE, Ops.MULTI}: return self.src[0].is_contiguous()
|
||||
return self.op is Ops.BUFFER
|
||||
|
||||
def contiguous(self, *args, **kwargs):
|
||||
@@ -840,9 +840,8 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
return self.src[0].after(self.store(val).end(*argfix(end)))
|
||||
|
||||
def custom_kernel(*srcs:UOp, fxn:Callable, grad_fxn:Callable|None=None) -> list[UOp]:
|
||||
placeholders = [UOp.placeholder_like(s, slot=i) for i,s in enumerate(srcs)]
|
||||
contig_srcs = tuple(x.contiguous() for x in srcs)
|
||||
kernel = UOp(Ops.KERNEL, src=tuple(x.base for x in contig_srcs), arg=Kernel(fxn(*placeholders), grad_fxn=grad_fxn))
|
||||
kernel = UOp(Ops.CUSTOM_KERNEL, src=contig_srcs, arg=CustomKernel(fxn=fxn, grad_fxn=grad_fxn))
|
||||
return [s.after(kernel) for s in contig_srcs]
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -855,6 +854,14 @@ class KernelInfo:
|
||||
@property
|
||||
def function_name(self): return to_function_name(self.name)
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CustomKernel:
|
||||
fxn: Callable
|
||||
grad_fxn: Callable|None = None
|
||||
# sadly CustomKernel can't be pickled or reconstructed as a str
|
||||
def __reduce__(self): return (CustomKernel, (panic,))
|
||||
def __repr__(self): return "CustomKernel(panic)"
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Kernel:
|
||||
ast: UOp
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import math
|
||||
from typing import cast, Any
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, print_uops, AxisType, KernelInfo, pyrender, Kernel
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, print_uops, AxisType, KernelInfo, pyrender, Kernel, CustomKernel
|
||||
from tinygrad.dtype import DType, ImageDType, dtypes, PtrDType, AddrSpace, Invalid
|
||||
from tinygrad.helpers import DEBUG, Context, prod, SPEC, Metadata
|
||||
from tinygrad.helpers import DEBUG, Context, prod, SPEC, Metadata, panic
|
||||
from tinygrad.uop.validate import validate_index
|
||||
|
||||
# four specs:
|
||||
@@ -54,7 +54,10 @@ movement_ops = PatternMatcher([
|
||||
(UPat({Ops.ADD, Ops.MUL, Ops.IDIV}, dtype=dtypes.index), lambda: True),
|
||||
|
||||
# AFTER on Movement Op
|
||||
(UPat(Ops.AFTER, src=(UPat(GroupOp.Movement),), allow_any_len=True), lambda: True),
|
||||
(UPat(Ops.AFTER, src=(UPat(GroupOp.Movement.union({Ops.MULTI})),), allow_any_len=True), lambda: True),
|
||||
|
||||
# custom kernels allowed here
|
||||
(UPat(Ops.CUSTOM_KERNEL), lambda: True),
|
||||
])
|
||||
|
||||
_tensor_spec = PatternMatcher([
|
||||
@@ -274,8 +277,8 @@ def type_verify(ast:UOp|list[UOp], check_spec:PatternMatcher):
|
||||
from tinygrad.codegen.opt import Opt, OptOps
|
||||
from tinygrad.schedule.rangeify import BufferizeOpts
|
||||
glbls:dict[str, Any] = {"inf": math.inf, "nan": math.nan, "KernelInfo": KernelInfo, "Kernel": Kernel, "Metadata": Metadata,
|
||||
"UOp": UOp, "dtypes": dtypes, "Ops": Ops, "AxisType": AxisType, "Invalid": Invalid,
|
||||
"Opt": Opt, "OptOps": OptOps, "BufferizeOpts": BufferizeOpts, "AddrSpace": AddrSpace}
|
||||
"UOp": UOp, "dtypes": dtypes, "Ops": Ops, "AxisType": AxisType, "Invalid": Invalid, "CustomKernel": CustomKernel,
|
||||
"Opt": Opt, "OptOps": OptOps, "BufferizeOpts": BufferizeOpts, "AddrSpace": AddrSpace, "panic": panic}
|
||||
def eval_pyrender(code:str) -> UOp:
|
||||
lcls:dict[str, Any] = {}
|
||||
exec(code, glbls, lcls)
|
||||
|
||||
@@ -16,7 +16,7 @@ from tinygrad.dtype import dtypes
|
||||
uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0", Ops.REDUCE: "#FF5B5B",
|
||||
Ops.DEFINE_GLOBAL:"#cb9037", **{x:"#f2cb91" for x in {Ops.DEFINE_LOCAL, Ops.DEFINE_REG}}, Ops.REDUCE_AXIS: "#FF6B6B",
|
||||
Ops.RANGE: "#c8a0e0", Ops.ASSIGN: "#909090", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff",
|
||||
Ops.INDEX: "#cef263", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.KERNEL: "#3e7f55",
|
||||
Ops.INDEX: "#cef263", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.KERNEL: "#3e7f55", Ops.CUSTOM_KERNEL: "#3ebf55",
|
||||
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80",
|
||||
Ops.BUFFER_VIEW: "#E5EAFF", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0", Ops.ENCDEC: "#bf71b6",
|
||||
Ops.ALLREDUCE: "#ff40a0", Ops.MSELECT: "#d040a0", Ops.MSTACK: "#d040a0", Ops.CONTIGUOUS: "#FFC14D",
|
||||
|
||||
Reference in New Issue
Block a user