From fde6c2d62b89c366a0df2fd58dd459eafc63099d Mon Sep 17 00:00:00 2001 From: George Hotz Date: Tue, 28 Feb 2023 16:50:46 -0800 Subject: [PATCH] fix image grouping --- tinygrad/image.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tinygrad/image.py b/tinygrad/image.py index 1b6fee5b4c..1fd05c08ee 100644 --- a/tinygrad/image.py +++ b/tinygrad/image.py @@ -38,6 +38,7 @@ def image_conv2d_decorator(normal_conv): # expand out rcin_hi, rcin_lo = cin//4 if cin >= 4 else 1, 4 if cin >= 4 else 1 + cout_expand = [groups//4 if cin == 1 else groups, 4 if cin == 1 else 1, rcout//4 if rcout >= 4 else 1, 4 if rcout >= 4 else 1] x = x.reshape(bs, iy, ix, groups, rcin_hi, rcin_lo) w = w.reshape(cout//4, H, rcin_hi, W, 4, rcin_lo) @@ -48,13 +49,12 @@ def image_conv2d_decorator(normal_conv): # prepare input x = x.permute(0,3,4,5,1,2)._pool((H, W), stride, dilation) # -> (bs, groups, rcin_hi, rcin_lo, oy, ox, H, W) oy, ox = x.shape[4:6] - x = x.permute(0,4,5,1,2,3,6,7).reshape(bs, oy, ox, groups, 1, 1, rcin_hi, rcin_lo, H, W) - x = x.expand(bs, oy, ox, groups, rcout//4 if rcout >= 4 else 1, 4 if rcout >= 4 else 1, rcin_hi, rcin_lo, H, W) - x = x.reshape(bs, oy, ox, cout//4, 4, rcin_hi, rcin_lo, H, W) + x = x.permute(0,4,5,1,2,3,6,7).reshape(bs, oy, ox, *cout_expand[0:2], 1, 1, rcin_hi, rcin_lo, H, W) + x = x.expand(bs, oy, ox, *cout_expand, rcin_hi, rcin_lo, H, W) # prepare weights w = w.permute(0,4,2,5,1,3) - w = w.reshape((1, 1, 1, cout//4, 4, rcin_hi, rcin_lo, H, W)) # needed or this is broadcasting? + w = w.reshape((1, 1, 1, *cout_expand, rcin_hi, rcin_lo, H, W)) # the conv! ret = (x*w).sum((-4, -3, -2, -1)).reshape(bs*oy, ox*cout//4, 4)