diff --git a/test/test_assign.py b/test/test_assign.py index ef9b53826e..248460908f 100644 --- a/test/test_assign.py +++ b/test/test_assign.py @@ -100,6 +100,14 @@ class TestAssign(unittest.TestCase): new = a + times_a np.testing.assert_allclose(new.numpy(), 5) + @unittest.expectedFailure + def test_assign_diamond_possible(self): + a = Tensor.ones(4).contiguous().realize() + times_a = a*3 + a.assign(Tensor.full((4,), 2.).contiguous()) + new = a + (times_a+1).contiguous() + np.testing.assert_allclose(new.numpy(), 6) + def test_assign_diamond_alt(self): a = Tensor.ones(4).contiguous().realize() a.assign(Tensor.full((4,), 2.).contiguous()) diff --git a/tinygrad/realize.py b/tinygrad/realize.py index 5b7a357b45..c6c733ceec 100644 --- a/tinygrad/realize.py +++ b/tinygrad/realize.py @@ -240,6 +240,10 @@ def create_schedule(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None) assert len(realized_children) == 1 reduce_for_op[next(iter(realized_children.keys()))] = r + # preschedule all buffers in realizes + prescheduled = {x:_schedule_one(x, realizes, reduce_for_op) for x in realizes if x not in seen and x.realized is None and x.op is not LoadOps.CONST} + + # breadth first ordering graph: DefaultDict[LazyBuffer,List[LazyBuffer]] = defaultdict(list) in_degree: DefaultDict[LazyBuffer,int] = defaultdict(int) queue: Deque[LazyBuffer] = deque() @@ -251,16 +255,17 @@ def create_schedule(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None) in_degree[buf] += 1 if in_degree[buf] == 0: queue.append(buf) - sorted_realizes: List[LazyBuffer] = [] + schedule: List[ScheduleItem] = [] while queue: buf = queue.popleft() if buf in realizes and buf not in seen: - sorted_realizes.append(buf) + schedule.append(prescheduled[buf]) seen.add(buf) for x in graph[buf]: in_degree[x] -= 1 if in_degree[x] == 0: queue.append(x) - sched:List[ScheduleItem] = [] - for x in sorted_realizes: sched.append(_schedule_one(x, realizes, reduce_for_op)) - return sched + # confirm everything was scheduled + assert len(prescheduled) == len(schedule), f"prescheduled {len(prescheduled)} but only scheduled {len(schedule)}" + return schedule +