diff --git a/test/unit/test_winograd.py b/test/unit/test_winograd.py index d8909f7620..7f419b838c 100644 --- a/test/unit/test_winograd.py +++ b/test/unit/test_winograd.py @@ -42,7 +42,7 @@ class TestWinograd(unittest.TestCase): out = Tensor.conv2d(x,w, padding=1) out.mean().backward() backward_schedule = Tensor.schedule(x.grad, w.grad) - self.assertEqual(len(backward_schedule), 5) + self.assertEqual(len(backward_schedule), 4) def test_counters(self): IC, OC, X, Y = 4,4,9,9 diff --git a/tinygrad/schedule/indexing.py b/tinygrad/schedule/indexing.py index 4c057a3cf1..2482175961 100644 --- a/tinygrad/schedule/indexing.py +++ b/tinygrad/schedule/indexing.py @@ -128,8 +128,7 @@ def apply_movement_op(op:Ops, in_shape:tuple[sint,...], arg:tuple, rngs:tuple[UO axes_out.append(combined_axes % s) combined_axes //= s # this simplify is doing a lot of heavy lifting. this is the replacement for the reshape view merging code - rngs = graph_rewrite(graph_rewrite(UOp.sink(*axes_out[::-1]), symbolic+pm_simplify_valid, name="reshape"), - pm_drop_and_clauses, name="reshape drop ands").src + rngs = graph_rewrite(UOp.sink(*axes_out[::-1]), symbolic+pm_simplify_valid+pm_drop_and_clauses, name="reshape").src case _: raise RuntimeError(f"{op} is not a MovementOp") return rngs