ASSIGN is STORE+AFTER (try 3)

This commit is contained in:
George Hotz
2026-03-12 15:13:32 +08:00
parent bdd62fd484
commit a8b04efec7
3 changed files with 10 additions and 6 deletions

View File

@@ -110,7 +110,8 @@ pm_early_transform_tensor_graph = PatternMatcher([
(UPat(Ops.CONTIGUOUS, src=(UPat(GroupOp.Movement),), name="c"), contiguous_mops_to_view),
# add CONTIGUOUS to tagged UOps
(UPat(GroupOp.All-{Ops.CONTIGUOUS, Ops.ASSIGN}, name="x"), lambda x: x.rtag(None).contiguous(tag=x.tag) if x.tag else x.replace(tag=None)),
(UPat(GroupOp.All-{Ops.CONTIGUOUS, Ops.ASSIGN, Ops.AFTER}, name="x"),
lambda x: x.rtag(None).contiguous(tag=x.tag) if x.tag else x.replace(tag=None)),
# remove extra CONTIGUOUS on ASSIGN (only when assign target is contiguous)
(UPat(Ops.CONTIGUOUS, src=(UPat(Ops.ASSIGN, name="a"),), name="c"),
lambda a,c: a.replace(tag=(a.tag or ())+(c.tag or ())) if a.src[0].has_buffer_identity() else None),
@@ -128,7 +129,7 @@ def untag_and_append(ctx:AllocCtx, x:UOp):
for t in x.tag:
original_uop: UOp = ctx.uop_list[t]
replace_uop = ret
while replace_uop.op is Ops.ASSIGN: replace_uop = replace_uop.src[0]
while replace_uop.op is Ops.AFTER: replace_uop = replace_uop.src[0]
ctx.buffer_map[original_uop] = replace_uop.shrink_to(original_uop.shape)
ctx.assigns.append(ret)
return ret
@@ -142,7 +143,7 @@ def replace_input_buffer(ctx:AllocCtx, b:UOp):
b._min_max if b.op is Ops.BIND else None, b.src[0].arg[0] if b.op is Ops.BIND else None)
pm_finalize_call = PatternMatcher([
(UPat(Ops.ASSIGN, name="x"), untag_and_append),
(UPat(Ops.AFTER, name="x"), untag_and_append),
(UPat(Ops.AFTER, name="x"), append_after),
(UPat(Ops.COPY, name="x"), lambda ctx,x: append_after(ctx,x) if isinstance(x.device, str) and x.device.startswith(("DISK", "TINYFS")) else None),
# remove unique from const. TODO: this is copied in function.py

View File

@@ -301,7 +301,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
raise ValueError(f"invalid type for axis: {axis_arg}")
return tuple(1 if i in axis_arg else s for i,s in enumerate(ps))
if self.op is Ops.ASSIGN: return self.src[1]._shape
if self.op is Ops.STORE: return self.src[1]._shape
# elementwise ops keep the shape the same. all inputs with shape must match
if self.op in GroupOp.ALU.union({Ops.CAST, Ops.COPY, Ops.NOOP, Ops.GROUP, Ops.SINK, Ops.ALLREDUCE, Ops.STORE}):
@@ -447,7 +447,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
return UOp(Ops.STORE, kwargs.pop("dtype", dtypes.void), (self, UOp.const(self.dtype, src) if not isinstance(src, UOp) else src), **kwargs)
def end(self, *src:UOp): return UOp(Ops.END, src=(self,)+src) if len(src) else self
def after(self, *src:UOp, **kwargs): return UOp(Ops.AFTER, self.dtype, (self,)+src, **kwargs) if len(src) else self
def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self, x))
def assign(self, x:UOp): return self.after(self.store(x))
def barrier(self, *src:UOp): return UOp(Ops.BARRIER, src=(self,)+src)
def contract(self, *rngs:UOp):
assert all(x.arg[-1] == AxisType.UPCAST for x in rngs), "all contract ranges must be upcast"

View File

@@ -96,6 +96,9 @@ _tensor_spec = PatternMatcher([
# ASSIGN has a target and a value. It can also optionally depend on other assigns
(UPat(Ops.ASSIGN, name="x"), lambda x: len(x.src) >= 2 and all(s.op is Ops.ASSIGN for s in x.src[2:])),
# STORE in tensor graph: store a value into a target
(UPat(Ops.STORE, dtypes.void, (UPat(), UPat())), lambda: True),
# MSELECT chooses one of the multi buffers
(UPat(Ops.MSELECT, name="x"), lambda x: isinstance(x.src[0].device, tuple) and x.arg < len(x.src[0].device)),
@@ -129,7 +132,7 @@ _tensor_spec = PatternMatcher([
(UPat(Ops.REDUCE_AXIS, name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) >= 2 and x.arg[0] in {Ops.ADD, Ops.MUL, Ops.MAX}),
# AFTER if things were kernelized
(UPat(Ops.AFTER, src=(UPat((Ops.BUFFER, Ops.AFTER)),), allow_any_len=True), lambda: True),
(UPat(Ops.AFTER, src=(UPat((Ops.BUFFER, Ops.AFTER, Ops.PARAM)),), allow_any_len=True), lambda: True),
# allow CALL/PARAM/CUSTOM_FUNCTION
(UPat(Ops.CALL, src=(UPat(name="f"),), name="c", allow_any_len=True), lambda c,f: c.dtype == f.dtype),