mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 08:28:55 +08:00
unit test merge_dim [pr] (#8195)
looking for better ways to write this. first adding some tests
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
import unittest
|
||||
from tinygrad.shape.view import View
|
||||
from tinygrad.shape.view import View, merge_dims
|
||||
|
||||
class TestView(unittest.TestCase):
|
||||
def test_canonicalize_empty_mask(self):
|
||||
@@ -38,5 +38,25 @@ class TestView(unittest.TestCase):
|
||||
v = View.create(shape=(2,3,4), mask=((0,2),(0,3),(0,4)))
|
||||
assert v.contiguous
|
||||
|
||||
class TestMergeDims(unittest.TestCase):
|
||||
def test_contiguous(self):
|
||||
shape = (2, 3, 4)
|
||||
strides = (12, 4, 1) #=strides_for_shape(shape)
|
||||
m = merge_dims(shape, strides)
|
||||
self.assertEqual(m, ((24, 1, 24),))
|
||||
|
||||
def test_0_in_strides(self):
|
||||
shape = (2, 3, 4)
|
||||
self.assertEqual(merge_dims(shape, (0, 4, 1)), ((2, 0, 0), (12, 1, 12)))
|
||||
self.assertEqual(merge_dims(shape, (0, 0, 1)), ((6, 0, 0), (4, 1, 4)))
|
||||
self.assertEqual(merge_dims(shape, (3, 1, 0)), ((6, 1, 6), (4, 0, 0)))
|
||||
self.assertEqual(merge_dims(shape, (0, 0, 0)), ((24, 0, 0),))
|
||||
|
||||
def test_pad_reshape(self):
|
||||
# st = ShapeTracker.from_shape((1, 2)).pad(((1, 0), (0, 1))).reshape((3, 2))
|
||||
self.assertEqual(merge_dims((2, 3), (0, 1), ((1, 2), (0, 2))), ((6, 1, 3),))
|
||||
# shift mask on stride 0
|
||||
self.assertEqual(merge_dims((2, 3), (0, 1), ((0, 1), (0, 2))), ((6, 1, 3),))
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -17,7 +17,7 @@ def strides_for_shape(shape:Tuple[sint, ...]) -> Tuple[sint, ...]:
|
||||
return canonicalize_strides(shape, strides)
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def _merge_dims(shape:Tuple[int, ...], strides:Tuple[int, ...], mask:Optional[Tuple[Tuple[int, int], ...]]=None) -> Tuple[Tuple[int, int, int], ...]:
|
||||
def merge_dims(shape:Tuple[int, ...], strides:Tuple[int, ...], mask:Optional[Tuple[Tuple[int, int], ...]]=None) -> Tuple[Tuple[int, int, int], ...]:
|
||||
# merge contiguous sub-parts or zero strided dims. ret = Tuple[(merged_size, stride, merged size w/o zero stride), ...]
|
||||
if not shape: return ()
|
||||
assert len(shape) == len(strides) and (mask is None or len(shape) == len(mask))
|
||||
@@ -233,7 +233,7 @@ class View:
|
||||
|
||||
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
||||
def minify(self):
|
||||
min_shape = tuple(x[0] for x in _merge_dims(self.shape, self.strides, self.mask))
|
||||
min_shape = tuple(x[0] for x in merge_dims(self.shape, self.strides, self.mask))
|
||||
return nv if (nv := self.reshape(min_shape)) else self
|
||||
|
||||
def __unsafe_resize(self, arg: Tuple[Tuple[sint, sint], ...], mask=None) -> View:
|
||||
@@ -320,7 +320,7 @@ class View:
|
||||
return View(new_shape, self.strides, self.offset, self.mask, self.contiguous)
|
||||
|
||||
strides, r_new_shape = [], reversed(new_shape)
|
||||
for merged_dim, new_stride, real_dim in reversed(_merge_dims(self.shape, self.strides, self.mask)):
|
||||
for merged_dim, new_stride, real_dim in reversed(merge_dims(self.shape, self.strides, self.mask)):
|
||||
acc = 1
|
||||
# TODO: third resolve shouldn't be needed
|
||||
while resolve(acc <= merged_dim) and resolve(acc != merged_dim) and resolve((new_dim := next(r_new_shape, 0)) > 0):
|
||||
|
||||
Reference in New Issue
Block a user