diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index 6ab922e3bb..597726b160 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -799,7 +799,7 @@ def _assert_valid_uop(uop:UOp, st:ShapeTracker, sts:Dict[UOp, ShapeTracker]) -> elif uop.op in {UOps.SHAPETRACKER, UOps.SWIZZLE}: st = uop.arg # everything else inherits shape else: - assert uop.op in {UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.CONTRACT, UOps.EXPAND, *BUFFER_UOPS}, f"bad UOp in intermediate uops {uop}" + assert uop.op in {UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.CONTRACT, UOps.EXPAND, UOps.ASSIGN, *BUFFER_UOPS}, f"bad UOp in intermediate uops {uop}" st = (src_sts:=[sts[x] for x in uop.src if x.has_st])[0] if not all_same(shapes:=[x.shape for x in src_sts]): if all_same(sizes:=[prod(x) for x in shapes]): raise AssertionError(f"found implicit reshape {shapes}") diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index 5f70891adc..44f87732cf 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -346,6 +346,8 @@ constant_folder = PatternMatcher([ (UPat(UOps.ALU, dtypes.bool, arg=BinaryOps.MUL, name="x"), lambda x: UOp(x.op, x.dtype, x.src, BinaryOps.AND)), # self ASSIGN is just self (UPat(UOps.ASSIGN, src=(UPat.var('x'), UPat.var('x'))), lambda x: x), + # ASSIGN to global is just self + (UPat(UOps.ASSIGN, src=(UPat(UOps.DEFINE_GLOBAL), UPat.var("x"))), lambda x: x), # VECTORIZE/GEP: the expander rule allows tuple GEP creation, this is just for removal (UPat(UOps.VECTORIZE, src=UPat(UOps.GEP, src=(UPat(name="x"),)), name="vec"), lambda vec,x: x if x.dtype == vec.dtype and tuple(y.arg[0] for y in vec.src) == tuple(range(len(vec.src))) else None), diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 490dbb5a02..92f0fa9650 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -110,7 +110,7 @@ reduceop_fusor = PatternMatcher([ # push a SWIZZLE down to STORE, through a reduce (ONLY reshapes) (UPat(UOps.REDUCE_AXIS, src=(UPat(UOps.SWIZZLE, name="swizzle"),), name="root"), push_swizzle_down_through_reduce), # push SWIZZLE(s) down to STORE, through an elementwise op (ONLY reshapes) - (UPat((UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.STORE), name="root"), push_swizzle_down_through_elementwise), + (UPat((UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.ASSIGN, UOps.STORE), name="root"), push_swizzle_down_through_elementwise), (UPat(UOps.REDUCE_AXIS, src=(UPat(UOps.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce), ]) @@ -157,9 +157,10 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, .. # elementwise ops pass shapetracker in_uops = tuple(_recursive_uop(x, st, outputs, var_vals, inputs, realizes, assign_targets, cache) for x in buf.srcs) - if buf.op in {MetaOps.CONTIGUOUS, MetaOps.ASSIGN}: + if buf.op is MetaOps.CONTIGUOUS: assert buf in outputs, f"{buf.op} must be writable" return in_uops[0] + if buf.op is MetaOps.ASSIGN: return cache.setdefault((buf, st), UOp(UOps.ASSIGN, dtype, (in_uops[1].src[0], in_uops[0]))) if buf.op is UnaryOps.CAST: return cache.setdefault((buf, st), UOp(UOps.CAST, dtype, in_uops)) if buf.op is UnaryOps.BITCAST: return cache.setdefault((buf, st), UOp(UOps.BITCAST, dtype, in_uops)) return cache.setdefault((buf, st), UOp(UOps.ALU, dtype, in_uops, buf.op)) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index b0290ab811..f3a7b9e403 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -618,7 +618,7 @@ spec = PatternMatcher([(x, functools.partial(lambda fxn,**kw: UOp.const(dtypes.b (UPat(UOps.ALU, arg=BinaryOps.IDIV, name="x"), lambda x: None if dtypes.is_int(x.dtype) else False), (UPat(UOps.ALU, name="x"), lambda x: all(x.dtype == y.dtype for y in x.src)), - (UPat(UOps.ASSIGN, src=(UPat(UOps.DEFINE_ACC), UPat())), lambda: True), + (UPat(UOps.ASSIGN, src=(UPat((UOps.DEFINE_ACC, UOps.DEFINE_GLOBAL)), UPat())), lambda: True), (UPat(UOps.ENDRANGE, dtype=dtypes.void, src=(UPat(UOps.RANGE),)), lambda: True), # all WMMA has 3 args,