Fix conv_transpose2d asymmetric padding (#840)

This commit is contained in:
Marcello Fuschi
2023-05-29 16:57:06 +02:00
committed by GitHub
parent 2fd2fb6380
commit 6ea5df19b2
2 changed files with 6 additions and 5 deletions

View File

@@ -441,9 +441,10 @@ class TestOps(unittest.TestCase):
lambda x,w: Tensor.conv_transpose2d(x,w,groups=2).relu(), atol=1e-4, grad_rtol=1e-5)
def test_padded_conv_transpose2d(self):
helper_test_op([(2,4,9,9), (4,4,3,3)],
lambda x,w: torch.nn.functional.conv_transpose2d(x,w,padding=1).relu(),
lambda x,w: Tensor.conv_transpose2d(x,w,padding=1).relu(), atol=1e-4, grad_rtol=1e-5)
for padding in [(1,2), (2,1), 2, 1, 0]:
helper_test_op([(2,4,9,9), (4,4,3,3)],
lambda x,w: torch.nn.functional.conv_transpose2d(x,w,padding=padding).relu(),
lambda x,w: Tensor.conv_transpose2d(x,w,padding=padding).relu(), atol=1e-4, grad_rtol=1e-5)
def test_dilated_conv_transpose2d(self):
helper_test_op([(2,4,9,9), (4,4,3,3)],

View File

@@ -385,8 +385,8 @@ class Tensor:
x = x.pad(((0,0), (0,0), *flatten(((0,0),(0,s-1)) for s in stride)))
x = x.reshape(*x.shape[:2], *[k*s for k,s in zip(x.shape[2::2], stride)])
x = x.shrink(((0,x.shape[0]), (0,x.shape[1]), *[(0,k-(s-1)) for k,s in zip(x.shape[2:], stride)]))
# TODO: the make_pair on padding is wrong in the asymmetric padding case
return x.conv2d(w.reshape(w.shape[0]*w.shape[1],*w.shape[2:]), groups=groups, bias=bias, dilation=dilation, padding=flatten(((k-1)*d-p,(k-1)*d-p) for k,p,d in zip(HW, make_pair(padding, len(HW)), make_pair(dilation, len(HW)))))
padding = flatten(((k-1)*d-p,(k-1)*d-p) for k,p,d in reversed(list(zip(HW, make_pair(padding, len(HW)), make_pair(dilation, len(HW))))))
return x.conv2d(w.reshape(w.shape[0]*w.shape[1],*w.shape[2:]), groups=groups, bias=bias, dilation=dilation, padding=padding)
def conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0) -> Tensor:
(bs,cin_), (cout,cin), HW = self.shape[:2], weight.shape[:2], weight.shape[2:]