diff --git a/test/unit/test_function.py b/test/unit/test_function.py index f6a6f99cfc..c3dfb0a410 100644 --- a/test/unit/test_function.py +++ b/test/unit/test_function.py @@ -522,6 +522,15 @@ class TestFunctionTuple(unittest.TestCase): np.testing.assert_allclose(g(a).numpy(), 14.0) + def test_custom_kernel_inplace_output_is_implicit(self): + # a custom_kernel output the kernel also READS (in-place add) is not write-only, so it must be captured as an input + def inplace_add(C:UOp, A:UOp) -> UOp: + i = UOp.range(A.shape[0], 0) + return C[i].store(C[i].load() + A[i]).end(i).sink(arg=KernelInfo(name="inplace_add")) + @function(precompile=True, allow_implicit=False) + def f(a:Tensor): return Tensor.custom_kernel(Tensor.empty(*a.shape, dtype=a.dtype, device=a.device), a, fxn=inplace_add)[0] + with self.assertRaisesRegex(RuntimeError, "implicit buffer"): f(Tensor([1., 2., 3., 4.]).contiguous().realize()) + def test_custom_kernel_precompile_further_compute(self): def my_kernel(C:UOp, A:UOp) -> UOp: i = UOp.range(A.shape[0], 0) diff --git a/tinygrad/callify.py b/tinygrad/callify.py index 1029b6cd8b..e8d1f67bba 100644 --- a/tinygrad/callify.py +++ b/tinygrad/callify.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from tinygrad.dtype import dtypes, AddrSpace, PtrDType, ImageDType +from tinygrad.dtype import dtypes, AddrSpace, PtrDType, ImageDType, Invalid from tinygrad.uop.ops import UOp, UPat, PatternMatcher, Ops, GroupOp, graph_rewrite, track_rewrites from tinygrad.helpers import VIZ, pluralize, all_int @@ -106,6 +106,9 @@ def _precompiled_output_redirect(s:UOp, t:UOp) -> UOp|None: if s.op in {Ops.BUFFER, Ops.MULTI} and s.has_buffer_identity(): return t return None +def _is_invalid_init_store(base:UOp, dep:UOp) -> bool: + return dep.op is Ops.STORE and dep.src[0].buf_uop is base.buf_uop and dep.src[1].base.arg is Invalid + def transform_precompiled_call(c:UOp) -> UOp|None: if not c.arg.precompile: return None assert c.src[0].op is Ops.TUPLE, f"expected TUPLE body for precompiled FUNCTION, got {c.src[0].op}" @@ -121,11 +124,14 @@ def transform_precompiled_call(c:UOp) -> UOp|None: items:list[UOp] = [] for s, t in zip(srcs, targets): after_deps:list[UOp] = [] + init_afters:list[UOp] = [] while s.op is Ops.AFTER: - after_deps.extend(s.src[1:]) + if all(_is_invalid_init_store(s.src[0], x) for x in s.src[1:]): init_afters.append(s) + else: after_deps.extend(s.src[1:]) s = s.src[0] if (placed := _precompiled_output_redirect(s, t)) is not None and s not in subs: subs[s] = placed + subs.update((x, placed) for x in init_afters) items.append(s.after(*after_deps) if after_deps else s) else: items.append(t.after(t.store(s), *after_deps)) diff --git a/tinygrad/function.py b/tinygrad/function.py index db5a491544..f510f2c13b 100644 --- a/tinygrad/function.py +++ b/tinygrad/function.py @@ -1,26 +1,29 @@ -import functools, itertools, time +import functools, time from typing import Generic, TypeVar, Callable, cast, overload from tinygrad.helpers import Context, dedup, getenv, DEBUG -from tinygrad.uop.ops import UOp, Ops, graph_rewrite, PatternMatcher, UPat +from tinygrad.uop.ops import UOp, Ops, ProgramInfo, graph_rewrite, PatternMatcher, UPat from tinygrad.tensor import Tensor from tinygrad.nn.state import get_state_dict def add_to_ctx(ctx, x:UOp): + if x.buf_uop in ctx[1]: return None ret = x.param_like(len(ctx[0])) ctx[0].append(x) return ret -pm_transform_unique_const = PatternMatcher([ - # transform unique consts to LUNIQUE - (UPat(Ops.CONST, src=(UPat(Ops.UNIQUE), UPat(Ops.DEVICE)), name="x"), - lambda ctx,x: x.replace(src=(UOp(Ops.LUNIQUE, arg=next(ctx[1])), x.src[1]))), -]) - pm_ctx = PatternMatcher([ (UPat((Ops.BUFFER, Ops.BIND), name="x"), add_to_ctx), (UPat((Ops.AFTER, Ops.CONTIGUOUS), name="x"), lambda ctx,x: add_to_ctx(ctx,x) if not x.op_in_backward_slice_with_self(Ops.PARAM) and x.op_in_backward_slice_with_self(Ops.BUFFER) else None), -])+pm_transform_unique_const +]) + +def write_only_outputs(uret:UOp) -> set[UOp]: + ret: set[UOp] = set() + for call in uret.backward_slice_with_self: + if call.op is Ops.CALL and call.src[0].op is Ops.SINK: + info = ProgramInfo.from_sink(call.src[0]) + ret.update(call.src[1+i].buf_uop for i in set(info.outs)-set(info.ins)) + return ret ReturnType = TypeVar('ReturnType') class _function(Generic[ReturnType]): @@ -65,7 +68,7 @@ class _function(Generic[ReturnType]): # the BUFFERs that are left are the implicit inputs num_explicit = len(call_uops) - uret = graph_rewrite(uret, pm_ctx, (call_uops, itertools.count(0)), bottom_up=True, name="get_implicit_inputs") + uret = graph_rewrite(uret, pm_ctx, (call_uops, write_only_outputs(uret)), bottom_up=True, name="get_implicit_inputs") name = getattr(self.fxn, '__qualname__', None) or type(self.fxn).__qualname__ if not self.allow_implicit: implicit_buffers = [x for x in call_uops[num_explicit:] if x.op is Ops.BUFFER] diff --git a/tinygrad/mixin/gradient.py b/tinygrad/mixin/gradient.py index c971f7725c..52a623a146 100644 --- a/tinygrad/mixin/gradient.py +++ b/tinygrad/mixin/gradient.py @@ -1,6 +1,6 @@ from typing import cast -import math, dataclasses, itertools -from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops, all_metadata, graph_rewrite +import math, dataclasses +from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops, all_metadata from tinygrad.helpers import argsort from tinygrad.dtype import sum_acc_dtype @@ -42,9 +42,6 @@ def call_gradient(ctx:UOp, k:UOp, needed:set[int]) -> tuple[UOp|None, ...]: grad_bodies = [(i, grads[p]) for i in needed if (p:=params.get(i)) is not None and p in grads] bwd_body = UOp.maketuple(*(gb for _, gb in grad_bodies)).substitute(fwd_subs, walk=True) bwd_body, compact_args = _compact_params(bwd_body, (*args, *grad_args, *fwd_outs)) - # TODO: is this okay here? - from tinygrad.function import pm_transform_unique_const - bwd_body = graph_rewrite(bwd_body, pm_transform_unique_const, ctx=(None, itertools.count(0))) bwd_call = bwd_body.call(*compact_args, name=(k.arg.name or "")+"_backward", precompile=k.arg.precompile_backward) gb_map = {i: idx for idx, (i, _) in enumerate(grad_bodies)} return (None,) + tuple(bwd_call.gettuple(gb_map[i]) if i in gb_map else None for i in range(len(args))) diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index c82ef5709e..e7d2f7baa3 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -564,11 +564,9 @@ class UOp(RandMixin, metaclass=UOpMetaClass): ret = UOp(Ops.CONST, dtype, arg=dtype.const(b), src=()) return ret.reshape((1,)*len(shape)).expand(shape) if shape is not None and shape != () and ret.shape != shape else ret @staticmethod - def invalids(shape:tuple[sint, ...]|None=None, dtype:DTypeLike|None=None, device:str|tuple[str, ...]|None=None, unique=True) -> UOp: + def invalids(shape:tuple[sint, ...]|None=None, dtype:DTypeLike|None=None, device:str|tuple[str, ...]|None=None) -> UOp: dt = to_dtype(dtype) if dtype is not None else dtypes.from_py(Invalid) - ret = UOp(Ops.CONST, dt, arg=dt.const(Invalid), - src=(UOp.unique(None if unique is True else unique), UOp(Ops.DEVICE, arg=canonicalize_device(device)))) - return ret.reshape((1,)*len(shape)).expand(shape) if shape is not None and ret.shape != shape else ret + return UOp.const(dt, Invalid, shape=shape).clone(device=device) @staticmethod def range(end:sint, axis_id, axis_type=AxisType.LOOP, *arg, dtype=dtypes.weakint, src=(), **kwargs): return UOp(Ops.RANGE, dtype=dtype, src=(sint_to_uop(end, dtype),)+src, arg=(axis_id, axis_type)+arg, **kwargs) @@ -1153,7 +1151,7 @@ class ProgramInfo: if u.op is Ops.PARAM and u.addrspace is not None: _globals.append(u.arg.slot) if u.op in (Ops.STORE, Ops.LOAD): if (idx:=u.src[0]).op in (Ops.INDEX, Ops.SHRINK) or (u.src[0].op is Ops.CAST and (idx:=u.src[0].src[0]).op is Ops.INDEX): - if (buf:=idx.src[0]).op is Ops.PARAM: (outs if u.op is Ops.STORE else ins).append(buf.arg.slot) + if (buf:=idx.src[0].buf_uop).op is Ops.PARAM: (outs if u.op is Ops.STORE else ins).append(buf.arg.slot) if u.op is Ops.SPECIAL: if u.arg[0] == 'i': local_size = None special_size = local_size if u.arg[0] == 'l' else global_size diff --git a/tinygrad/uop/render.py b/tinygrad/uop/render.py index 71075220ad..ea93182a75 100644 --- a/tinygrad/uop/render.py +++ b/tinygrad/uop/render.py @@ -1,5 +1,5 @@ from typing import cast -from tinygrad.dtype import dtypes, Invalid +from tinygrad.dtype import dtypes from tinygrad.uop import Ops, GroupOp from tinygrad.uop.ops import UOp, PatternMatcher, UPat, multirange_str, range_str, consumer_map_from_toposort from tinygrad.helpers import strip_parens, all_same @@ -79,8 +79,6 @@ def render_marg(ctx,x:UOp): sugar = {Ops.SINK, Ops.END, Ops.STORE, Ops.LOAD, Ops.UNIQUE, Ops.SQRT, Ops.INDEX, Ops.REDUCE, Ops.AFTER, Ops.THREEFRY, Ops.WHERE, Ops.RECIPROCAL, Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.CONTIGUOUS, Ops.BARRIER, Ops.DETACH} pm_pyrender_extra = PatternMatcher([ - (UPat(Ops.CONST, src=(UPat(Ops.UNIQUE, name="u"), UPat(Ops.DEVICE, name="d")), arg=Invalid, name="x"), - lambda x,u,d: f"UOp.invalids(dtype={x.dtype}, device={repr(d.arg)}, unique={u.arg})"), (UPat(Ops.CONST, src=(), name="x"), lambda x: f"UOp.const({x.dtype}, {x.arg})"), (UPat(Ops.DEFINE_VAR, src=(), name="x"), lambda x: f"UOp.variable(\"{x.arg[0]}\", {x.arg[1]}, {x.arg[2]}{', dtype='+str(x.dtype) if x.dtype is not dtypes.weakint else ''})"), diff --git a/tinygrad/uop/spec.py b/tinygrad/uop/spec.py index 0abde12168..3208014ec7 100644 --- a/tinygrad/uop/spec.py +++ b/tinygrad/uop/spec.py @@ -124,9 +124,6 @@ spec_tensor = PatternMatcher([ (UPat(Ops.UNIQUE, dtypes.void, ()), lambda: True), (UPat(Ops.LUNIQUE, dtypes.void, ()), lambda: True), - # CONST with a UNIQUE and DEVICE - (UPat(Ops.CONST, src=(UPat((Ops.UNIQUE, Ops.LUNIQUE)), UPat(Ops.DEVICE)), name="c"), lambda c: c.arg is Invalid), - # BUFFER (UPat(Ops.BUFFER, src=(UPat((Ops.UNIQUE, Ops.LUNIQUE)), UPat(Ops.DEVICE)), name="buf"), lambda buf: isinstance(buf.arg, int) and isinstance(buf.dtype, DType)), diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index e68f22cadc..772e87ca8e 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -120,7 +120,6 @@ def uop_to_json(data:VizData, x:UOp) -> dict[int, dict]: for u in (toposort:=x.toposort()): # always exclude DEVICE/CONST/UNIQUE if u.op in {Ops.DEVICE, Ops.CONST, Ops.UNIQUE, Ops.LUNIQUE} and u is not x: excluded.add(u) - if u.op is Ops.CONST and len(u.src) and u.src[0].op in {Ops.UNIQUE, Ops.LUNIQUE}: excluded.remove(u) if u.op is Ops.STACK and len(u.src) == 0: excluded.add(u) # exclude RESHAPE/EXPAND that only serve to broadcast a CONST if u.op in {Ops.RESHAPE, Ops.EXPAND} and len(u.src) >= 1 and u.src[0] in excluded and u is not x: excluded.add(u)