fix pad None will value (#2308)

This commit is contained in:
chenyu
2023-11-14 23:57:05 -05:00
committed by GitHub
parent 01f8781c26
commit 175cdbe815
2 changed files with 6 additions and 2 deletions

View File

@@ -665,11 +665,14 @@ class TestOps(unittest.TestCase):
helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (-1,2,-3,4)), lambda x: x.pad2d(padding=(-1,2,-3,4)))
helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4), value=5), lambda x: x.pad2d(padding=(1,2,3,4),value=5))
helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (-1,2,-3,4), value=5), lambda x: x.pad2d(padding=(-1,2,-3,4),value=5))
def test_pad(self):
helper_test_op([(3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4)),lambda x: x.pad(((3,4),(1,2))))
helper_test_op([(3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4), value=5), lambda x: x.pad(((3,4), (1,2)), value=5))
helper_test_op([(3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4), value=float("inf")), lambda x: x.pad(((3,4), (1,2)), value=float("inf")))
helper_test_op([(3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4), value=float("-inf")), lambda x: x.pad(((3,4), (1,2)), value=float("-inf")))
helper_test_op([(3,3)], lambda x: torch.nn.functional.pad(x, (0,0,3,4), value=1), lambda x: x.pad(((3,4), None), value=1))
helper_test_op([(3,3)], lambda x: torch.nn.functional.pad(x, (0,0,0,0), value=1), lambda x: x.pad((None, None), value=1))
def test_transpose(self):
helper_test_op([(3,3,3)], lambda x: x.transpose(1,2), lambda x: x.transpose(1,2))

View File

@@ -255,8 +255,9 @@ class Tensor:
def flip(self, axis, *args) -> Tensor: return mlops.Flip.apply(self, axis=[x if x >= 0 else x+len(self.shape) for x in argfix(axis, *args)])
def shrink(self, arg:Tuple[Optional[Tuple[sint, sint]], ...]) -> Tensor: return mlops.Shrink.apply(self, arg=tuple(x if x is not None else (0,s) for x,s in zip(arg, self.shape))) if any(x is not None and x != (0,s) for x,s in zip(arg, self.shape)) else self
def pad(self, arg:Tuple[Optional[Tuple[int, int]], ...], value:float=0.0) -> Tensor:
ret = mlops.Pad.apply(self, arg=tuple(x if x is not None else (0,0) for x in arg)) if any(x is not None and x != (0,0) for x in arg) else self
return ret if 0 == value else ret + mlops.Pad.apply(Tensor.ones_like(self), arg=arg).where(0, value)
if all(x is None or x == (0,0) for x in arg): return self
ret = mlops.Pad.apply(self, arg=(narg:=tuple(x if x is not None else (0,0) for x in arg)))
return ret if 0 == value else ret + mlops.Pad.apply(Tensor.ones_like(self), arg=narg).where(0, value)
# ***** movement hlops *****