invalid clone is anonymous buffer [PR] (#16613)

This commit is contained in:
chenyu
2026-06-15 20:14:26 -04:00
committed by GitHub
parent 4a0488ae97
commit efd03d7153
8 changed files with 36 additions and 29 deletions

View File

@@ -522,6 +522,15 @@ class TestFunctionTuple(unittest.TestCase):
np.testing.assert_allclose(g(a).numpy(), 14.0)
def test_custom_kernel_inplace_output_is_implicit(self):
# a custom_kernel output the kernel also READS (in-place add) is not write-only, so it must be captured as an input
def inplace_add(C:UOp, A:UOp) -> UOp:
i = UOp.range(A.shape[0], 0)
return C[i].store(C[i].load() + A[i]).end(i).sink(arg=KernelInfo(name="inplace_add"))
@function(precompile=True, allow_implicit=False)
def f(a:Tensor): return Tensor.custom_kernel(Tensor.empty(*a.shape, dtype=a.dtype, device=a.device), a, fxn=inplace_add)[0]
with self.assertRaisesRegex(RuntimeError, "implicit buffer"): f(Tensor([1., 2., 3., 4.]).contiguous().realize())
def test_custom_kernel_precompile_further_compute(self):
def my_kernel(C:UOp, A:UOp) -> UOp:
i = UOp.range(A.shape[0], 0)

View File

@@ -1,5 +1,5 @@
from dataclasses import dataclass, field
from tinygrad.dtype import dtypes, AddrSpace, PtrDType, ImageDType
from tinygrad.dtype import dtypes, AddrSpace, PtrDType, ImageDType, Invalid
from tinygrad.uop.ops import UOp, UPat, PatternMatcher, Ops, GroupOp, graph_rewrite, track_rewrites
from tinygrad.helpers import VIZ, pluralize, all_int
@@ -106,6 +106,9 @@ def _precompiled_output_redirect(s:UOp, t:UOp) -> UOp|None:
if s.op in {Ops.BUFFER, Ops.MULTI} and s.has_buffer_identity(): return t
return None
def _is_invalid_init_store(base:UOp, dep:UOp) -> bool:
return dep.op is Ops.STORE and dep.src[0].buf_uop is base.buf_uop and dep.src[1].base.arg is Invalid
def transform_precompiled_call(c:UOp) -> UOp|None:
if not c.arg.precompile: return None
assert c.src[0].op is Ops.TUPLE, f"expected TUPLE body for precompiled FUNCTION, got {c.src[0].op}"
@@ -121,11 +124,14 @@ def transform_precompiled_call(c:UOp) -> UOp|None:
items:list[UOp] = []
for s, t in zip(srcs, targets):
after_deps:list[UOp] = []
init_afters:list[UOp] = []
while s.op is Ops.AFTER:
after_deps.extend(s.src[1:])
if all(_is_invalid_init_store(s.src[0], x) for x in s.src[1:]): init_afters.append(s)
else: after_deps.extend(s.src[1:])
s = s.src[0]
if (placed := _precompiled_output_redirect(s, t)) is not None and s not in subs:
subs[s] = placed
subs.update((x, placed) for x in init_afters)
items.append(s.after(*after_deps) if after_deps else s)
else:
items.append(t.after(t.store(s), *after_deps))

View File

@@ -1,26 +1,29 @@
import functools, itertools, time
import functools, time
from typing import Generic, TypeVar, Callable, cast, overload
from tinygrad.helpers import Context, dedup, getenv, DEBUG
from tinygrad.uop.ops import UOp, Ops, graph_rewrite, PatternMatcher, UPat
from tinygrad.uop.ops import UOp, Ops, ProgramInfo, graph_rewrite, PatternMatcher, UPat
from tinygrad.tensor import Tensor
from tinygrad.nn.state import get_state_dict
def add_to_ctx(ctx, x:UOp):
if x.buf_uop in ctx[1]: return None
ret = x.param_like(len(ctx[0]))
ctx[0].append(x)
return ret
pm_transform_unique_const = PatternMatcher([
# transform unique consts to LUNIQUE
(UPat(Ops.CONST, src=(UPat(Ops.UNIQUE), UPat(Ops.DEVICE)), name="x"),
lambda ctx,x: x.replace(src=(UOp(Ops.LUNIQUE, arg=next(ctx[1])), x.src[1]))),
])
pm_ctx = PatternMatcher([
(UPat((Ops.BUFFER, Ops.BIND), name="x"), add_to_ctx),
(UPat((Ops.AFTER, Ops.CONTIGUOUS), name="x"),
lambda ctx,x: add_to_ctx(ctx,x) if not x.op_in_backward_slice_with_self(Ops.PARAM) and x.op_in_backward_slice_with_self(Ops.BUFFER) else None),
])+pm_transform_unique_const
])
def write_only_outputs(uret:UOp) -> set[UOp]:
ret: set[UOp] = set()
for call in uret.backward_slice_with_self:
if call.op is Ops.CALL and call.src[0].op is Ops.SINK:
info = ProgramInfo.from_sink(call.src[0])
ret.update(call.src[1+i].buf_uop for i in set(info.outs)-set(info.ins))
return ret
ReturnType = TypeVar('ReturnType')
class _function(Generic[ReturnType]):
@@ -65,7 +68,7 @@ class _function(Generic[ReturnType]):
# the BUFFERs that are left are the implicit inputs
num_explicit = len(call_uops)
uret = graph_rewrite(uret, pm_ctx, (call_uops, itertools.count(0)), bottom_up=True, name="get_implicit_inputs")
uret = graph_rewrite(uret, pm_ctx, (call_uops, write_only_outputs(uret)), bottom_up=True, name="get_implicit_inputs")
name = getattr(self.fxn, '__qualname__', None) or type(self.fxn).__qualname__
if not self.allow_implicit:
implicit_buffers = [x for x in call_uops[num_explicit:] if x.op is Ops.BUFFER]

View File

@@ -1,6 +1,6 @@
from typing import cast
import math, dataclasses, itertools
from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops, all_metadata, graph_rewrite
import math, dataclasses
from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops, all_metadata
from tinygrad.helpers import argsort
from tinygrad.dtype import sum_acc_dtype
@@ -42,9 +42,6 @@ def call_gradient(ctx:UOp, k:UOp, needed:set[int]) -> tuple[UOp|None, ...]:
grad_bodies = [(i, grads[p]) for i in needed if (p:=params.get(i)) is not None and p in grads]
bwd_body = UOp.maketuple(*(gb for _, gb in grad_bodies)).substitute(fwd_subs, walk=True)
bwd_body, compact_args = _compact_params(bwd_body, (*args, *grad_args, *fwd_outs))
# TODO: is this okay here?
from tinygrad.function import pm_transform_unique_const
bwd_body = graph_rewrite(bwd_body, pm_transform_unique_const, ctx=(None, itertools.count(0)))
bwd_call = bwd_body.call(*compact_args, name=(k.arg.name or "")+"_backward", precompile=k.arg.precompile_backward)
gb_map = {i: idx for idx, (i, _) in enumerate(grad_bodies)}
return (None,) + tuple(bwd_call.gettuple(gb_map[i]) if i in gb_map else None for i in range(len(args)))

View File

@@ -564,11 +564,9 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
ret = UOp(Ops.CONST, dtype, arg=dtype.const(b), src=())
return ret.reshape((1,)*len(shape)).expand(shape) if shape is not None and shape != () and ret.shape != shape else ret
@staticmethod
def invalids(shape:tuple[sint, ...]|None=None, dtype:DTypeLike|None=None, device:str|tuple[str, ...]|None=None, unique=True) -> UOp:
def invalids(shape:tuple[sint, ...]|None=None, dtype:DTypeLike|None=None, device:str|tuple[str, ...]|None=None) -> UOp:
dt = to_dtype(dtype) if dtype is not None else dtypes.from_py(Invalid)
ret = UOp(Ops.CONST, dt, arg=dt.const(Invalid),
src=(UOp.unique(None if unique is True else unique), UOp(Ops.DEVICE, arg=canonicalize_device(device))))
return ret.reshape((1,)*len(shape)).expand(shape) if shape is not None and ret.shape != shape else ret
return UOp.const(dt, Invalid, shape=shape).clone(device=device)
@staticmethod
def range(end:sint, axis_id, axis_type=AxisType.LOOP, *arg, dtype=dtypes.weakint, src=(), **kwargs):
return UOp(Ops.RANGE, dtype=dtype, src=(sint_to_uop(end, dtype),)+src, arg=(axis_id, axis_type)+arg, **kwargs)
@@ -1153,7 +1151,7 @@ class ProgramInfo:
if u.op is Ops.PARAM and u.addrspace is not None: _globals.append(u.arg.slot)
if u.op in (Ops.STORE, Ops.LOAD):
if (idx:=u.src[0]).op in (Ops.INDEX, Ops.SHRINK) or (u.src[0].op is Ops.CAST and (idx:=u.src[0].src[0]).op is Ops.INDEX):
if (buf:=idx.src[0]).op is Ops.PARAM: (outs if u.op is Ops.STORE else ins).append(buf.arg.slot)
if (buf:=idx.src[0].buf_uop).op is Ops.PARAM: (outs if u.op is Ops.STORE else ins).append(buf.arg.slot)
if u.op is Ops.SPECIAL:
if u.arg[0] == 'i': local_size = None
special_size = local_size if u.arg[0] == 'l' else global_size

View File

@@ -1,5 +1,5 @@
from typing import cast
from tinygrad.dtype import dtypes, Invalid
from tinygrad.dtype import dtypes
from tinygrad.uop import Ops, GroupOp
from tinygrad.uop.ops import UOp, PatternMatcher, UPat, multirange_str, range_str, consumer_map_from_toposort
from tinygrad.helpers import strip_parens, all_same
@@ -79,8 +79,6 @@ def render_marg(ctx,x:UOp):
sugar = {Ops.SINK, Ops.END, Ops.STORE, Ops.LOAD, Ops.UNIQUE, Ops.SQRT, Ops.INDEX, Ops.REDUCE, Ops.AFTER, Ops.THREEFRY,
Ops.WHERE, Ops.RECIPROCAL, Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.CONTIGUOUS, Ops.BARRIER, Ops.DETACH}
pm_pyrender_extra = PatternMatcher([
(UPat(Ops.CONST, src=(UPat(Ops.UNIQUE, name="u"), UPat(Ops.DEVICE, name="d")), arg=Invalid, name="x"),
lambda x,u,d: f"UOp.invalids(dtype={x.dtype}, device={repr(d.arg)}, unique={u.arg})"),
(UPat(Ops.CONST, src=(), name="x"), lambda x: f"UOp.const({x.dtype}, {x.arg})"),
(UPat(Ops.DEFINE_VAR, src=(), name="x"), lambda x:
f"UOp.variable(\"{x.arg[0]}\", {x.arg[1]}, {x.arg[2]}{', dtype='+str(x.dtype) if x.dtype is not dtypes.weakint else ''})"),

View File

@@ -124,9 +124,6 @@ spec_tensor = PatternMatcher([
(UPat(Ops.UNIQUE, dtypes.void, ()), lambda: True),
(UPat(Ops.LUNIQUE, dtypes.void, ()), lambda: True),
# CONST with a UNIQUE and DEVICE
(UPat(Ops.CONST, src=(UPat((Ops.UNIQUE, Ops.LUNIQUE)), UPat(Ops.DEVICE)), name="c"), lambda c: c.arg is Invalid),
# BUFFER
(UPat(Ops.BUFFER, src=(UPat((Ops.UNIQUE, Ops.LUNIQUE)), UPat(Ops.DEVICE)), name="buf"),
lambda buf: isinstance(buf.arg, int) and isinstance(buf.dtype, DType)),

View File

@@ -120,7 +120,6 @@ def uop_to_json(data:VizData, x:UOp) -> dict[int, dict]:
for u in (toposort:=x.toposort()):
# always exclude DEVICE/CONST/UNIQUE
if u.op in {Ops.DEVICE, Ops.CONST, Ops.UNIQUE, Ops.LUNIQUE} and u is not x: excluded.add(u)
if u.op is Ops.CONST and len(u.src) and u.src[0].op in {Ops.UNIQUE, Ops.LUNIQUE}: excluded.remove(u)
if u.op is Ops.STACK and len(u.src) == 0: excluded.add(u)
# exclude RESHAPE/EXPAND that only serve to broadcast a CONST
if u.op in {Ops.RESHAPE, Ops.EXPAND} and len(u.src) >= 1 and u.src[0] in excluded and u is not x: excluded.add(u)