remove new_buffer

This commit is contained in:
George Hotz
2022-06-06 07:57:39 -07:00
parent 30f55eaaba
commit 613f0ca6e5
2 changed files with 35 additions and 35 deletions

View File

@@ -272,6 +272,10 @@ def convdx(w,grad_output,dx,conv_args):
int g = get_global_id(1);
int ci = get_global_id(2);
for (int Y = 0; Y < iy; Y++) { for (int X = 0; X < ix; X++) {
dx[B*groups*cin*iy*ix + g*cin*iy*ix + ci*iy*ix + Y*ix + X] = 0.0;
} }
for (int Y = 0; Y < oy; Y++) { for (int X = 0; X < ox; X++) {
for (int y = 0; y < H; y++) { for (int x = 0; x < W; x++) {
float acc = 0.0;

View File

@@ -1,24 +1,20 @@
import pyopencl as cl
import numpy as np
from tinygrad.helpers import binary_broadcast
from ..tensor import Function
from ..llops.opencl import GPUBuffer
from ..llops.opencl import GPUBuffer as Buffer
from ..llops.opencl import unary_op, binary_op, reduce_op, perm_axis, inner_slice
from ..llops.opencl import matmul, conv, convdw, convdx
def buffer_new(shape, zero=False):
return GPUBuffer(shape, hostbuf=None if not zero else np.zeros(shape, dtype=np.float32))
# ************* unary ops *************
class UnaryOp(Function):
def forward(ctx, input):
ctx.save_for_backward(input)
return unary_op(ctx.fop, input, buffer_new(input.shape))
return unary_op(ctx.fop, input, Buffer(input.shape))
def backward(ctx, grad_output):
input, = ctx.saved_tensors
return binary_op(ctx.bop, input, grad_output, buffer_new(input.shape))
return binary_op(ctx.bop, input, grad_output, Buffer(input.shape))
class ReLU(UnaryOp):
fop = 'max(a, (float)0.)'
@@ -42,36 +38,36 @@ def reduce_shape(shape, axis):
class Sum(Function):
def forward(ctx, input, axis=None):
ctx.save_for_backward(input.shape)
return reduce_op("out += a", input, buffer_new(reduce_shape(input.shape, axis)))
return reduce_op("out += a", input, Buffer(reduce_shape(input.shape, axis)))
def backward(ctx, grad_output):
shape_input, = ctx.saved_tensors
# NOTE: the b buffer_new isn't used, since this is just for broadcast
ret = buffer_new(shape_input)
# NOTE: the b Buffer isn't used, since this is just for broadcast
ret = Buffer(shape_input)
return binary_op('a', grad_output, ret, ret)
class Max(Function):
def forward(ctx, input, axis=None):
ret = reduce_op("out = max(a,out)", input, buffer_new(reduce_shape(input.shape, axis)), start="-INFINITY")
ret = reduce_op("out = max(a,out)", input, Buffer(reduce_shape(input.shape, axis)), start="-INFINITY")
ctx.save_for_backward(input, axis, ret)
return ret
def backward(ctx, grad_output):
input, axis, ret = ctx.saved_tensors
ret2 = binary_op("1.0*(a==b)", input, ret, buffer_new(input.shape))
div = reduce_op("out += a", ret2, buffer_new(reduce_shape(ret2.shape, axis)), start="1e-10")
ret2 = binary_op("1.0*(a==b)", input, ret, Buffer(input.shape))
div = reduce_op("out += a", ret2, Buffer(reduce_shape(ret2.shape, axis)), start="1e-10")
binary_op("a/b", ret2, div, ret2)
return binary_op('a*b', ret2, grad_output, ret2)
# ************* binary ops *************
def unbroadcast(out, in_sh):
return reduce_op("out += a", out, buffer_new(in_sh))
return reduce_op("out += a", out, Buffer(in_sh))
class Add(Function):
def forward(ctx, x, y):
ctx.save_for_backward(x.shape, y.shape)
return binary_op('a+b', x, y, buffer_new(binary_broadcast(x.shape, y.shape)))
return binary_op('a+b', x, y, Buffer(binary_broadcast(x.shape, y.shape)))
def backward(ctx, grad_output):
shape_x, shape_y = ctx.saved_tensors
@@ -80,33 +76,33 @@ class Add(Function):
class Sub(Function):
def forward(ctx, x, y):
ctx.save_for_backward(x.shape, y.shape)
return binary_op('a-b', x, y, buffer_new(binary_broadcast(x.shape, y.shape)))
return binary_op('a-b', x, y, Buffer(binary_broadcast(x.shape, y.shape)))
def backward(ctx, grad_output):
shape_x, shape_y = ctx.saved_tensors
grad_x, grad_y = grad_output, unary_op('-a', grad_output, buffer_new(grad_output.shape))
grad_x, grad_y = grad_output, unary_op('-a', grad_output, Buffer(grad_output.shape))
return unbroadcast(grad_x, shape_x), unbroadcast(grad_y, shape_y)
class Mul(Function):
def forward(ctx, x, y):
ctx.save_for_backward(x, y)
return binary_op('a*b', x, y, buffer_new(binary_broadcast(x.shape, y.shape)))
return binary_op('a*b', x, y, Buffer(binary_broadcast(x.shape, y.shape)))
def backward(ctx, grad_output):
x,y = ctx.saved_tensors
grad_x = binary_op('a*b', y, grad_output, buffer_new(grad_output.shape))
grad_y = binary_op('a*b', x, grad_output, buffer_new(grad_output.shape))
grad_x = binary_op('a*b', y, grad_output, Buffer(grad_output.shape))
grad_y = binary_op('a*b', x, grad_output, Buffer(grad_output.shape))
return unbroadcast(grad_x, x.shape), unbroadcast(grad_y, y.shape)
class Pow(Function):
def forward(ctx, x, y):
ctx.save_for_backward(x, y)
return binary_op('pow(a,b)', x, y, buffer_new(binary_broadcast(x.shape, y.shape)))
return binary_op('pow(a,b)', x, y, Buffer(binary_broadcast(x.shape, y.shape)))
def backward(ctx, grad_output):
x,y = ctx.saved_tensors
grad_x_inter = binary_op('b * (pow((float)a, (float)(b-1.0)))', x, y, buffer_new(grad_output.shape))
grad_y_inter = binary_op('pow(a, (float)b) * log(a);', x, y, buffer_new(grad_output.shape))
grad_x_inter = binary_op('b * (pow((float)a, (float)(b-1.0)))', x, y, Buffer(grad_output.shape))
grad_y_inter = binary_op('pow(a, (float)b) * log(a);', x, y, Buffer(grad_output.shape))
return unbroadcast(binary_op('a*b', grad_output, grad_x_inter, grad_x_inter), x.shape), \
unbroadcast(binary_op('a*b', grad_output, grad_y_inter, grad_y_inter), y.shape)
@@ -116,35 +112,35 @@ class Reshape(Function):
def forward(ctx, x, shape):
ctx.save_for_backward(x.shape)
shape = tuple(-np.prod(x.shape) // np.prod(shape) if s == -1 else s for s in shape)
r = GPUBuffer(shape, hostbuf=x) # NOTE: this is not a copy
r = Buffer(shape, hostbuf=x) # NOTE: this is not a copy
assert np.prod(x.shape) == np.prod(r.shape)
return r
def backward(ctx, grad_output):
in_shape, = ctx.saved_tensors
return GPUBuffer(in_shape, hostbuf=grad_output)
return Buffer(in_shape, hostbuf=grad_output)
class Transpose(Function):
def forward(ctx, x, order=(1,0)):
ctx.save_for_backward(order)
ret = buffer_new(np.array(x.shape)[list(order)])
ret = Buffer(np.array(x.shape)[list(order)])
return perm_axis(x, order, ret)
def backward(ctx, grad_output):
norder = np.argsort(ctx.order)
ret = buffer_new(np.array(grad_output.shape)[list(norder)])
ret = Buffer(np.array(grad_output.shape)[list(norder)])
return perm_axis(grad_output, norder, ret)
class Slice(Function):
def forward(ctx, x, arg=None):
ctx.save_for_backward(x.shape)
ret = buffer_new([y[1]-y[0] for y in arg])
ret = Buffer([y[1]-y[0] for y in arg])
return inner_slice(x, arg, ret)
def backward(ctx, grad_output):
shape, = ctx.saved_tensors
narg = [(0-p[0], grad_output.shape[i]+(shape[i]-p[1])) for i,p in enumerate(ctx.arg)]
ret = buffer_new([y[1]-y[0] for y in narg])
ret = Buffer([y[1]-y[0] for y in narg])
return inner_slice(grad_output, narg, ret)
# ************* processing ops *************
@@ -152,14 +148,14 @@ class Slice(Function):
class Matmul(Function):
def forward(ctx, input, weight):
assert input.shape[-1] == weight.shape[-2]
ret = buffer_new(list(input.shape[0:-1])+[weight.shape[-1]])
ret = Buffer(list(input.shape[0:-1])+[weight.shape[-1]])
ctx.save_for_backward(input, weight)
return matmul(input, weight, ret)
def backward(ctx, grad_output):
input, weight = ctx.saved_tensors
grad_input = matmul(grad_output, weight, buffer_new(input.shape), transpose_b=True) if ctx.needs_input_grad[0] else None
grad_weight = matmul(input, grad_output, buffer_new(weight.shape), transpose_a=True) if ctx.needs_input_grad[1] else None
grad_input = matmul(grad_output, weight, Buffer(input.shape), transpose_b=True) if ctx.needs_input_grad[0] else None
grad_weight = matmul(input, grad_output, Buffer(weight.shape), transpose_a=True) if ctx.needs_input_grad[1] else None
return grad_input, grad_weight
class Conv2D(Function):
@@ -177,7 +173,7 @@ class Conv2D(Function):
# output buffer
conv_args = H, W, groups, rcout, cin, oy, ox, iy, ix, ys, xs, bs
return conv(x, w, buffer_new((bs, cout, oy, ox)), conv_args)
return conv(x, w, Buffer((bs, cout, oy, ox)), conv_args)
def backward(ctx, grad_output):
bs,_,oy,ox = grad_output.shape
@@ -191,6 +187,6 @@ class Conv2D(Function):
rcout = cout//ctx.groups
conv_args = H, W, ctx.groups, rcout, cin, oy, ox, iy, ix, ys, xs, bs
dx = convdx(w, grad_output, buffer_new((bs, cin_, iy, ix), zero=True), conv_args) if ctx.needs_input_grad[0] else None
dw = convdw(x, grad_output, buffer_new((cout, cin, H, W)), conv_args) if ctx.needs_input_grad[1] else None
dx = convdx(w, grad_output, Buffer((bs, cin_, iy, ix)), conv_args) if ctx.needs_input_grad[0] else None
dw = convdw(x, grad_output, Buffer((cout, cin, H, W)), conv_args) if ctx.needs_input_grad[1] else None
return dx, dw