From 175cdbe815a9fcc06348588790c77e204563e130 Mon Sep 17 00:00:00 2001 From: chenyu Date: Tue, 14 Nov 2023 23:57:05 -0500 Subject: [PATCH] fix pad None will value (#2308) --- test/test_ops.py | 3 +++ tinygrad/tensor.py | 5 +++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 868019d455..c9874dd729 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -665,11 +665,14 @@ class TestOps(unittest.TestCase): helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (-1,2,-3,4)), lambda x: x.pad2d(padding=(-1,2,-3,4))) helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4), value=5), lambda x: x.pad2d(padding=(1,2,3,4),value=5)) helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (-1,2,-3,4), value=5), lambda x: x.pad2d(padding=(-1,2,-3,4),value=5)) + def test_pad(self): helper_test_op([(3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4)),lambda x: x.pad(((3,4),(1,2)))) helper_test_op([(3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4), value=5), lambda x: x.pad(((3,4), (1,2)), value=5)) helper_test_op([(3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4), value=float("inf")), lambda x: x.pad(((3,4), (1,2)), value=float("inf"))) helper_test_op([(3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4), value=float("-inf")), lambda x: x.pad(((3,4), (1,2)), value=float("-inf"))) + helper_test_op([(3,3)], lambda x: torch.nn.functional.pad(x, (0,0,3,4), value=1), lambda x: x.pad(((3,4), None), value=1)) + helper_test_op([(3,3)], lambda x: torch.nn.functional.pad(x, (0,0,0,0), value=1), lambda x: x.pad((None, None), value=1)) def test_transpose(self): helper_test_op([(3,3,3)], lambda x: x.transpose(1,2), lambda x: x.transpose(1,2)) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 50aa5eb48b..ffce6e9a65 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -255,8 +255,9 @@ class Tensor: def flip(self, axis, *args) -> Tensor: return mlops.Flip.apply(self, axis=[x if x >= 0 else x+len(self.shape) for x in argfix(axis, *args)]) def shrink(self, arg:Tuple[Optional[Tuple[sint, sint]], ...]) -> Tensor: return mlops.Shrink.apply(self, arg=tuple(x if x is not None else (0,s) for x,s in zip(arg, self.shape))) if any(x is not None and x != (0,s) for x,s in zip(arg, self.shape)) else self def pad(self, arg:Tuple[Optional[Tuple[int, int]], ...], value:float=0.0) -> Tensor: - ret = mlops.Pad.apply(self, arg=tuple(x if x is not None else (0,0) for x in arg)) if any(x is not None and x != (0,0) for x in arg) else self - return ret if 0 == value else ret + mlops.Pad.apply(Tensor.ones_like(self), arg=arg).where(0, value) + if all(x is None or x == (0,0) for x in arg): return self + ret = mlops.Pad.apply(self, arg=(narg:=tuple(x if x is not None else (0,0) for x in arg))) + return ret if 0 == value else ret + mlops.Pad.apply(Tensor.ones_like(self), arg=narg).where(0, value) # ***** movement hlops *****