fix get reduce contraction with test (#9834)

This commit is contained in:
George Hotz
2025-04-10 22:24:21 +08:00
committed by GitHub
parent c3fa470852
commit f666dd14eb
3 changed files with 6 additions and 1 deletions

View File

@@ -340,6 +340,8 @@ jobs:
run: awk '/```python/{flag=1;next}/```/{flag=0}flag' README.md > README.py && PYTHONPATH=. python README.py
- name: Run unit tests
run: PYTHONPATH="." python -m pytest -n=auto test/unit/
- name: Run targetted tests on NULL backend
run: PYTHONPATH="." NULL=1 python3 test/test_multitensor.py TestMultiTensor.test_data_parallel_resnet_train_step
- name: Run GC tests
run: PYTHONPATH="." python test/external/external_uop_gc.py
- name: Repo line count < 12500 lines

View File

@@ -194,6 +194,9 @@ class TestGetContraction(unittest.TestCase):
r = get_contraction_with_reduce((16, 1, 1, 1, 1), (16, 1, 1, 1), (1,))
self.assertEqual(r, [[0], [1, 2], [3], [4]])
r = get_contraction_with_reduce((2, 512, 1, 1), (2, 1, 512), (1,))
self.assertIsNone(r)
def test_contraction(self):
r = get_contraction((1,2,3,4), (2,3,4))
self.assertEqual(r, [[0, 1], [2], [3]])

View File

@@ -27,7 +27,7 @@ def get_contraction_with_reduce(old_shape:tuple[sint, ...], new_shape:tuple[sint
while take_from < len(contraction) and len(contraction[take_from]) == 0:
assert new_shape[take_from] == 1
take_from += 1
if take_from == len(contraction): return None # nothing to take
if take_from == len(contraction) or new_shape[take_from] != 1: return None # nothing to take
for j in range(take_from, i, -1):
assert len(contraction[j]) > 0
contraction[j-1] = contraction[j][:-1]