mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-11 23:46:02 +08:00
ops work
This commit is contained in:
@@ -181,12 +181,12 @@ class Conv2D(Function):
|
||||
)
|
||||
tw = w.reshape(ctx.groups, rcout, cin, H, W)
|
||||
ctx.save_for_backward(tx, tw, x.shape)
|
||||
#ret = np.einsum('igjYXyx,gkjyx -> igkYX', tx, tw).reshape(bs, cout, oy, ox)
|
||||
ret = np.zeros((bs,ctx.groups,rcout,oy,ox),dtype=x.dtype)
|
||||
|
||||
ret = np.zeros((bs,ctx.groups,oy,ox,rcout),dtype=x.dtype)
|
||||
for g in range(ctx.groups):
|
||||
#ijYXyx,kjyx -> iYXk ->ikYX
|
||||
ret[:,g]+=np.moveaxis(np.tensordot(tx[:,g], tw[g],((1,4,5),(1,2,3))),3,1)
|
||||
return ret.reshape(bs, cout, oy, ox)
|
||||
ret[:,g] += np.tensordot(tx[:,g], tw[g], ((1,4,5),(1,2,3)))
|
||||
return np.moveaxis(ret,4,2).reshape(bs, cout, oy, ox)
|
||||
|
||||
|
||||
@staticmethod
|
||||
@@ -200,17 +200,19 @@ class Conv2D(Function):
|
||||
ggg = grad_output.reshape(bs,ctx.groups,rcout,oy,ox)
|
||||
|
||||
gdw = np.zeros((ctx.groups,rcout,cin,H,W), dtype=tx.dtype)
|
||||
#gdw = np.einsum('igkYX,igjYXyx -> gkjyx',ggg,tx)
|
||||
for g in range(ctx.groups):
|
||||
#'ikYX,ijYXyx -> kjyx'
|
||||
gdw[g] += np.tensordot(ggg[:,g],tx[:,g], ((0,2,3),(0,2,3)))
|
||||
gdw[g] += np.tensordot(ggg[:,g], tx[:,g], ((0,2,3),(0,2,3)))
|
||||
|
||||
#needs to be optimized
|
||||
# needs to be optimized
|
||||
gdx = np.zeros((bs,ctx.groups,cin,OY,OX), dtype=tx.dtype)
|
||||
for Y in range(grad_output.shape[2]):
|
||||
for X in range(grad_output.shape[3]):
|
||||
iY,iX = Y*ys, X*xs
|
||||
gdx[:,:,: , iY:iY+H, iX:iX+W] += np.einsum('igk,gkjyx->igjyx',ggg[:,:,:,Y,X], tw)
|
||||
#gdx[:,:,: , iY:iY+H, iX:iX+W] += np.einsum('igk,gkjyx->igjyx', ggg[:,:,:,Y,X], tw)
|
||||
for g in range(ctx.groups):
|
||||
tg = np.dot(ggg[:,g,:,Y,X].reshape(bs, -1), tw[g].reshape(rcout, -1))
|
||||
gdx[:, g, :, iY:iY+H, iX:iX+W] += tg.reshape((bs, cin, H, W))
|
||||
|
||||
return gdx.reshape((bs, ctx.groups*cin, OY, OX)), gdw.reshape((ctx.groups*rcout, cin, H, W))
|
||||
register('conv2d', Conv2D)
|
||||
|
||||
@@ -23,7 +23,7 @@ class Add(Function):
|
||||
res_g[gid] = a_g[gid] + b_g[gid];
|
||||
}
|
||||
""").build()
|
||||
prg.add(ctx.cl_queue, [ret.size//4], None, x, y, ret)
|
||||
prg.add(ctx.cl_queue, [np.prod(ret.shape)], None, x, y, ret)
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
@@ -43,7 +43,7 @@ class Mul(Function):
|
||||
res_g[gid] = a_g[gid] * b_g[gid];
|
||||
}
|
||||
""").build()
|
||||
prg.mul(ctx.cl_queue, [ret.size//4], None, x, y, ret)
|
||||
prg.mul(ctx.cl_queue, [np.prod(ret.shape)], None, x, y, ret)
|
||||
ctx.save_for_backward(x, y, prg)
|
||||
return ret
|
||||
|
||||
@@ -143,13 +143,41 @@ class Dot(Function):
|
||||
input, grad_output, grad_weight,
|
||||
one, msize, isize, one, isize, osize)
|
||||
|
||||
|
||||
#prg.matmul(ctx.cl_queue, [msize, osize], None,
|
||||
# input, grad_output, grad_weight,
|
||||
|
||||
|
||||
return grad_input, grad_weight
|
||||
register('dot', Dot, gpu=True)
|
||||
register('matmul', Dot, gpu=True)
|
||||
|
||||
|
||||
# *** these two are unfinished, but until we fix the optimizer, it's useless ***
|
||||
|
||||
class ReLU(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
ret = buffer_like(ctx, x)
|
||||
prg = cl.Program(ctx.cl_ctx, """
|
||||
__kernel void relu(
|
||||
__global const float *a_g, __global float *res_g)
|
||||
{
|
||||
int gid = get_global_id(0);
|
||||
res_g[gid] = min(a_g[gid], (float)0.);
|
||||
}
|
||||
""").build()
|
||||
prg.relu(ctx.cl_queue, [np.prod(ret.shape)], None, x, ret)
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return grad_output
|
||||
register('relu', ReLU, gpu=True)
|
||||
|
||||
class LogSoftmax(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
return input
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return grad_output
|
||||
register('logsoftmax', LogSoftmax, gpu=True)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user