mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 08:28:55 +08:00
fix get reduce contraction with test (#9834)
This commit is contained in:
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
@@ -340,6 +340,8 @@ jobs:
|
||||
run: awk '/```python/{flag=1;next}/```/{flag=0}flag' README.md > README.py && PYTHONPATH=. python README.py
|
||||
- name: Run unit tests
|
||||
run: PYTHONPATH="." python -m pytest -n=auto test/unit/
|
||||
- name: Run targetted tests on NULL backend
|
||||
run: PYTHONPATH="." NULL=1 python3 test/test_multitensor.py TestMultiTensor.test_data_parallel_resnet_train_step
|
||||
- name: Run GC tests
|
||||
run: PYTHONPATH="." python test/external/external_uop_gc.py
|
||||
- name: Repo line count < 12500 lines
|
||||
|
||||
@@ -194,6 +194,9 @@ class TestGetContraction(unittest.TestCase):
|
||||
r = get_contraction_with_reduce((16, 1, 1, 1, 1), (16, 1, 1, 1), (1,))
|
||||
self.assertEqual(r, [[0], [1, 2], [3], [4]])
|
||||
|
||||
r = get_contraction_with_reduce((2, 512, 1, 1), (2, 1, 512), (1,))
|
||||
self.assertIsNone(r)
|
||||
|
||||
def test_contraction(self):
|
||||
r = get_contraction((1,2,3,4), (2,3,4))
|
||||
self.assertEqual(r, [[0, 1], [2], [3]])
|
||||
|
||||
@@ -27,7 +27,7 @@ def get_contraction_with_reduce(old_shape:tuple[sint, ...], new_shape:tuple[sint
|
||||
while take_from < len(contraction) and len(contraction[take_from]) == 0:
|
||||
assert new_shape[take_from] == 1
|
||||
take_from += 1
|
||||
if take_from == len(contraction): return None # nothing to take
|
||||
if take_from == len(contraction) or new_shape[take_from] != 1: return None # nothing to take
|
||||
for j in range(take_from, i, -1):
|
||||
assert len(contraction[j]) > 0
|
||||
contraction[j-1] = contraction[j][:-1]
|
||||
|
||||
Reference in New Issue
Block a user