This commit is contained in:
George Hotz
2020-11-02 07:03:23 -08:00
parent af5a4e0f5a
commit 355402504e
2 changed files with 45 additions and 15 deletions

View File

@@ -181,12 +181,12 @@ class Conv2D(Function):
)
tw = w.reshape(ctx.groups, rcout, cin, H, W)
ctx.save_for_backward(tx, tw, x.shape)
#ret = np.einsum('igjYXyx,gkjyx -> igkYX', tx, tw).reshape(bs, cout, oy, ox)
ret = np.zeros((bs,ctx.groups,rcout,oy,ox),dtype=x.dtype)
ret = np.zeros((bs,ctx.groups,oy,ox,rcout),dtype=x.dtype)
for g in range(ctx.groups):
#ijYXyx,kjyx -> iYXk ->ikYX
ret[:,g]+=np.moveaxis(np.tensordot(tx[:,g], tw[g],((1,4,5),(1,2,3))),3,1)
return ret.reshape(bs, cout, oy, ox)
ret[:,g] += np.tensordot(tx[:,g], tw[g], ((1,4,5),(1,2,3)))
return np.moveaxis(ret,4,2).reshape(bs, cout, oy, ox)
@staticmethod
@@ -200,17 +200,19 @@ class Conv2D(Function):
ggg = grad_output.reshape(bs,ctx.groups,rcout,oy,ox)
gdw = np.zeros((ctx.groups,rcout,cin,H,W), dtype=tx.dtype)
#gdw = np.einsum('igkYX,igjYXyx -> gkjyx',ggg,tx)
for g in range(ctx.groups):
#'ikYX,ijYXyx -> kjyx'
gdw[g] += np.tensordot(ggg[:,g],tx[:,g], ((0,2,3),(0,2,3)))
gdw[g] += np.tensordot(ggg[:,g], tx[:,g], ((0,2,3),(0,2,3)))
#needs to be optimized
# needs to be optimized
gdx = np.zeros((bs,ctx.groups,cin,OY,OX), dtype=tx.dtype)
for Y in range(grad_output.shape[2]):
for X in range(grad_output.shape[3]):
iY,iX = Y*ys, X*xs
gdx[:,:,: , iY:iY+H, iX:iX+W] += np.einsum('igk,gkjyx->igjyx',ggg[:,:,:,Y,X], tw)
#gdx[:,:,: , iY:iY+H, iX:iX+W] += np.einsum('igk,gkjyx->igjyx', ggg[:,:,:,Y,X], tw)
for g in range(ctx.groups):
tg = np.dot(ggg[:,g,:,Y,X].reshape(bs, -1), tw[g].reshape(rcout, -1))
gdx[:, g, :, iY:iY+H, iX:iX+W] += tg.reshape((bs, cin, H, W))
return gdx.reshape((bs, ctx.groups*cin, OY, OX)), gdw.reshape((ctx.groups*rcout, cin, H, W))
register('conv2d', Conv2D)

View File

@@ -23,7 +23,7 @@ class Add(Function):
res_g[gid] = a_g[gid] + b_g[gid];
}
""").build()
prg.add(ctx.cl_queue, [ret.size//4], None, x, y, ret)
prg.add(ctx.cl_queue, [np.prod(ret.shape)], None, x, y, ret)
return ret
@staticmethod
@@ -43,7 +43,7 @@ class Mul(Function):
res_g[gid] = a_g[gid] * b_g[gid];
}
""").build()
prg.mul(ctx.cl_queue, [ret.size//4], None, x, y, ret)
prg.mul(ctx.cl_queue, [np.prod(ret.shape)], None, x, y, ret)
ctx.save_for_backward(x, y, prg)
return ret
@@ -143,13 +143,41 @@ class Dot(Function):
input, grad_output, grad_weight,
one, msize, isize, one, isize, osize)
#prg.matmul(ctx.cl_queue, [msize, osize], None,
# input, grad_output, grad_weight,
return grad_input, grad_weight
register('dot', Dot, gpu=True)
register('matmul', Dot, gpu=True)
# *** these two are unfinished, but until we fix the optimizer, it's useless ***
class ReLU(Function):
@staticmethod
def forward(ctx, x):
ret = buffer_like(ctx, x)
prg = cl.Program(ctx.cl_ctx, """
__kernel void relu(
__global const float *a_g, __global float *res_g)
{
int gid = get_global_id(0);
res_g[gid] = min(a_g[gid], (float)0.);
}
""").build()
prg.relu(ctx.cl_queue, [np.prod(ret.shape)], None, x, ret)
return ret
@staticmethod
def backward(ctx, grad_output):
return grad_output
register('relu', ReLU, gpu=True)
class LogSoftmax(Function):
@staticmethod
def forward(ctx, input):
return input
@staticmethod
def backward(ctx, grad_output):
return grad_output
register('logsoftmax', LogSoftmax, gpu=True)