match torch api for pad2d

This commit is contained in:
George Hotz
2020-11-09 17:48:56 -08:00
parent daf073535f
commit 866b759d3b
3 changed files with 7 additions and 7 deletions

View File

@@ -63,7 +63,7 @@ class TestOps(unittest.TestCase):
helper_test_op([(45,65)], lambda x: torch.nn.LogSoftmax(dim=1)(x), Tensor.logsoftmax, atol=1e-7, grad_atol=1e-7, gpu=self.gpu)
def test_pad2d(self):
helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (1,1,1,1)), lambda x: x.pad2d(padding=(1,1,1,1)), gpu=self.gpu)
helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4)), lambda x: x.pad2d(padding=(1,2,3,4)), gpu=self.gpu)
def test_conv2d(self):
for bs in [1,8]:

View File

@@ -102,12 +102,12 @@ class Pad2D(Function):
ctx.save_for_backward(padding)
return np.pad(x,
((0,0), (0,0),
(padding[0], padding[1]), (padding[2], padding[3])))
(padding[2], padding[3]), (padding[0], padding[1])))
@staticmethod
def backward(ctx, grad_output):
padding, = ctx.saved_tensors
return grad_output[..., padding[0]:-padding[1], padding[2]:-padding[3]]
return grad_output[..., padding[2]:-padding[3], padding[0]:-padding[1]]
register('pad2d', Pad2D)
class Reshape(Function):

View File

@@ -295,7 +295,7 @@ class Pad2D(Function):
@staticmethod
def forward(ctx, x, padding=None):
bs,cin,iy,ix = x.shape
oy,ox = iy+padding[0]+padding[1], ix+padding[2]+padding[3]
oy,ox = iy+padding[2]+padding[3], ix+padding[0]+padding[1]
ret = buffer_zeros(ctx, (bs, cin, oy, ox))
prg = clbuild(ctx.cl_ctx, """
@@ -319,7 +319,7 @@ class Pad2D(Function):
ctx.save_for_backward(padding)
prg.pad2d(ctx.cl_queue, [bs, cin, iy], None,
x, ret,
np.int32(cin), np.int32(padding[0]), np.int32(padding[2]),
np.int32(cin), np.int32(padding[2]), np.int32(padding[0]),
np.int32(oy), np.int32(ox), np.int32(iy), np.int32(ix)
)
return ret
@@ -328,7 +328,7 @@ class Pad2D(Function):
def backward(ctx, grad_output):
padding, = ctx.saved_tensors
bs, cin, iy, ix = grad_output.shape
oy, ox = iy - padding[0] - padding[1], ix - padding[2] - padding[3]
oy, ox = iy - padding[2] - padding[3], ix - padding[0] - padding[1]
ret = buffer_new(ctx, (bs, cin, oy, ox))
prg = clbuild(ctx.cl_ctx, """
__kernel void pad2d(
@@ -350,7 +350,7 @@ class Pad2D(Function):
""")
prg.pad2d(ctx.cl_queue, [bs, cin, oy], None,
grad_output, ret,
np.int32(cin), np.int32(padding[0]), np.int32(padding[2]),
np.int32(cin), np.int32(padding[2]), np.int32(padding[0]),
np.int32(oy), np.int32(ox), np.int32(iy), np.int32(ix)
)
return ret