mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-14 00:45:16 +08:00
Fix conv_transpose2d asymmetric padding (#840)
This commit is contained in:
@@ -441,9 +441,10 @@ class TestOps(unittest.TestCase):
|
||||
lambda x,w: Tensor.conv_transpose2d(x,w,groups=2).relu(), atol=1e-4, grad_rtol=1e-5)
|
||||
|
||||
def test_padded_conv_transpose2d(self):
|
||||
helper_test_op([(2,4,9,9), (4,4,3,3)],
|
||||
lambda x,w: torch.nn.functional.conv_transpose2d(x,w,padding=1).relu(),
|
||||
lambda x,w: Tensor.conv_transpose2d(x,w,padding=1).relu(), atol=1e-4, grad_rtol=1e-5)
|
||||
for padding in [(1,2), (2,1), 2, 1, 0]:
|
||||
helper_test_op([(2,4,9,9), (4,4,3,3)],
|
||||
lambda x,w: torch.nn.functional.conv_transpose2d(x,w,padding=padding).relu(),
|
||||
lambda x,w: Tensor.conv_transpose2d(x,w,padding=padding).relu(), atol=1e-4, grad_rtol=1e-5)
|
||||
|
||||
def test_dilated_conv_transpose2d(self):
|
||||
helper_test_op([(2,4,9,9), (4,4,3,3)],
|
||||
|
||||
@@ -385,8 +385,8 @@ class Tensor:
|
||||
x = x.pad(((0,0), (0,0), *flatten(((0,0),(0,s-1)) for s in stride)))
|
||||
x = x.reshape(*x.shape[:2], *[k*s for k,s in zip(x.shape[2::2], stride)])
|
||||
x = x.shrink(((0,x.shape[0]), (0,x.shape[1]), *[(0,k-(s-1)) for k,s in zip(x.shape[2:], stride)]))
|
||||
# TODO: the make_pair on padding is wrong in the asymmetric padding case
|
||||
return x.conv2d(w.reshape(w.shape[0]*w.shape[1],*w.shape[2:]), groups=groups, bias=bias, dilation=dilation, padding=flatten(((k-1)*d-p,(k-1)*d-p) for k,p,d in zip(HW, make_pair(padding, len(HW)), make_pair(dilation, len(HW)))))
|
||||
padding = flatten(((k-1)*d-p,(k-1)*d-p) for k,p,d in reversed(list(zip(HW, make_pair(padding, len(HW)), make_pair(dilation, len(HW))))))
|
||||
return x.conv2d(w.reshape(w.shape[0]*w.shape[1],*w.shape[2:]), groups=groups, bias=bias, dilation=dilation, padding=padding)
|
||||
|
||||
def conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0) -> Tensor:
|
||||
(bs,cin_), (cout,cin), HW = self.shape[:2], weight.shape[:2], weight.shape[2:]
|
||||
|
||||
Reference in New Issue
Block a user