mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
ranges on store (#11334)
* ranges on store * fix store spec * fix that * fix gates * fix tests * fix ptx
This commit is contained in:
@@ -267,7 +267,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
stores = [u for u in uops if u.op is Ops.STORE]
|
||||
assert len(accs) == 0 # it's removed now
|
||||
assert len(stores) == 1
|
||||
assert stores[0].src[-1].dtype == dtypes.float.vec(4)
|
||||
assert stores[0].src[1].dtype == dtypes.float.vec(4)
|
||||
|
||||
# NOTE: can reenable, it does work. it just makes BEAM slow
|
||||
@unittest.expectedFailure
|
||||
@@ -294,10 +294,10 @@ class TestLinearizer(unittest.TestCase):
|
||||
stores = [u for u in program.uops if u.op is Ops.STORE and u.dtype.addrspace != AddrSpace.REG]
|
||||
|
||||
# the first store is to lds and can be upcasted
|
||||
assert stores[0].src[-1].dtype == dtypes.float.vec(4)
|
||||
assert stores[0].src[1].dtype == dtypes.float.vec(4)
|
||||
assert any(x.op is Ops.DEFINE_LOCAL for x in stores[0].toposort())
|
||||
# the second store is to gds with no upcasts
|
||||
assert stores[1].src[-1].dtype == dtypes.float
|
||||
assert stores[1].src[1].dtype == dtypes.float
|
||||
assert any(x.op is Ops.DEFINE_GLOBAL for x in stores[1].toposort())
|
||||
|
||||
def test_zero_fold(self):
|
||||
@@ -648,7 +648,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
k = helper_linearizer_opt(out)[-1]
|
||||
uops = get_program(k.get_optimized_ast(), k.opts).uops
|
||||
# check that the float4 cast collapses
|
||||
store_vals = [u.src[-1] for u in uops if u.op is Ops.STORE and u.dtype.addrspace != AddrSpace.REG]
|
||||
store_vals = [u.src[1] for u in uops if u.op is Ops.STORE and u.dtype.addrspace != AddrSpace.REG]
|
||||
for val in store_vals:
|
||||
assert val.dtype == dtypes.float.vec(4) # and val.op is not Ops.VECTORIZE
|
||||
|
||||
@@ -671,7 +671,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
x = Tensor.randn((4,3,6,6)).realize()
|
||||
out = x.flip((0,1)).contiguous()
|
||||
k = helper_linearizer_opt(out)[-1]
|
||||
store_val = [u.src[-1] for u in get_program(k.get_optimized_ast(), k.opts).uops if u.op is Ops.STORE][0]
|
||||
store_val = [u.src[1] for u in get_program(k.get_optimized_ast(), k.opts).uops if u.op is Ops.STORE][0]
|
||||
assert store_val.dtype == dtypes.float.vec(4) and store_val.op is not Ops.VECTORIZE
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
|
||||
@@ -690,7 +690,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
barrier = [u for u in uops if u.op is Ops.BARRIER][0]
|
||||
# check that the float4 cast collapses for all stores
|
||||
for store in local_stores+global_stores:
|
||||
assert store.src[-1].dtype.count > 1 # and store.src[2].op is not Ops.VECTORIZE
|
||||
assert store.src[1].dtype.count > 1 # and store.src[2].op is not Ops.VECTORIZE
|
||||
# # check the children's vins
|
||||
# TODO: src ALU are not the same, should it?
|
||||
# assert barrier.src == tuple(local_stores)
|
||||
@@ -707,11 +707,11 @@ class TestLinearizer(unittest.TestCase):
|
||||
stores = [u for u in uops if u.op is Ops.STORE and u.dtype.addrspace != AddrSpace.REG]
|
||||
|
||||
# the float4 value stores directly in lds and we skip upcast
|
||||
self.assertEqual(stores[0].src[-1].dtype, dtypes.float.vec(4))
|
||||
self.assertEqual(stores[0].src[1].dtype, dtypes.float.vec(4))
|
||||
#assert stores[0].src[-1].op is not Ops.VECTORIZE
|
||||
|
||||
# the global store doesn't change
|
||||
assert stores[1].src[-1].dtype == dtypes.float
|
||||
assert stores[1].src[1].dtype == dtypes.float
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4")
|
||||
@@ -730,7 +730,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
]
|
||||
k = helper_linearizer_ast(ast, [Tensor.randn(240*40).realize()], opts=[opt])[-1]
|
||||
out = [u for u in get_program(k.get_optimized_ast(), k.opts).uops if u.op is Ops.STORE][0]
|
||||
assert out.src[-1].op is Ops.VECTORIZE and out.src[-1].dtype == dtypes.float.vec(4)
|
||||
assert out.src[1].op is Ops.VECTORIZE and out.src[1].dtype == dtypes.float.vec(4)
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4")
|
||||
@@ -748,18 +748,18 @@ class TestLinearizer(unittest.TestCase):
|
||||
Opt(op=OptOps.UPCAST, axis=1, arg=0), Opt(op=OptOps.UPCAST, axis=0, arg=2)]
|
||||
k = helper_linearizer_ast(ast, [Tensor.randn(8*32).realize()], opts=[opt])[-1]
|
||||
out = [u for u in get_program(k.get_optimized_ast(), k.opts).uops if u.op is Ops.STORE][0]
|
||||
assert out.src[-1].op is Ops.VECTORIZE and out.src[-1].dtype.count != 1
|
||||
assert out.src[1].op is Ops.VECTORIZE and out.src[1].dtype.count != 1
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "need backends that support float4")
|
||||
class TestFloat4(unittest.TestCase):
|
||||
@staticmethod
|
||||
def count_float4(uops: list[UOp], n=4):
|
||||
return (len([uop for uop in uops if uop.op is Ops.LOAD and uop.dtype == dtypes.float.vec(n)]),
|
||||
len([uop for uop in uops if uop.op is Ops.STORE and uop.src[-1].dtype == dtypes.float.vec(n)]))
|
||||
len([uop for uop in uops if uop.op is Ops.STORE and uop.src[1].dtype == dtypes.float.vec(n)]))
|
||||
@staticmethod
|
||||
def count_half4(uops: list[UOp]):
|
||||
return (len([uop for uop in uops if uop.op is Ops.LOAD and uop.dtype == dtypes.half.vec(4)]),
|
||||
len([uop for uop in uops if uop.op is Ops.STORE and uop.src[-1].dtype == dtypes.half.vec(4)]))
|
||||
len([uop for uop in uops if uop.op is Ops.STORE and uop.src[1].dtype == dtypes.half.vec(4)]))
|
||||
|
||||
def test_float4_basic(self):
|
||||
a = Tensor.empty(2, 8).realize()
|
||||
|
||||
@@ -14,7 +14,7 @@ def render(self) -> tuple[str, ConstType, ConstType]:
|
||||
# NOTE: we need STORE so the ALU op has children
|
||||
glbl = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0)
|
||||
uops = full_rewrite(UOp(Ops.STORE, dtypes.void, (glbl.index(UOp.const(dtypes.int, 0)), self)).sink())
|
||||
rewritten_uop = [uop for uop in uops if uop.op is Ops.STORE][0].src[-1]
|
||||
rewritten_uop = [uop for uop in uops if uop.op is Ops.STORE][0].src[1]
|
||||
return rewritten_uop.render(simplify=False), rewritten_uop.vmin, rewritten_uop.vmax
|
||||
|
||||
def uconst(val): return UOp.const(dtypes.int, val)
|
||||
@@ -642,7 +642,7 @@ class TestSymbolic(unittest.TestCase):
|
||||
# TODO: copied from render, render does not support cast
|
||||
glbl = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0)
|
||||
uops = full_rewrite(UOp(Ops.STORE, dtypes.void, (glbl.index(UOp.const(dtypes.int, 0)), expr)).sink())
|
||||
rewritten_uop = [uop for uop in uops if uop.op is Ops.STORE][0].src[-1]
|
||||
rewritten_uop = [uop for uop in uops if uop.op is Ops.STORE][0].src[1]
|
||||
|
||||
self.assertEqual(rewritten_uop, cond.where(a.cast(dtypes.half), b.cast(dtypes.half)))
|
||||
|
||||
|
||||
@@ -47,10 +47,10 @@ def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> UOp|None:
|
||||
new_valid = functools.reduce(operator.and_, ss) if (ss:=[s for s in split_uop(valid, Ops.AND) if s not in drop_stmt]) else None
|
||||
return buf.index(idx, new_valid)
|
||||
|
||||
def delete_redundant_gates(buf:UOp, idx:UOp, val:UOp, store_gate:UOp, cast:UOp|None=None) -> UOp|None:
|
||||
def delete_redundant_gates(store:UOp, buf:UOp, idx:UOp, val:UOp, store_gate:UOp, cast:UOp|None=None) -> UOp|None:
|
||||
if store_gate not in [gate.src[0] for gate in val.toposort() if gate.op is Ops.IF]: return None
|
||||
# remove the gate from the index
|
||||
return UOp.store(buf.index(idx).cast(cast.dtype) if cast is not None else buf.index(idx), val)
|
||||
return UOp.store(buf.index(idx).cast(cast.dtype) if cast is not None else buf.index(idx), val, *store.src[2:])
|
||||
|
||||
load_store_indexing = PatternMatcher([
|
||||
# simplify valid
|
||||
@@ -61,7 +61,7 @@ load_store_indexing = PatternMatcher([
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("start_idx"), UPat(Ops.CONST, arg=True))), lambda buf,start_idx: buf.index(start_idx)),
|
||||
# delete_redundant_gates (after expand)
|
||||
(UPat(Ops.STORE, src=(UPat.any(stidx:=UPat.var("buf").index(UPat.var("idx"), UPat.var("store_gate")), stidx.cast().named("cast")),
|
||||
UPat.var("val"))), delete_redundant_gates),
|
||||
UPat.var("val")), name="store", allow_any_len=True), delete_redundant_gates),
|
||||
])
|
||||
|
||||
# ***** load/store grouping *****
|
||||
@@ -311,7 +311,7 @@ pm_render = PatternMatcher([
|
||||
lambda x: x.replace(src=(x.src[0], x.const_like(0))+x.src[1:]) if len(x.src) == 1 or x.src[1].op is Ops.CUSTOM else None),
|
||||
# gate any stores that aren't gated with ifs
|
||||
(UPat(Ops.STORE, src=(UPat(src=(UPat(), UPat(), UPat(dtype=dtypes.bool)), name="idx").or_casted(), UPat()), name="store", allow_any_len=True),
|
||||
lambda store,idx: UOp(Ops.STORE, dtype=store.dtype, src=store.src+(UOp(Ops.IF, src=(idx.src[2],)),)) if \
|
||||
lambda store,idx: UOp(Ops.STORE, dtype=store.dtype, src=store.src[:2]+(UOp(Ops.IF, src=(idx.src[2],)),)+store.src[2:]) if \
|
||||
len(store.src) <= 2 or store.src[2].op != Ops.IF else None),
|
||||
])
|
||||
|
||||
@@ -340,7 +340,7 @@ def reduce_to_acc(ctx:ReduceContext, red:UOp):
|
||||
lst = [acc.load()] + lst # put acc as the first element
|
||||
ctx.acc_num += 1
|
||||
ret = functools.reduce(lambda x,y: x.alu(red.arg, y), lst)
|
||||
return acc.store(ret).load() if len(reduce_range) != 0 else ret
|
||||
return acc.store(ret, *reduce_range).load() if len(reduce_range) != 0 else ret
|
||||
|
||||
def no_vectorized_reduce(inp:UOp, red:UOp):
|
||||
if inp.dtype != red.dtype:
|
||||
|
||||
@@ -57,7 +57,7 @@ def lower_store(ctx: IndexContext, x: UOp, buf: UOp):
|
||||
# NOTE: only store the local reduceop in the threads that are actually doing the reduce
|
||||
for oidx, ridx in zip(ctx.idxs, ctx.ridxs):
|
||||
if oidx is not ridx: valid = valid * oidx.eq(0)
|
||||
return buf.index(idx, valid).store(x.src[1])
|
||||
return buf.index(idx, valid).store(x.src[1], *[x for x in UOp.sink(idx, valid).toposort() if x.op is Ops.RANGE])
|
||||
|
||||
def lower_const(ctx:IndexContext, view:UOp, c:UOp):
|
||||
if all(x.mask is None for x in view.arg.views): return c
|
||||
|
||||
@@ -54,7 +54,7 @@ ptx_matcher = PatternMatcher([
|
||||
# move mask from INDEX to the load/store to enable pointer arithmetic
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"), UPat.var("gate"))), UPat.var("alt"))),
|
||||
lambda buf,idx,gate,alt: UOp(Ops.LOAD, alt.dtype, (buf.index(idx), alt, gate))),
|
||||
(UPat(Ops.STORE, src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"), UPat())), UPat.var("val"), UPat.var("gate"))),
|
||||
(UPat(Ops.STORE, src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"), UPat())), UPat.var("val"), UPat.var("gate")), allow_any_len=True),
|
||||
lambda buf,idx,val,gate: UOp.store(buf.index(idx), val, gate)),
|
||||
# ptx shr and shl instructions require y to be uint
|
||||
(UPat.var("x") << UPat.var("y"), lambda x,y: UOp(Ops.SHL, x.dtype, (x,y.cast(dtypes.uint))) if y.dtype != dtypes.uint else None),
|
||||
|
||||
@@ -451,8 +451,9 @@ sym = symbolic_flat+PatternMatcher([
|
||||
(UPat(Ops.INDEX, src=(UPat.var("b"), UPat.var("idx"), UPat.const(dtypes.bool, True))), lambda b, idx: b.index(idx)),
|
||||
# ** load/store folding **
|
||||
(UPat.store(UPat(Ops.INDEX, name="index"), UPat.load(UPat(Ops.INDEX, name="index"))), lambda index: UOp(Ops.NOOP)),
|
||||
(UPat.store(UPat(Ops.INDEX, name="index"), UPat.var("gate").where(UPat.var("alt"), UPat.load(UPat(Ops.INDEX, name="index")))),
|
||||
lambda index, gate, alt: UOp.store(index.src[0].index(index.src[1], gate), alt)),
|
||||
(UPat.store(UPat(Ops.INDEX, name="index"), UPat.var("gate").where(UPat.var("alt"),
|
||||
UPat.load(UPat(Ops.INDEX, name="index"))), allow_any_len=True, name="store"),
|
||||
lambda index, gate, alt, store: UOp.store(index.src[0].index(index.src[1], gate), alt, *store.src[2:])),
|
||||
# fold gated LOAD/STORE
|
||||
(UPat().index(UPat(), UPat.const(dtypes.bool, True)).named("idx"), lambda idx: idx.replace(src=idx.src[0:2])), # remove True
|
||||
(UPat((Ops.LOAD, Ops.STORE), src=(UPat().index(UPat(), UPat.const(dtypes.bool, False)).or_casted(),), allow_any_len=True, name="x"),
|
||||
|
||||
Reference in New Issue
Block a user