Tensor UOps can become a buffer or const after scheduling (#8698)

* spec

* work

* update test_viewed_consts_do_not_realize

* remove
This commit is contained in:
qazal
2025-01-21 05:33:19 -05:00
committed by GitHub
parent e2008c98c3
commit f0d424ecdf
3 changed files with 60 additions and 5 deletions

View File

@@ -2366,5 +2366,56 @@ class TestContiguous(unittest.TestCase):
b = a.expand((4, 4)).contiguous().contiguous()
check_schedule(b, 1)
class TestUOpBecome(unittest.TestCase):
# the simplest case, if we create a new BUFFER for this UOp
def test_new_buffer(self):
a = Tensor.empty(4, 4)
b = Tensor.empty(4, 4)
add = a+b
check_schedule(add, 1)
assert UPat(Ops.VIEW, src=(UPat(Ops.BUFFER))).match(add.lazydata.base, {})
def test_new_buffer_view(self):
a = Tensor.empty(4, 4)
b = Tensor.empty(4, 4)
add = (a+b).reshape(8, 2)
check_schedule(add, 1)
assert UPat(Ops.VIEW, src=(UPat(Ops.BUFFER))).match(add.lazydata.base, {})
# VIEW is preserverd after the becomes rewrite.
self.assertEqual(add.lazydata.shape, (8, 2))
assert add.lazydata is not add.lazydata.base
def test_become_existing_buffer(self):
a = Tensor.empty(4, 4)
b = a*1
assert UPat(Ops.MUL).match(b.lazydata, {}) # before scheduling it's a mul
check_schedule(b, 0)
assert UPat(Ops.VIEW, src=(UPat(Ops.BUFFER))).match(b.lazydata.base, {}) # scheduling replaces the tensor lazydata with a VIEW(BUFFER)
self.assertIs(a.lazydata.base.buffer, b.lazydata.base.buffer)
def test_become_const_in_base(self):
a = Tensor.empty(4)
b = a*0
assert UPat(Ops.MUL).match(b.lazydata, {}) # before scheduling it's a mul
check_schedule(b, 0)
assert UPat(Ops.CONST, arg=0).match(b.lazydata.base, {}) # scheduling replaces the tensor lazydata with a VIEW(BUFFER)
def test_become_const_in_view(self):
# if we shrink the base down to a size 0, only the VIEW becomes CONST, base is unchanged.
add = Tensor.empty(2, 2)+Tensor.empty(2, 2)
b = add.shrink(((0, 1), (0, 0)))
check_schedule(b, 0)
assert UPat(Ops.CONST, arg=0).match(b.lazydata, {})
self.assertEqual(b.shape, (1, 0))
# the base is untouched.
assert UPat(Ops.ADD).match(add.lazydata, {})
def test_become_const_from_const(self):
const_add = Tensor(1)+Tensor(2)
assert UPat(Ops.ADD).match(const_add.lazydata, {})
check_schedule(const_add, 0)
assert UPat(Ops.CONST, arg=3).match(const_add.lazydata.base, {})
if __name__ == '__main__':
unittest.main(verbosity=2)

View File

@@ -71,9 +71,9 @@ class TestTensorUopRepresentation(unittest.TestCase):
def test_viewed_consts_do_not_realize(self):
a = Tensor.ones(10, 10)
print(a.lazydata)
pre_realize = a.lazydata
a.realize()
assert a.lazydata is pre_realize
is_pattern(a, const_pattern)
self.assertEqual(a.lazydata.shape, (10, 10))
# currently, CONSTs have a "fake" BUFFER. this should be fixed
# current:

View File

@@ -514,10 +514,14 @@ def create_schedule_with_vars(big_sink:UOp, skip_check:bool=not __debug__) -> tu
for buf_uop in store_uops:
for luop in ctx.tensor_uops[buf_uop]: ctx.becomes_map[luop] = buf_uop.view(unwrap(luop.st))
# tensors can become an existing buffer, no ScheduleItem needed
# tensors can become an existing buffer or simplify to a const, no ScheduleItem needed
for k,v in tensor_map.items():
# NOTE: we only add base tensors to becomes_map
if k is not v and v.is_realized and k is k.base: ctx.becomes_map[k] = v.view(unwrap(k.st))
# NOOP
if k.base is v.base: continue
# NOTE: only the base tensors get a BUFFER UOp
if v.is_realized and k is k.base: ctx.becomes_map[k] = v.view(unwrap(k.st))
# otherwise if it simplified to a CONST the UOp just becomes that CONST
elif v.op is Ops.CONST: ctx.becomes_map[k] = v
# add kernel children
schedule_targets = {out:si for si in prescheduled for out in si.outputs}