From ce41e6572d607fd9bd39a9fd47ebdbcec0e20231 Mon Sep 17 00:00:00 2001 From: chenyu Date: Thu, 12 Dec 2024 17:55:52 -0500 Subject: [PATCH] unit test merge_dim [pr] (#8195) looking for better ways to write this. first adding some tests --- test/unit/test_view.py | 22 +++++++++++++++++++++- tinygrad/shape/view.py | 6 +++--- 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/test/unit/test_view.py b/test/unit/test_view.py index 68f3037441..474ddd80b5 100644 --- a/test/unit/test_view.py +++ b/test/unit/test_view.py @@ -1,6 +1,6 @@ #!/usr/bin/env python import unittest -from tinygrad.shape.view import View +from tinygrad.shape.view import View, merge_dims class TestView(unittest.TestCase): def test_canonicalize_empty_mask(self): @@ -38,5 +38,25 @@ class TestView(unittest.TestCase): v = View.create(shape=(2,3,4), mask=((0,2),(0,3),(0,4))) assert v.contiguous +class TestMergeDims(unittest.TestCase): + def test_contiguous(self): + shape = (2, 3, 4) + strides = (12, 4, 1) #=strides_for_shape(shape) + m = merge_dims(shape, strides) + self.assertEqual(m, ((24, 1, 24),)) + + def test_0_in_strides(self): + shape = (2, 3, 4) + self.assertEqual(merge_dims(shape, (0, 4, 1)), ((2, 0, 0), (12, 1, 12))) + self.assertEqual(merge_dims(shape, (0, 0, 1)), ((6, 0, 0), (4, 1, 4))) + self.assertEqual(merge_dims(shape, (3, 1, 0)), ((6, 1, 6), (4, 0, 0))) + self.assertEqual(merge_dims(shape, (0, 0, 0)), ((24, 0, 0),)) + + def test_pad_reshape(self): + # st = ShapeTracker.from_shape((1, 2)).pad(((1, 0), (0, 1))).reshape((3, 2)) + self.assertEqual(merge_dims((2, 3), (0, 1), ((1, 2), (0, 2))), ((6, 1, 3),)) + # shift mask on stride 0 + self.assertEqual(merge_dims((2, 3), (0, 1), ((0, 1), (0, 2))), ((6, 1, 3),)) + if __name__ == '__main__': unittest.main() diff --git a/tinygrad/shape/view.py b/tinygrad/shape/view.py index 1a93c867fa..92d544e43b 100644 --- a/tinygrad/shape/view.py +++ b/tinygrad/shape/view.py @@ -17,7 +17,7 @@ def strides_for_shape(shape:Tuple[sint, ...]) -> Tuple[sint, ...]: return canonicalize_strides(shape, strides) @functools.lru_cache(maxsize=None) -def _merge_dims(shape:Tuple[int, ...], strides:Tuple[int, ...], mask:Optional[Tuple[Tuple[int, int], ...]]=None) -> Tuple[Tuple[int, int, int], ...]: +def merge_dims(shape:Tuple[int, ...], strides:Tuple[int, ...], mask:Optional[Tuple[Tuple[int, int], ...]]=None) -> Tuple[Tuple[int, int, int], ...]: # merge contiguous sub-parts or zero strided dims. ret = Tuple[(merged_size, stride, merged size w/o zero stride), ...] if not shape: return () assert len(shape) == len(strides) and (mask is None or len(shape) == len(mask)) @@ -233,7 +233,7 @@ class View: @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none def minify(self): - min_shape = tuple(x[0] for x in _merge_dims(self.shape, self.strides, self.mask)) + min_shape = tuple(x[0] for x in merge_dims(self.shape, self.strides, self.mask)) return nv if (nv := self.reshape(min_shape)) else self def __unsafe_resize(self, arg: Tuple[Tuple[sint, sint], ...], mask=None) -> View: @@ -320,7 +320,7 @@ class View: return View(new_shape, self.strides, self.offset, self.mask, self.contiguous) strides, r_new_shape = [], reversed(new_shape) - for merged_dim, new_stride, real_dim in reversed(_merge_dims(self.shape, self.strides, self.mask)): + for merged_dim, new_stride, real_dim in reversed(merge_dims(self.shape, self.strides, self.mask)): acc = 1 # TODO: third resolve shouldn't be needed while resolve(acc <= merged_dim) and resolve(acc != merged_dim) and resolve((new_dim := next(r_new_shape, 0)) > 0):