mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
match torch api for pad2d
This commit is contained in:
@@ -63,7 +63,7 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(45,65)], lambda x: torch.nn.LogSoftmax(dim=1)(x), Tensor.logsoftmax, atol=1e-7, grad_atol=1e-7, gpu=self.gpu)
|
||||
|
||||
def test_pad2d(self):
|
||||
helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (1,1,1,1)), lambda x: x.pad2d(padding=(1,1,1,1)), gpu=self.gpu)
|
||||
helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4)), lambda x: x.pad2d(padding=(1,2,3,4)), gpu=self.gpu)
|
||||
|
||||
def test_conv2d(self):
|
||||
for bs in [1,8]:
|
||||
|
||||
@@ -102,12 +102,12 @@ class Pad2D(Function):
|
||||
ctx.save_for_backward(padding)
|
||||
return np.pad(x,
|
||||
((0,0), (0,0),
|
||||
(padding[0], padding[1]), (padding[2], padding[3])))
|
||||
(padding[2], padding[3]), (padding[0], padding[1])))
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
padding, = ctx.saved_tensors
|
||||
return grad_output[..., padding[0]:-padding[1], padding[2]:-padding[3]]
|
||||
return grad_output[..., padding[2]:-padding[3], padding[0]:-padding[1]]
|
||||
register('pad2d', Pad2D)
|
||||
|
||||
class Reshape(Function):
|
||||
|
||||
@@ -295,7 +295,7 @@ class Pad2D(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, padding=None):
|
||||
bs,cin,iy,ix = x.shape
|
||||
oy,ox = iy+padding[0]+padding[1], ix+padding[2]+padding[3]
|
||||
oy,ox = iy+padding[2]+padding[3], ix+padding[0]+padding[1]
|
||||
ret = buffer_zeros(ctx, (bs, cin, oy, ox))
|
||||
|
||||
prg = clbuild(ctx.cl_ctx, """
|
||||
@@ -319,7 +319,7 @@ class Pad2D(Function):
|
||||
ctx.save_for_backward(padding)
|
||||
prg.pad2d(ctx.cl_queue, [bs, cin, iy], None,
|
||||
x, ret,
|
||||
np.int32(cin), np.int32(padding[0]), np.int32(padding[2]),
|
||||
np.int32(cin), np.int32(padding[2]), np.int32(padding[0]),
|
||||
np.int32(oy), np.int32(ox), np.int32(iy), np.int32(ix)
|
||||
)
|
||||
return ret
|
||||
@@ -328,7 +328,7 @@ class Pad2D(Function):
|
||||
def backward(ctx, grad_output):
|
||||
padding, = ctx.saved_tensors
|
||||
bs, cin, iy, ix = grad_output.shape
|
||||
oy, ox = iy - padding[0] - padding[1], ix - padding[2] - padding[3]
|
||||
oy, ox = iy - padding[2] - padding[3], ix - padding[0] - padding[1]
|
||||
ret = buffer_new(ctx, (bs, cin, oy, ox))
|
||||
prg = clbuild(ctx.cl_ctx, """
|
||||
__kernel void pad2d(
|
||||
@@ -350,7 +350,7 @@ class Pad2D(Function):
|
||||
""")
|
||||
prg.pad2d(ctx.cl_queue, [bs, cin, oy], None,
|
||||
grad_output, ret,
|
||||
np.int32(cin), np.int32(padding[0]), np.int32(padding[2]),
|
||||
np.int32(cin), np.int32(padding[2]), np.int32(padding[0]),
|
||||
np.int32(oy), np.int32(ox), np.int32(iy), np.int32(ix)
|
||||
)
|
||||
return ret
|
||||
|
||||
Reference in New Issue
Block a user