mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 08:28:55 +08:00
fix image grouping
This commit is contained in:
@@ -38,6 +38,7 @@ def image_conv2d_decorator(normal_conv):
|
||||
|
||||
# expand out
|
||||
rcin_hi, rcin_lo = cin//4 if cin >= 4 else 1, 4 if cin >= 4 else 1
|
||||
cout_expand = [groups//4 if cin == 1 else groups, 4 if cin == 1 else 1, rcout//4 if rcout >= 4 else 1, 4 if rcout >= 4 else 1]
|
||||
x = x.reshape(bs, iy, ix, groups, rcin_hi, rcin_lo)
|
||||
w = w.reshape(cout//4, H, rcin_hi, W, 4, rcin_lo)
|
||||
|
||||
@@ -48,13 +49,12 @@ def image_conv2d_decorator(normal_conv):
|
||||
# prepare input
|
||||
x = x.permute(0,3,4,5,1,2)._pool((H, W), stride, dilation) # -> (bs, groups, rcin_hi, rcin_lo, oy, ox, H, W)
|
||||
oy, ox = x.shape[4:6]
|
||||
x = x.permute(0,4,5,1,2,3,6,7).reshape(bs, oy, ox, groups, 1, 1, rcin_hi, rcin_lo, H, W)
|
||||
x = x.expand(bs, oy, ox, groups, rcout//4 if rcout >= 4 else 1, 4 if rcout >= 4 else 1, rcin_hi, rcin_lo, H, W)
|
||||
x = x.reshape(bs, oy, ox, cout//4, 4, rcin_hi, rcin_lo, H, W)
|
||||
x = x.permute(0,4,5,1,2,3,6,7).reshape(bs, oy, ox, *cout_expand[0:2], 1, 1, rcin_hi, rcin_lo, H, W)
|
||||
x = x.expand(bs, oy, ox, *cout_expand, rcin_hi, rcin_lo, H, W)
|
||||
|
||||
# prepare weights
|
||||
w = w.permute(0,4,2,5,1,3)
|
||||
w = w.reshape((1, 1, 1, cout//4, 4, rcin_hi, rcin_lo, H, W)) # needed or this is broadcasting?
|
||||
w = w.reshape((1, 1, 1, *cout_expand, rcin_hi, rcin_lo, H, W))
|
||||
|
||||
# the conv!
|
||||
ret = (x*w).sum((-4, -3, -2, -1)).reshape(bs*oy, ox*cout//4, 4)
|
||||
|
||||
Reference in New Issue
Block a user