From f666dd14eb686dc17873be077e68b866264bc7b7 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Thu, 10 Apr 2025 22:24:21 +0800 Subject: [PATCH] fix get reduce contraction with test (#9834) --- .github/workflows/test.yml | 2 ++ test/unit/test_helpers.py | 3 +++ tinygrad/codegen/lowerer.py | 2 +- 3 files changed, 6 insertions(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index c375a8cb6b..7a54cceb12 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -340,6 +340,8 @@ jobs: run: awk '/```python/{flag=1;next}/```/{flag=0}flag' README.md > README.py && PYTHONPATH=. python README.py - name: Run unit tests run: PYTHONPATH="." python -m pytest -n=auto test/unit/ + - name: Run targetted tests on NULL backend + run: PYTHONPATH="." NULL=1 python3 test/test_multitensor.py TestMultiTensor.test_data_parallel_resnet_train_step - name: Run GC tests run: PYTHONPATH="." python test/external/external_uop_gc.py - name: Repo line count < 12500 lines diff --git a/test/unit/test_helpers.py b/test/unit/test_helpers.py index 5c5c0ef69c..99ea27f52e 100644 --- a/test/unit/test_helpers.py +++ b/test/unit/test_helpers.py @@ -194,6 +194,9 @@ class TestGetContraction(unittest.TestCase): r = get_contraction_with_reduce((16, 1, 1, 1, 1), (16, 1, 1, 1), (1,)) self.assertEqual(r, [[0], [1, 2], [3], [4]]) + r = get_contraction_with_reduce((2, 512, 1, 1), (2, 1, 512), (1,)) + self.assertIsNone(r) + def test_contraction(self): r = get_contraction((1,2,3,4), (2,3,4)) self.assertEqual(r, [[0, 1], [2], [3]]) diff --git a/tinygrad/codegen/lowerer.py b/tinygrad/codegen/lowerer.py index 1faca4c2b7..51c108bd0c 100644 --- a/tinygrad/codegen/lowerer.py +++ b/tinygrad/codegen/lowerer.py @@ -27,7 +27,7 @@ def get_contraction_with_reduce(old_shape:tuple[sint, ...], new_shape:tuple[sint while take_from < len(contraction) and len(contraction[take_from]) == 0: assert new_shape[take_from] == 1 take_from += 1 - if take_from == len(contraction): return None # nothing to take + if take_from == len(contraction) or new_shape[take_from] != 1: return None # nothing to take for j in range(take_from, i, -1): assert len(contraction[j]) > 0 contraction[j-1] = contraction[j][:-1]