diff --git a/test/test_tensor.py b/test/test_tensor.py index 886cbd38c1..5f284bbf53 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -67,7 +67,7 @@ class TestTinygrad(unittest.TestCase): class TestOps(unittest.TestCase): def test_conv2d(self): x = torch.randn((5,2,10,7), requires_grad=True) - w = torch.randn((4,2,3,3), requires_grad=True) + w = torch.randn((4,2,3,2), requires_grad=True) xt = Tensor(x.detach().numpy()) wt = Tensor(w.detach().numpy()) diff --git a/tinygrad/utils.py b/tinygrad/utils.py index 833b1d1f5f..44d7678e06 100644 --- a/tinygrad/utils.py +++ b/tinygrad/utils.py @@ -1,4 +1,5 @@ import numpy as np +from functools import lru_cache def mask_like(like, mask_inx, mask_value = 1.0): mask = np.zeros_like(like).reshape(-1) @@ -31,14 +32,17 @@ def fetch_mnist(): # these are matlab functions used to speed up convs # write them fast and the convs will be fast? +@lru_cache +def get_im2col_indexes(oy, ox, cin, H, W): + idxc = np.tile(np.arange(cin).repeat(H*W), oy*ox) + idxy = np.tile(np.arange(H).repeat(W), oy*ox*cin) + np.arange(oy).repeat(ox*cin*H*W) + idxx = np.tile(np.arange(W), oy*ox*cin*H) + np.tile(np.arange(ox), oy).repeat(cin*H*W) + return idxc, idxy, idxx + def im2col(x, H, W): bs,cin,oy,ox = x.shape[0], x.shape[1], x.shape[2]-(H-1), x.shape[3]-(W-1) - - # TODO: use something like np.take for speed - tx = np.empty((bs, oy, ox, cin*W*H), dtype=x.dtype) - for Y in range(oy): - for X in range(ox): - tx[:, Y, X] = x[:, :, Y:Y+H, X:X+W].reshape(bs, -1) + ic, iy, ix = get_im2col_indexes(oy, ox, cin, H, W) + tx = x[:, ic, iy, ix] return tx.reshape(-1, cin*W*H) def col2im(tx, H, W, OY, OX):