From b0c0c5d0d62e40df97999c8df3c1809ecb7842db Mon Sep 17 00:00:00 2001 From: Ryan Neph Date: Sun, 8 Nov 2020 11:45:55 -0800 Subject: [PATCH] strided Pool funcs (#74) * *Pool2D GPU forward supports stride * kernel_size from ctx instead of saved_tensors * *Pool2D CPU forward supports stride * update ctx.stride properly --- test/test_ops.py | 27 ++++++++++++--------- tinygrad/ops.py | 58 ++++++++++++++++++++++++++-------------------- tinygrad/opsgpu.py | 47 +++++++++++++++++++++---------------- 3 files changed, 76 insertions(+), 56 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index e3c57dd16a..e372b4b12b 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -87,18 +87,23 @@ class TestOps(unittest.TestCase): lambda x,w: torch.nn.functional.conv2d(x,w,stride=(2,1)).relu(), lambda x,w: Tensor.conv2d(x,w,stride=(2,1)).relu(), atol=2e-5, grad_atol=2e-6, gpu=self.gpu, forward_only=self.gpu) - def test_maxpool2x2(self): - helper_test_op([(32,2,110,28)], lambda x: torch.nn.functional.max_pool2d(x, (2,2)), Tensor.max_pool2d, gpu=self.gpu, forward_only=self.gpu) + def test_maxpool2d(self): + for ksz in [(2,2), (3,3), (3,2), (5,5), (5,1)]: + for strd in [(1,1), (2,1), (2,2), (4,2)]: + # TODO Grad tolerance for CPU implementation needs to be slightly relaxed; why? + with self.subTest(kernel_size=ksz, stride=strd): + helper_test_op([(32,2,110,28)], + lambda x: torch.nn.functional.max_pool2d(x, kernel_size=ksz, stride=strd), + lambda x: Tensor.max_pool2d(x, kernel_size=ksz, stride=strd), gpu=self.gpu, forward_only=self.gpu, grad_atol=1e-3) - def test_maxpool_sizes(self): - for sz in [(2,2), (3,3), (3,2), (5,5), (5,1)]: - helper_test_op([(32,2,110,28)], - lambda x: torch.nn.functional.max_pool2d(x, kernel_size=sz), - lambda x: Tensor.max_pool2d(x, kernel_size=sz), gpu=self.gpu, forward_only=self.gpu) - - def test_avgpool2x2(self): - # TODO Grad tolerance needs to be slightly relaxed; why? - helper_test_op([(32,2,111,28)], lambda x: torch.nn.functional.avg_pool2d(x, (2,2)), Tensor.avg_pool2d, gpu=self.gpu, grad_atol=1e-5) + def test_avgpool2d(self): + for ksz in [(2,2), (3,3), (3,2), (5,5), (5,1)]: + for strd in [(1,1), (2,1), (2,2), (4,2)]: + # TODO Grad tolerance needs to be slightly relaxed; why? + with self.subTest(kernel_size=ksz, stride=strd): + helper_test_op([(32,2,111,28)], + lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=ksz, stride=strd), + lambda x: Tensor.avg_pool2d(x, kernel_size=ksz, stride=strd), gpu=self.gpu, grad_atol=1e-5) if GPU: class TestOpsGPU(TestOps): diff --git a/tinygrad/ops.py b/tinygrad/ops.py index d9082144d9..6373adc5f9 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -1,3 +1,4 @@ +import warnings import numpy as np from .tensor import Function, register @@ -236,47 +237,54 @@ register('conv2d', Conv2D) # ************* pooling ops ************* -def stack_for_pool(x, py, px): - my, mx = (x.shape[2]//py)*py, (x.shape[3]//px)*px - stack = [] - xup = x[:, :, :my, :mx] - for Y in range(py): - for X in range(px): - stack.append(xup[:, :, Y::py, X::px][None]) - return np.concatenate(stack, axis=0) +def stack_for_pool(x, kernel_size, stride, fill_value=0): + (ky, kx), (py, px) = kernel_size, stride + my, mx = (x.shape[2]-ky)//py+1, (x.shape[3]-kx)//px+1 + stack = fill_value*np.ones((ky, kx, *x.shape[:2], my+ky, mx+kx), dtype=x.dtype) + for Y in range(ky): + for X in range(kx): + sl = x[..., Y:Y+my*py+ky:py, X:X+mx*px+kx:px] + stack[Y, X, ..., :sl.shape[2], :sl.shape[3]] = sl + return stack.reshape(-1, *stack.shape[2:]), (my, mx) -def unstack_for_pool(fxn, s, py, px): - my, mx = (s[2]//py)*py, (s[3]//px)*px - for Y in range(py): - for X in range(px): - ll = fxn(Y*px+X) +def unstack_for_pool(fxn, s, kernel_size, stride): + (ky, kx), (py, px) = kernel_size, stride + for Y in range(ky): + for X in range(kx): + ll = fxn(Y*kx+X) if X == 0 and Y == 0: - ret = np.zeros(s, dtype=ll.dtype) - ret[:, :, Y:my:py, X:mx:px] = ll - return ret + ret = np.zeros((*s[:2], s[2]+ky, s[3]+kx), dtype=ll.dtype) + ret[..., Y:Y+ll.shape[2]*py:py, X:X+ll.shape[3]*px:px] = ll + return ret[..., :s[2], :s[3]] class MaxPool2D(Function): @staticmethod - def forward(ctx, x, kernel_size=(2, 2)): - stack = stack_for_pool(x, *kernel_size) - idxs = np.argmax(stack, axis=0) + def forward(ctx, x, kernel_size=(2, 2), stride=None): + if not stride: + ctx.stride = stride = kernel_size + stack, (my, mx) = stack_for_pool(x, kernel_size, stride, fill_value=-np.inf) + idxs = np.nanargmax(stack, axis=0)[..., :my, :mx] ctx.save_for_backward(idxs, x.shape) - return np.max(stack, axis=0) + return np.amax(stack, axis=0)[..., :my, :mx] @staticmethod def backward(ctx, grad_output): idxs,s = ctx.saved_tensors return unstack_for_pool( lambda idx: grad_output * (idxs == idx), - s, *ctx.kernel_size) + s, ctx.kernel_size, ctx.stride) register('max_pool2d', MaxPool2D) class AvgPool2D(Function): @staticmethod - def forward(ctx, x, kernel_size=(2, 2)): - stack = stack_for_pool(x, *kernel_size) + def forward(ctx, x, kernel_size=(2, 2), stride=None): + if not stride: + ctx.stride = stride = kernel_size + stack, (my, mx) = stack_for_pool(x, kernel_size, stride, fill_value=np.nan) ctx.save_for_backward(x.shape) - return np.mean(stack, axis=0) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + return np.nanmean(stack, axis=0)[...,:my, :mx] @staticmethod def backward(ctx, grad_output): @@ -284,6 +292,6 @@ class AvgPool2D(Function): py, px = ctx.kernel_size return unstack_for_pool( lambda idx: grad_output/py/px, - s, py, px) + s, ctx.kernel_size, ctx.stride) register('avg_pool2d', AvgPool2D) diff --git a/tinygrad/opsgpu.py b/tinygrad/opsgpu.py index cbf1bf9df0..b8896a0c2b 100644 --- a/tinygrad/opsgpu.py +++ b/tinygrad/opsgpu.py @@ -28,14 +28,15 @@ def clbuild(cl_ctx, prg): def cl_subsample_krnl_build(cl_ctx, iter_op, result_op, init_val=0): prg = """ __kernel void subsample( - __global float *output, __global const float *input, uint2 osize, uint2 isize, uint2 kernel_size, int nelem + __global float *output, __global const float *input, uint2 osize, uint2 isize, uint2 kernel_size, + uint2 stride, int nelem ) { int3 gid = (int3)(get_global_id(2), get_global_id(1), get_global_id(0)); int oid = gid.x + osize.x*(gid.y + osize.y*gid.z); float group_res = """+str(init_val)+"""; for (uint j=0; j