diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 4606698a43..9a3190c290 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -52,7 +52,7 @@ jobs: - name: Fuzz Test shapetracker run: | PYTHONPATH="." python test/external/fuzz_shapetracker.py - FUZZ=invert PYTHONPATH="." python test/external/fuzz_shapetracker_math.py + PYTHONPATH="." python test/external/fuzz_shapetracker_math.py - name: Test shapetracker to_movement_ops run: PYTHONPATH="." python extra/to_movement_ops.py - name: Use as an external package diff --git a/test/external/fuzz_shapetracker_math.py b/test/external/fuzz_shapetracker_math.py index 13f67dbcdd..ae7c2dc2e7 100644 --- a/test/external/fuzz_shapetracker_math.py +++ b/test/external/fuzz_shapetracker_math.py @@ -1,21 +1,10 @@ import random -from typing import List from tqdm import trange from tinygrad.helpers import getenv, DEBUG, colored from tinygrad.shape.shapetracker import ShapeTracker from test.external.fuzz_shapetracker import shapetracker_ops from test.external.fuzz_shapetracker import do_permute, do_reshape_split_one, do_reshape_combine_two, do_flip, do_pad - -class MultiShapeTracker: - def __init__(self, sts:List[ShapeTracker]): self.sts = sts - @property - def shape(self): return self.sts[0].shape - def reshape(self, arg): self.sts = [x.reshape(arg) for x in self.sts] - def permute(self, arg): self.sts = [x.permute(arg) for x in self.sts] - def expand(self, arg): self.sts = [x.expand(arg) for x in self.sts] - def shrink(self, arg): self.sts = [x.shrink(arg) for x in self.sts] - def stride(self, arg): self.sts = [x.stride(arg) for x in self.sts] - def pad(self, arg): self.sts = [x.pad(arg) for x in self.sts] +from test.unit.test_shapetracker_math import st_equal, MultiShapeTracker def fuzz_plus(): m = MultiShapeTracker([ShapeTracker.from_shape((random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)))]) @@ -34,20 +23,18 @@ def fuzz_invert(): m = MultiShapeTracker([start]) for _ in range(8): random.choice(invertible_shapetracker_ops)(m) inv = m.sts[0].invert(start.shape) - st_sum = (ShapeTracker.from_shape(m.sts[0].shape) + inv) if inv else None + st_sum = (m.sts[0] + inv) if inv else None return start, st_sum if __name__ == "__main__": # random.seed(42) total = getenv("CNT", 1000) for fuzz in [globals()[f'fuzz_{x}'] for x in getenv("FUZZ", "invert,plus").split(",")]: - good = 0 - for _ in trange(total): + for _ in trange(total, desc=f"{fuzz}"): st1, st2 = fuzz() - if st1 == st2: good += 1 - if st1 != st2 or DEBUG >= 1: + eq = st_equal(st1, st2) + if DEBUG >= 1: print(f"EXP: {st1}") print(f"GOT: {st2}") - print(colored("****", "red" if st1 != st2 else "green")) - print(f"hit {good}/{total}") - assert good == total + print(colored("****", "green" if eq else "red")) + if not eq: exit(0) diff --git a/test/unit/test_shapetracker_math.py b/test/unit/test_shapetracker_math.py index a7a885550a..7b99c633ce 100644 --- a/test/unit/test_shapetracker_math.py +++ b/test/unit/test_shapetracker_math.py @@ -1,5 +1,38 @@ import unittest +from typing import List +from tinygrad.helpers import prod +from tinygrad.shape.view import View from tinygrad.shape.shapetracker import ShapeTracker +from tinygrad.shape.symbolic import Variable, sym_infer + +class MultiShapeTracker: + def __init__(self, sts:List[ShapeTracker]): self.sts = sts + @property + def shape(self): return self.sts[0].shape + def reshape(self, arg): self.sts = [x.reshape(arg) for x in self.sts] + def permute(self, arg): self.sts = [x.permute(arg) for x in self.sts] + def expand(self, arg): self.sts = [x.expand(arg) for x in self.sts] + def shrink(self, arg): self.sts = [x.shrink(arg) for x in self.sts] + def stride(self, arg): self.sts = [x.stride(arg) for x in self.sts] + def pad(self, arg): self.sts = [x.pad(arg) for x in self.sts] + +def st_equal(st1, st2) -> bool: + if st1.shape != st2.shape: return False + if st1 == st2: return True + idx = Variable("idx", 0, prod(st1.shape)-1) + st1_idx, st1_valid = st1.expr_node(idx) + st2_idx, st2_valid = st2.expr_node(idx) + for i in range(idx.min, idx.max): + st1_off = sym_infer(st1_idx, {idx: i}) + st2_off = sym_infer(st2_idx, {idx: i}) + st1_v = sym_infer(st1_valid, {idx: i}) + st2_v = sym_infer(st2_valid, {idx: i}) + if st1_v != st2_v or (st1_off != st2_off and st1_v): + print(f"ST MISMATCH @ {i}, {st1_v=} != {st2_v=}, {st1_off=} != {st2_off=}") + print(st1) + print(st2) + return False + return True class TestShapeTrackerBasics(unittest.TestCase): def test_pad_shrink_removes_mask(self): @@ -22,6 +55,11 @@ class TestShapeTrackerBasics(unittest.TestCase): x1 = x1.reshape( (2, 2, 5) ) assert x == x1.simplify() + def test_simplify_is_correct(self): + multiv = ShapeTracker(views=(View(shape=(15, 3), strides=(9, 1), offset=6, mask=None, contiguous=False), + View(shape=(4, 3), strides=(12, 4), offset=0, mask=None, contiguous=False))) + assert st_equal(multiv, multiv.simplify()) + class TestShapeTrackerAdd(unittest.TestCase): def test_simple_add_reshape(self): a = ShapeTracker.from_shape((10, 10)) @@ -36,6 +74,16 @@ class TestShapeTrackerAdd(unittest.TestCase): b = b.permute((1,0)) assert a+b == ShapeTracker.from_shape((10, 10)) + def test_plus_real1(self): + st = MultiShapeTracker([ShapeTracker.from_shape((15, 9))]) + st.shrink( ((0, 15), (6, 9)) ) + backup = st.sts[0] + st.sts.append(ShapeTracker.from_shape(backup.shape)) + st.reshape( (45,) ) + st.stride( (4,) ) + st.reshape( (4, 3) ) + assert st_equal(backup + st.sts[1], st.sts[0]) + class TestShapeTrackerInvert(unittest.TestCase): def test_invert_reshape(self): a = ShapeTracker.from_shape((10, 10)) @@ -46,13 +94,20 @@ class TestShapeTrackerInvert(unittest.TestCase): def test_invert_permute(self): a = ShapeTracker.from_shape((5, 20)) x = a.permute((1,0)) - ap = ShapeTracker.from_shape(x.shape) + x.invert(a.shape) + ap = x + x.invert(a.shape) assert ap == a, f"{ap} != {a}" def test_invert_permute_3(self): a = ShapeTracker.from_shape((8, 4, 5)) x = a.permute((1,2,0)) - ap = ShapeTracker.from_shape(x.shape) + x.invert(a.shape) + ap = x + x.invert(a.shape) + assert ap == a, f"{ap} != {a}" + + def test_invert_real1(self): + a = ShapeTracker.from_shape((3, 6, 10)) + x = a.reshape( (3, 3, 2, 10) ) + x = x.permute( (2, 1, 3, 0) ) + ap = x + x.invert(a.shape) assert ap == a, f"{ap} != {a}" def test_cant_invert_expand(self): @@ -66,10 +121,17 @@ class TestShapeTrackerInvert(unittest.TestCase): assert x.invert(a.shape) is None def test_can_invert_flip(self): - a = ShapeTracker.from_shape((10, 10)) + a = ShapeTracker.from_shape((20, 10)) x = a.stride((-1,1)) - ap = ShapeTracker.from_shape(x.shape) + x.invert(a.shape) - assert ap == a, f"{ap} != {a}" + ap = x + x.invert(a.shape) + assert st_equal(ap, a) + + def test_can_invert_flip_permute(self): + a = ShapeTracker.from_shape((20, 10)) + x = a.permute((1,0)) + x = x.stride((-1,1)) + ap = x + x.invert(a.shape) + assert st_equal(ap, a) def test_cant_invert_stride(self): a = ShapeTracker.from_shape((10, 10)) @@ -81,8 +143,8 @@ class TestShapeTrackerInvert(unittest.TestCase): x = a.pad( ((2, 0), (0, 0)) ) x = x.reshape( (2, 2, 5) ) x = x.reshape( (4, 5) ) - ap = ShapeTracker.from_shape(x.shape) + x.invert(a.shape) - assert ap == a, f"{ap} != {a}" + ap = x + x.invert(a.shape) + assert st_equal(ap, a) if __name__ == '__main__': unittest.main() diff --git a/tinygrad/shape/shapetracker.py b/tinygrad/shape/shapetracker.py index d677cef417..29d3db9c8f 100644 --- a/tinygrad/shape/shapetracker.py +++ b/tinygrad/shape/shapetracker.py @@ -57,11 +57,14 @@ class ShapeTracker: def __post_init__(self): assert isinstance(self.views, tuple) and all(isinstance(v, View) for v in self.views), "ShapeTracker must be created with a tuple of Views" - def __add__(self, st:ShapeTracker) -> ShapeTracker: return ShapeTracker(self.views + st.views).simplify() + def __add__(self, st:ShapeTracker) -> ShapeTracker: + base = ShapeTracker(self.views) + for v in st.views: base = ShapeTracker(base.views + (v,)).simplify() # one view at a time = better simplification + return base def invert(self, out_shape:Tuple[sint, ...]) -> Optional[ShapeTracker]: ret = tuple(v.invert(s) for v,s in zip(self.views[::-1], [x.shape for x in self.views[::-1][1:]]+[out_shape])) - return ShapeTracker(cast(Tuple[View, ...], ret)) if all(x is not None for x in ret) else None + return ShapeTracker(cast(Tuple[View, ...], ret)).reshape(out_shape) if all(x is not None for x in ret) else None @staticmethod def from_shape(shape:Tuple[sint, ...]): return ShapeTracker((View.create(shape),)) @@ -115,13 +118,14 @@ class ShapeTracker: idxs: List[Node] = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(self.shape)] idx, valid = self.expr_idxs(idxs) ret: List[Optional[sint]] = [None] * len(self.views[-1].shape) + bad_idx_vars: Set[Variable] = set() for this_dim in (idx.nodes if isinstance(idx, SumNode) else [idx]): idx_maybe, stride_maybe = (this_dim.a, this_dim.b) if isinstance(this_dim, MulNode) else (this_dim, 1) try: ret[idxs.index(idx_maybe)] = stride_maybe - except ValueError: pass + except ValueError: bad_idx_vars = bad_idx_vars.union(idx_maybe.vars()) idx_vars, valid_vars = idx.vars(), valid.vars() for i,tidx in enumerate(idxs): - if tidx in valid_vars and not ignore_valid: ret[i] = None + if tidx in bad_idx_vars or (tidx in valid_vars and not ignore_valid): ret[i] = None elif tidx not in idx_vars: ret[i] = 0 return tuple(ret) diff --git a/tinygrad/shape/view.py b/tinygrad/shape/view.py index 74950d5d01..f7ec6f0863 100644 --- a/tinygrad/shape/view.py +++ b/tinygrad/shape/view.py @@ -102,11 +102,10 @@ class View: @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none def invert(self, out_shape:Tuple[sint, ...]) -> Optional[View]: - ret = self.shrink(self.mask) if self.mask else self - if prod(ret.shape) != prod(out_shape): return None # don't support shrink, expand, or stride != (-1, 1) - ret = cast(View, ret.reshape(tuple(s for s in ret.shape if s != 1))) # removing ones will never be an issue - ret = ret.stride(tuple(-1 if x < 0 else 1 for x in ret.strides)) - return ret.permute(argsort(tuple(-x for x in ret.strides))).reshape(out_shape) + ret = View.create(self.shape) + if self.mask: ret = ret.shrink(self.mask) + ret = ret.stride(tuple(-1 if x < 0 else 1 for x in self.strides)).permute(argsort(tuple(-x if x > 0 else x for x in self.strides))) + return ret if prod(ret.shape) == prod(out_shape) else None # don't support shrink, expand, or stride != (-1, 1) # MovementOps live here now