mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-16 09:55:33 +08:00
invalid clone is anonymous buffer [PR] (#16613)
This commit is contained in:
@@ -522,6 +522,15 @@ class TestFunctionTuple(unittest.TestCase):
|
||||
|
||||
np.testing.assert_allclose(g(a).numpy(), 14.0)
|
||||
|
||||
def test_custom_kernel_inplace_output_is_implicit(self):
|
||||
# a custom_kernel output the kernel also READS (in-place add) is not write-only, so it must be captured as an input
|
||||
def inplace_add(C:UOp, A:UOp) -> UOp:
|
||||
i = UOp.range(A.shape[0], 0)
|
||||
return C[i].store(C[i].load() + A[i]).end(i).sink(arg=KernelInfo(name="inplace_add"))
|
||||
@function(precompile=True, allow_implicit=False)
|
||||
def f(a:Tensor): return Tensor.custom_kernel(Tensor.empty(*a.shape, dtype=a.dtype, device=a.device), a, fxn=inplace_add)[0]
|
||||
with self.assertRaisesRegex(RuntimeError, "implicit buffer"): f(Tensor([1., 2., 3., 4.]).contiguous().realize())
|
||||
|
||||
def test_custom_kernel_precompile_further_compute(self):
|
||||
def my_kernel(C:UOp, A:UOp) -> UOp:
|
||||
i = UOp.range(A.shape[0], 0)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from dataclasses import dataclass, field
|
||||
from tinygrad.dtype import dtypes, AddrSpace, PtrDType, ImageDType
|
||||
from tinygrad.dtype import dtypes, AddrSpace, PtrDType, ImageDType, Invalid
|
||||
from tinygrad.uop.ops import UOp, UPat, PatternMatcher, Ops, GroupOp, graph_rewrite, track_rewrites
|
||||
from tinygrad.helpers import VIZ, pluralize, all_int
|
||||
|
||||
@@ -106,6 +106,9 @@ def _precompiled_output_redirect(s:UOp, t:UOp) -> UOp|None:
|
||||
if s.op in {Ops.BUFFER, Ops.MULTI} and s.has_buffer_identity(): return t
|
||||
return None
|
||||
|
||||
def _is_invalid_init_store(base:UOp, dep:UOp) -> bool:
|
||||
return dep.op is Ops.STORE and dep.src[0].buf_uop is base.buf_uop and dep.src[1].base.arg is Invalid
|
||||
|
||||
def transform_precompiled_call(c:UOp) -> UOp|None:
|
||||
if not c.arg.precompile: return None
|
||||
assert c.src[0].op is Ops.TUPLE, f"expected TUPLE body for precompiled FUNCTION, got {c.src[0].op}"
|
||||
@@ -121,11 +124,14 @@ def transform_precompiled_call(c:UOp) -> UOp|None:
|
||||
items:list[UOp] = []
|
||||
for s, t in zip(srcs, targets):
|
||||
after_deps:list[UOp] = []
|
||||
init_afters:list[UOp] = []
|
||||
while s.op is Ops.AFTER:
|
||||
after_deps.extend(s.src[1:])
|
||||
if all(_is_invalid_init_store(s.src[0], x) for x in s.src[1:]): init_afters.append(s)
|
||||
else: after_deps.extend(s.src[1:])
|
||||
s = s.src[0]
|
||||
if (placed := _precompiled_output_redirect(s, t)) is not None and s not in subs:
|
||||
subs[s] = placed
|
||||
subs.update((x, placed) for x in init_afters)
|
||||
items.append(s.after(*after_deps) if after_deps else s)
|
||||
else:
|
||||
items.append(t.after(t.store(s), *after_deps))
|
||||
|
||||
@@ -1,26 +1,29 @@
|
||||
import functools, itertools, time
|
||||
import functools, time
|
||||
from typing import Generic, TypeVar, Callable, cast, overload
|
||||
from tinygrad.helpers import Context, dedup, getenv, DEBUG
|
||||
from tinygrad.uop.ops import UOp, Ops, graph_rewrite, PatternMatcher, UPat
|
||||
from tinygrad.uop.ops import UOp, Ops, ProgramInfo, graph_rewrite, PatternMatcher, UPat
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.nn.state import get_state_dict
|
||||
|
||||
def add_to_ctx(ctx, x:UOp):
|
||||
if x.buf_uop in ctx[1]: return None
|
||||
ret = x.param_like(len(ctx[0]))
|
||||
ctx[0].append(x)
|
||||
return ret
|
||||
|
||||
pm_transform_unique_const = PatternMatcher([
|
||||
# transform unique consts to LUNIQUE
|
||||
(UPat(Ops.CONST, src=(UPat(Ops.UNIQUE), UPat(Ops.DEVICE)), name="x"),
|
||||
lambda ctx,x: x.replace(src=(UOp(Ops.LUNIQUE, arg=next(ctx[1])), x.src[1]))),
|
||||
])
|
||||
|
||||
pm_ctx = PatternMatcher([
|
||||
(UPat((Ops.BUFFER, Ops.BIND), name="x"), add_to_ctx),
|
||||
(UPat((Ops.AFTER, Ops.CONTIGUOUS), name="x"),
|
||||
lambda ctx,x: add_to_ctx(ctx,x) if not x.op_in_backward_slice_with_self(Ops.PARAM) and x.op_in_backward_slice_with_self(Ops.BUFFER) else None),
|
||||
])+pm_transform_unique_const
|
||||
])
|
||||
|
||||
def write_only_outputs(uret:UOp) -> set[UOp]:
|
||||
ret: set[UOp] = set()
|
||||
for call in uret.backward_slice_with_self:
|
||||
if call.op is Ops.CALL and call.src[0].op is Ops.SINK:
|
||||
info = ProgramInfo.from_sink(call.src[0])
|
||||
ret.update(call.src[1+i].buf_uop for i in set(info.outs)-set(info.ins))
|
||||
return ret
|
||||
|
||||
ReturnType = TypeVar('ReturnType')
|
||||
class _function(Generic[ReturnType]):
|
||||
@@ -65,7 +68,7 @@ class _function(Generic[ReturnType]):
|
||||
|
||||
# the BUFFERs that are left are the implicit inputs
|
||||
num_explicit = len(call_uops)
|
||||
uret = graph_rewrite(uret, pm_ctx, (call_uops, itertools.count(0)), bottom_up=True, name="get_implicit_inputs")
|
||||
uret = graph_rewrite(uret, pm_ctx, (call_uops, write_only_outputs(uret)), bottom_up=True, name="get_implicit_inputs")
|
||||
name = getattr(self.fxn, '__qualname__', None) or type(self.fxn).__qualname__
|
||||
if not self.allow_implicit:
|
||||
implicit_buffers = [x for x in call_uops[num_explicit:] if x.op is Ops.BUFFER]
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import cast
|
||||
import math, dataclasses, itertools
|
||||
from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops, all_metadata, graph_rewrite
|
||||
import math, dataclasses
|
||||
from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops, all_metadata
|
||||
from tinygrad.helpers import argsort
|
||||
from tinygrad.dtype import sum_acc_dtype
|
||||
|
||||
@@ -42,9 +42,6 @@ def call_gradient(ctx:UOp, k:UOp, needed:set[int]) -> tuple[UOp|None, ...]:
|
||||
grad_bodies = [(i, grads[p]) for i in needed if (p:=params.get(i)) is not None and p in grads]
|
||||
bwd_body = UOp.maketuple(*(gb for _, gb in grad_bodies)).substitute(fwd_subs, walk=True)
|
||||
bwd_body, compact_args = _compact_params(bwd_body, (*args, *grad_args, *fwd_outs))
|
||||
# TODO: is this okay here?
|
||||
from tinygrad.function import pm_transform_unique_const
|
||||
bwd_body = graph_rewrite(bwd_body, pm_transform_unique_const, ctx=(None, itertools.count(0)))
|
||||
bwd_call = bwd_body.call(*compact_args, name=(k.arg.name or "")+"_backward", precompile=k.arg.precompile_backward)
|
||||
gb_map = {i: idx for idx, (i, _) in enumerate(grad_bodies)}
|
||||
return (None,) + tuple(bwd_call.gettuple(gb_map[i]) if i in gb_map else None for i in range(len(args)))
|
||||
|
||||
@@ -564,11 +564,9 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
|
||||
ret = UOp(Ops.CONST, dtype, arg=dtype.const(b), src=())
|
||||
return ret.reshape((1,)*len(shape)).expand(shape) if shape is not None and shape != () and ret.shape != shape else ret
|
||||
@staticmethod
|
||||
def invalids(shape:tuple[sint, ...]|None=None, dtype:DTypeLike|None=None, device:str|tuple[str, ...]|None=None, unique=True) -> UOp:
|
||||
def invalids(shape:tuple[sint, ...]|None=None, dtype:DTypeLike|None=None, device:str|tuple[str, ...]|None=None) -> UOp:
|
||||
dt = to_dtype(dtype) if dtype is not None else dtypes.from_py(Invalid)
|
||||
ret = UOp(Ops.CONST, dt, arg=dt.const(Invalid),
|
||||
src=(UOp.unique(None if unique is True else unique), UOp(Ops.DEVICE, arg=canonicalize_device(device))))
|
||||
return ret.reshape((1,)*len(shape)).expand(shape) if shape is not None and ret.shape != shape else ret
|
||||
return UOp.const(dt, Invalid, shape=shape).clone(device=device)
|
||||
@staticmethod
|
||||
def range(end:sint, axis_id, axis_type=AxisType.LOOP, *arg, dtype=dtypes.weakint, src=(), **kwargs):
|
||||
return UOp(Ops.RANGE, dtype=dtype, src=(sint_to_uop(end, dtype),)+src, arg=(axis_id, axis_type)+arg, **kwargs)
|
||||
@@ -1153,7 +1151,7 @@ class ProgramInfo:
|
||||
if u.op is Ops.PARAM and u.addrspace is not None: _globals.append(u.arg.slot)
|
||||
if u.op in (Ops.STORE, Ops.LOAD):
|
||||
if (idx:=u.src[0]).op in (Ops.INDEX, Ops.SHRINK) or (u.src[0].op is Ops.CAST and (idx:=u.src[0].src[0]).op is Ops.INDEX):
|
||||
if (buf:=idx.src[0]).op is Ops.PARAM: (outs if u.op is Ops.STORE else ins).append(buf.arg.slot)
|
||||
if (buf:=idx.src[0].buf_uop).op is Ops.PARAM: (outs if u.op is Ops.STORE else ins).append(buf.arg.slot)
|
||||
if u.op is Ops.SPECIAL:
|
||||
if u.arg[0] == 'i': local_size = None
|
||||
special_size = local_size if u.arg[0] == 'l' else global_size
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from typing import cast
|
||||
from tinygrad.dtype import dtypes, Invalid
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.uop import Ops, GroupOp
|
||||
from tinygrad.uop.ops import UOp, PatternMatcher, UPat, multirange_str, range_str, consumer_map_from_toposort
|
||||
from tinygrad.helpers import strip_parens, all_same
|
||||
@@ -79,8 +79,6 @@ def render_marg(ctx,x:UOp):
|
||||
sugar = {Ops.SINK, Ops.END, Ops.STORE, Ops.LOAD, Ops.UNIQUE, Ops.SQRT, Ops.INDEX, Ops.REDUCE, Ops.AFTER, Ops.THREEFRY,
|
||||
Ops.WHERE, Ops.RECIPROCAL, Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.CONTIGUOUS, Ops.BARRIER, Ops.DETACH}
|
||||
pm_pyrender_extra = PatternMatcher([
|
||||
(UPat(Ops.CONST, src=(UPat(Ops.UNIQUE, name="u"), UPat(Ops.DEVICE, name="d")), arg=Invalid, name="x"),
|
||||
lambda x,u,d: f"UOp.invalids(dtype={x.dtype}, device={repr(d.arg)}, unique={u.arg})"),
|
||||
(UPat(Ops.CONST, src=(), name="x"), lambda x: f"UOp.const({x.dtype}, {x.arg})"),
|
||||
(UPat(Ops.DEFINE_VAR, src=(), name="x"), lambda x:
|
||||
f"UOp.variable(\"{x.arg[0]}\", {x.arg[1]}, {x.arg[2]}{', dtype='+str(x.dtype) if x.dtype is not dtypes.weakint else ''})"),
|
||||
|
||||
@@ -124,9 +124,6 @@ spec_tensor = PatternMatcher([
|
||||
(UPat(Ops.UNIQUE, dtypes.void, ()), lambda: True),
|
||||
(UPat(Ops.LUNIQUE, dtypes.void, ()), lambda: True),
|
||||
|
||||
# CONST with a UNIQUE and DEVICE
|
||||
(UPat(Ops.CONST, src=(UPat((Ops.UNIQUE, Ops.LUNIQUE)), UPat(Ops.DEVICE)), name="c"), lambda c: c.arg is Invalid),
|
||||
|
||||
# BUFFER
|
||||
(UPat(Ops.BUFFER, src=(UPat((Ops.UNIQUE, Ops.LUNIQUE)), UPat(Ops.DEVICE)), name="buf"),
|
||||
lambda buf: isinstance(buf.arg, int) and isinstance(buf.dtype, DType)),
|
||||
|
||||
@@ -120,7 +120,6 @@ def uop_to_json(data:VizData, x:UOp) -> dict[int, dict]:
|
||||
for u in (toposort:=x.toposort()):
|
||||
# always exclude DEVICE/CONST/UNIQUE
|
||||
if u.op in {Ops.DEVICE, Ops.CONST, Ops.UNIQUE, Ops.LUNIQUE} and u is not x: excluded.add(u)
|
||||
if u.op is Ops.CONST and len(u.src) and u.src[0].op in {Ops.UNIQUE, Ops.LUNIQUE}: excluded.remove(u)
|
||||
if u.op is Ops.STACK and len(u.src) == 0: excluded.add(u)
|
||||
# exclude RESHAPE/EXPAND that only serve to broadcast a CONST
|
||||
if u.op in {Ops.RESHAPE, Ops.EXPAND} and len(u.src) >= 1 and u.src[0] in excluded and u is not x: excluded.add(u)
|
||||
|
||||
Reference in New Issue
Block a user