diff --git a/tinygrad/engine/allocations.py b/tinygrad/engine/allocations.py index 36307dd2e9..020168811c 100644 --- a/tinygrad/engine/allocations.py +++ b/tinygrad/engine/allocations.py @@ -110,7 +110,8 @@ pm_early_transform_tensor_graph = PatternMatcher([ (UPat(Ops.CONTIGUOUS, src=(UPat(GroupOp.Movement),), name="c"), contiguous_mops_to_view), # add CONTIGUOUS to tagged UOps - (UPat(GroupOp.All-{Ops.CONTIGUOUS, Ops.ASSIGN}, name="x"), lambda x: x.rtag(None).contiguous(tag=x.tag) if x.tag else x.replace(tag=None)), + (UPat(GroupOp.All-{Ops.CONTIGUOUS, Ops.ASSIGN, Ops.AFTER}, name="x"), + lambda x: x.rtag(None).contiguous(tag=x.tag) if x.tag else x.replace(tag=None)), # remove extra CONTIGUOUS on ASSIGN (only when assign target is contiguous) (UPat(Ops.CONTIGUOUS, src=(UPat(Ops.ASSIGN, name="a"),), name="c"), lambda a,c: a.replace(tag=(a.tag or ())+(c.tag or ())) if a.src[0].has_buffer_identity() else None), @@ -128,7 +129,7 @@ def untag_and_append(ctx:AllocCtx, x:UOp): for t in x.tag: original_uop: UOp = ctx.uop_list[t] replace_uop = ret - while replace_uop.op is Ops.ASSIGN: replace_uop = replace_uop.src[0] + while replace_uop.op is Ops.AFTER: replace_uop = replace_uop.src[0] ctx.buffer_map[original_uop] = replace_uop.shrink_to(original_uop.shape) ctx.assigns.append(ret) return ret @@ -142,7 +143,7 @@ def replace_input_buffer(ctx:AllocCtx, b:UOp): b._min_max if b.op is Ops.BIND else None, b.src[0].arg[0] if b.op is Ops.BIND else None) pm_finalize_call = PatternMatcher([ - (UPat(Ops.ASSIGN, name="x"), untag_and_append), + (UPat(Ops.AFTER, name="x"), untag_and_append), (UPat(Ops.AFTER, name="x"), append_after), (UPat(Ops.COPY, name="x"), lambda ctx,x: append_after(ctx,x) if isinstance(x.device, str) and x.device.startswith(("DISK", "TINYFS")) else None), # remove unique from const. TODO: this is copied in function.py diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 17bfcc3654..933b22b776 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -301,7 +301,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass): raise ValueError(f"invalid type for axis: {axis_arg}") return tuple(1 if i in axis_arg else s for i,s in enumerate(ps)) - if self.op is Ops.ASSIGN: return self.src[1]._shape + if self.op is Ops.STORE: return self.src[1]._shape # elementwise ops keep the shape the same. all inputs with shape must match if self.op in GroupOp.ALU.union({Ops.CAST, Ops.COPY, Ops.NOOP, Ops.GROUP, Ops.SINK, Ops.ALLREDUCE, Ops.STORE}): @@ -447,7 +447,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass): return UOp(Ops.STORE, kwargs.pop("dtype", dtypes.void), (self, UOp.const(self.dtype, src) if not isinstance(src, UOp) else src), **kwargs) def end(self, *src:UOp): return UOp(Ops.END, src=(self,)+src) if len(src) else self def after(self, *src:UOp, **kwargs): return UOp(Ops.AFTER, self.dtype, (self,)+src, **kwargs) if len(src) else self - def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self, x)) + def assign(self, x:UOp): return self.after(self.store(x)) def barrier(self, *src:UOp): return UOp(Ops.BARRIER, src=(self,)+src) def contract(self, *rngs:UOp): assert all(x.arg[-1] == AxisType.UPCAST for x in rngs), "all contract ranges must be upcast" diff --git a/tinygrad/uop/spec.py b/tinygrad/uop/spec.py index 62bcf056a1..672e2b4287 100644 --- a/tinygrad/uop/spec.py +++ b/tinygrad/uop/spec.py @@ -96,6 +96,9 @@ _tensor_spec = PatternMatcher([ # ASSIGN has a target and a value. It can also optionally depend on other assigns (UPat(Ops.ASSIGN, name="x"), lambda x: len(x.src) >= 2 and all(s.op is Ops.ASSIGN for s in x.src[2:])), + # STORE in tensor graph: store a value into a target + (UPat(Ops.STORE, dtypes.void, (UPat(), UPat())), lambda: True), + # MSELECT chooses one of the multi buffers (UPat(Ops.MSELECT, name="x"), lambda x: isinstance(x.src[0].device, tuple) and x.arg < len(x.src[0].device)), @@ -129,7 +132,7 @@ _tensor_spec = PatternMatcher([ (UPat(Ops.REDUCE_AXIS, name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) >= 2 and x.arg[0] in {Ops.ADD, Ops.MUL, Ops.MAX}), # AFTER if things were kernelized - (UPat(Ops.AFTER, src=(UPat((Ops.BUFFER, Ops.AFTER)),), allow_any_len=True), lambda: True), + (UPat(Ops.AFTER, src=(UPat((Ops.BUFFER, Ops.AFTER, Ops.PARAM)),), allow_any_len=True), lambda: True), # allow CALL/PARAM/CUSTOM_FUNCTION (UPat(Ops.CALL, src=(UPat(name="f"),), name="c", allow_any_len=True), lambda c,f: c.dtype == f.dtype),