fix image grouping

This commit is contained in:
George Hotz
2023-02-28 16:50:46 -08:00
parent 17c55f051d
commit fde6c2d62b

View File

@@ -38,6 +38,7 @@ def image_conv2d_decorator(normal_conv):
# expand out
rcin_hi, rcin_lo = cin//4 if cin >= 4 else 1, 4 if cin >= 4 else 1
cout_expand = [groups//4 if cin == 1 else groups, 4 if cin == 1 else 1, rcout//4 if rcout >= 4 else 1, 4 if rcout >= 4 else 1]
x = x.reshape(bs, iy, ix, groups, rcin_hi, rcin_lo)
w = w.reshape(cout//4, H, rcin_hi, W, 4, rcin_lo)
@@ -48,13 +49,12 @@ def image_conv2d_decorator(normal_conv):
# prepare input
x = x.permute(0,3,4,5,1,2)._pool((H, W), stride, dilation) # -> (bs, groups, rcin_hi, rcin_lo, oy, ox, H, W)
oy, ox = x.shape[4:6]
x = x.permute(0,4,5,1,2,3,6,7).reshape(bs, oy, ox, groups, 1, 1, rcin_hi, rcin_lo, H, W)
x = x.expand(bs, oy, ox, groups, rcout//4 if rcout >= 4 else 1, 4 if rcout >= 4 else 1, rcin_hi, rcin_lo, H, W)
x = x.reshape(bs, oy, ox, cout//4, 4, rcin_hi, rcin_lo, H, W)
x = x.permute(0,4,5,1,2,3,6,7).reshape(bs, oy, ox, *cout_expand[0:2], 1, 1, rcin_hi, rcin_lo, H, W)
x = x.expand(bs, oy, ox, *cout_expand, rcin_hi, rcin_lo, H, W)
# prepare weights
w = w.permute(0,4,2,5,1,3)
w = w.reshape((1, 1, 1, cout//4, 4, rcin_hi, rcin_lo, H, W)) # needed or this is broadcasting?
w = w.reshape((1, 1, 1, *cout_expand, rcin_hi, rcin_lo, H, W))
# the conv!
ret = (x*w).sum((-4, -3, -2, -1)).reshape(bs*oy, ox*cout//4, 4)