mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
GPU llops
This commit is contained in:
@@ -1,18 +1,18 @@
|
||||
Getting the core instruction set correct is the value of tinygrad
|
||||
|
||||
|
||||
Max size tensor is 6-D for the pool2d
|
||||
|
||||
Unary Ops
|
||||
===
|
||||
|
||||
These are the simplest to reason about, and have pointwise mem access.
|
||||
A and B are always the same size
|
||||
|
||||
Forward : A -> B
|
||||
Backward (binary): (B', A) -> A'
|
||||
|
||||
|
||||
|
||||
|
||||
Reduce Ops (with axis)
|
||||
===
|
||||
|
||||
|
||||
141
tinygrad/llops/gpu.py
Normal file
141
tinygrad/llops/gpu.py
Normal file
@@ -0,0 +1,141 @@
|
||||
# llops don't know about derivatives
|
||||
|
||||
import functools
|
||||
import numpy as np
|
||||
import pyopencl as cl
|
||||
from tinygrad.helpers import binary_broadcast
|
||||
|
||||
i32 = np.int32
|
||||
|
||||
cl_ctx, cl_queue = None, None
|
||||
def require_init_gpu():
|
||||
global cl_ctx, cl_queue
|
||||
if cl_ctx is None:
|
||||
devices = cl.get_platforms()[0].get_devices(device_type=cl.device_type.GPU)
|
||||
if len(devices) == 0:
|
||||
devices = cl.get_platforms()[0].get_devices(device_type=cl.device_type.CPU)
|
||||
cl_ctx = cl.Context(devices=devices)
|
||||
# this is an in-order command queue
|
||||
cl_queue = cl.CommandQueue(cl_ctx)
|
||||
|
||||
class GPUBuffer:
|
||||
def __init__(self, shape, hostbuf=None):
|
||||
require_init_gpu()
|
||||
self.shape, self.dtype = tuple(shape), np.float32
|
||||
self.cl = hostbuf.cl if isinstance(hostbuf, GPUBuffer) else \
|
||||
cl.Buffer(cl_ctx, cl.mem_flags.READ_WRITE | (cl.mem_flags.COPY_HOST_PTR if hostbuf is not None else 0), 4*np.prod(shape),
|
||||
hostbuf=hostbuf.astype(np.float32).ravel() if hostbuf is not None else None)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<GPUBuffer with shape {self.shape!r}>"
|
||||
|
||||
@staticmethod
|
||||
def fromCPU(x):
|
||||
return GPUBuffer(x.shape, x.view(np.ndarray))
|
||||
|
||||
def toCPU(self):
|
||||
data = np.empty(self.shape, dtype=np.float32)
|
||||
cl_queue.finish()
|
||||
cl.enqueue_copy(cl_queue, data, self.cl, is_blocking=True)
|
||||
return data
|
||||
|
||||
def buffer_new(ctx, shape, zero=False):
|
||||
return GPUBuffer(shape, hostbuf=None if not zero else np.zeros(shape, dtype=np.float32))
|
||||
|
||||
def buffer_np(ctx, x):
|
||||
return cl.Buffer(cl_ctx, cl.mem_flags.READ_WRITE | cl.mem_flags.COPY_HOST_PTR, hostbuf=x)
|
||||
|
||||
def clbuffer(hostbuf, shape):
|
||||
return cl.Buffer(cl_ctx, cl.mem_flags.READ_WRITE | (cl.mem_flags.COPY_HOST_PTR if hostbuf is not None else 0),
|
||||
4*np.prod(shape),
|
||||
hostbuf=hostbuf.astype(np.float32).ravel() if hostbuf is not None else None)
|
||||
|
||||
@functools.lru_cache
|
||||
def clbuild(name, prg):
|
||||
clprg = cl.Program(cl_ctx, prg).build().__getattr__(name)
|
||||
def run(*args):
|
||||
clprg(cl_queue, *args)
|
||||
return run
|
||||
|
||||
# x -> ret
|
||||
def unary_op(ctx, code, x):
|
||||
ret = buffer_new(ctx, x.shape)
|
||||
unop = clbuild("unop", """
|
||||
__kernel void unop(__global const float *a_g, __global float *res_g) {
|
||||
int gid = get_global_id(0);
|
||||
float a = a_g[gid];
|
||||
res_g[gid] = """+code+""";
|
||||
}""")
|
||||
unop([np.prod(ret.shape)], None, x.cl, ret.cl)
|
||||
return ret
|
||||
|
||||
@functools.lru_cache
|
||||
def get_binop_prg(cl_ctx, code, complist):
|
||||
ndims = len(complist)
|
||||
args = "".join([f", int d{i}" for i in range(ndims)] + [f", int p{i}" for i in range(ndims-1)])
|
||||
compute_idx_rets = "".join([f"\n int idx_ret{i} = (gid0 / {f'p{i}' if i < ndims-1 else '1'}) % d{i};" for i in range(ndims)])
|
||||
|
||||
idx_exprs = ["0", "0"] # [idx_x, idx_y]
|
||||
for i in range(ndims):
|
||||
for j in range(2):
|
||||
if complist[i][j]:
|
||||
idx_exprs[j] = "idx_ret%d + d%d*(%s)" % (i, i, idx_exprs[j])
|
||||
|
||||
return cl.Program(cl_ctx, """__kernel void binop(__global const float *x_g, __global const float *y_g, __global float *res_g"""+args+""") {
|
||||
int gid0 = get_global_id(0);"""+compute_idx_rets+"""
|
||||
float a = x_g["""+idx_exprs[0]+"""];
|
||||
float b = y_g["""+idx_exprs[1]+"""];
|
||||
res_g[gid0] = """+code+""";\n}""").build()
|
||||
|
||||
def binary_op(ctx, code, x, y):
|
||||
shape_ret, dimlist, complist = binary_broadcast(x.shape, y.shape)
|
||||
prod_list = np.array(dimlist, dtype=i32)[-1::-1].cumprod(dtype=i32)[-1::-1] # take cumprod from back to front
|
||||
|
||||
prg = get_binop_prg(cl_ctx, code, tuple(complist))
|
||||
ret = buffer_new(ctx, shape_ret, zero=True)
|
||||
prg.binop(cl_queue, [prod_list[0]] if len(dimlist) > 0 else [1], None, x.cl, y.cl, ret.cl, *dimlist, *(prod_list[1:]))
|
||||
return ret
|
||||
|
||||
def reduce_op(ctx, code, code2, inp, axis=None, start="0.0"):
|
||||
if axis is None:
|
||||
# full reduce
|
||||
osize = [1]*len(inp.shape)
|
||||
else:
|
||||
osize = np.array(inp.shape)
|
||||
osize[list(axis)] = 1
|
||||
ret = buffer_new(ctx, osize)
|
||||
if axis is None:
|
||||
ret.shape = (1,)
|
||||
|
||||
# TODO: this is insanely slow
|
||||
reduce = clbuild("reduce", """
|
||||
__kernel void reduce(__global const float *a_g, int sz, __global float *res_g, int prod, int n_dims,
|
||||
__global const int *shape_x, __global const int *shape_ret) {
|
||||
int gid = get_global_id(0);
|
||||
|
||||
float out = """+start+""";
|
||||
for (int x = 0; x < sz; x++) {
|
||||
int idx = 0; // compute index into a_g
|
||||
int tprod = prod;
|
||||
int tsz = sz;
|
||||
for (int dim = 0; dim < n_dims; dim++) {
|
||||
idx *= shape_x[dim];
|
||||
if (shape_x[dim] == shape_ret[dim]) { // dim from gid, don't reduce
|
||||
tprod /= shape_x[dim];
|
||||
idx += (gid / tprod) % shape_x[dim];
|
||||
} else { // dim from x
|
||||
tsz /= shape_x[dim];
|
||||
idx += (x / tsz) % shape_x[dim];
|
||||
}
|
||||
}
|
||||
float a = a_g[idx];
|
||||
"""+code+""";
|
||||
}
|
||||
res_g[gid] = """+code2+""";
|
||||
}""")
|
||||
reduce([np.prod(osize)], None, inp.cl,
|
||||
i32(np.prod(inp.shape)//np.prod(osize)), ret.cl,
|
||||
i32(np.prod(osize)), i32(len(osize)),
|
||||
buffer_np(ctx, np.array(inp.shape, dtype=np.int32)),
|
||||
buffer_np(ctx, np.array(osize, dtype=np.int32)))
|
||||
return ret
|
||||
@@ -1,50 +1,7 @@
|
||||
import functools
|
||||
import pyopencl as cl
|
||||
import numpy as np
|
||||
from tinygrad.helpers import binary_broadcast
|
||||
from ..tensor import Function
|
||||
|
||||
cl_ctx, cl_queue = None, None
|
||||
def require_init_gpu():
|
||||
global cl_ctx, cl_queue
|
||||
if cl_queue is None:
|
||||
devices = cl.get_platforms()[0].get_devices(device_type=cl.device_type.GPU)
|
||||
if len(devices) == 0:
|
||||
devices = cl.get_platforms()[0].get_devices(device_type=cl.device_type.CPU)
|
||||
cl_ctx = cl.Context(devices=devices)
|
||||
# this is an in-order command queue
|
||||
cl_queue = cl.CommandQueue(cl_ctx)
|
||||
|
||||
class GPUBuffer:
|
||||
def __init__(self, shape, hostbuf=None):
|
||||
require_init_gpu()
|
||||
self.shape, self.dtype = tuple(shape), np.float32
|
||||
self.cl = hostbuf.cl if isinstance(hostbuf, GPUBuffer) else \
|
||||
cl.Buffer(cl_ctx, cl.mem_flags.READ_WRITE | (cl.mem_flags.COPY_HOST_PTR if hostbuf is not None else 0), 4*np.prod(shape),
|
||||
hostbuf=hostbuf.astype(np.float32).ravel() if hostbuf is not None else None)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<GPUBuffer with shape {self.shape!r}>"
|
||||
|
||||
@staticmethod
|
||||
def fromCPU(x):
|
||||
return GPUBuffer(x.shape, x.view(np.ndarray))
|
||||
|
||||
def toCPU(self):
|
||||
data = np.empty(self.shape, dtype=np.float32)
|
||||
cl_queue.finish()
|
||||
cl.enqueue_copy(cl_queue, data, self.cl, is_blocking=True)
|
||||
return data
|
||||
|
||||
def buffer_new(ctx, shape, zero=False):
|
||||
return GPUBuffer(shape, hostbuf=None if not zero else np.zeros(shape, dtype=np.float32))
|
||||
|
||||
def buffer_np(ctx, x):
|
||||
return cl.Buffer(cl_ctx, cl.mem_flags.READ_WRITE | cl.mem_flags.COPY_HOST_PTR, hostbuf=x)
|
||||
|
||||
@functools.lru_cache
|
||||
def clbuild(cl_ctx, name, prg):
|
||||
return cl.Program(cl_ctx, prg).build().__getattr__(name)
|
||||
from ..llops.gpu import GPUBuffer, clbuild, buffer_new, buffer_np, unary_op, binary_op, reduce_op
|
||||
|
||||
def uint2(x, y):
|
||||
return np.array((x,y), dtype=cl.cltypes.uint2)
|
||||
@@ -52,17 +9,6 @@ i32 = np.int32
|
||||
|
||||
# ************* unary ops *************
|
||||
|
||||
def unary_op(ctx, code, x):
|
||||
ret = buffer_new(ctx, x.shape)
|
||||
unop = clbuild(cl_ctx, "unop", """
|
||||
__kernel void unop(__global const float *a_g, __global float *res_g) {
|
||||
int gid = get_global_id(0);
|
||||
float a = a_g[gid];
|
||||
res_g[gid] = """+code+""";
|
||||
}""")
|
||||
unop(cl_queue, [np.prod(ret.shape)], None, x.cl, ret.cl)
|
||||
return ret
|
||||
|
||||
class ReLU(Function):
|
||||
def forward(ctx, input):
|
||||
ctx.save_for_backward(input)
|
||||
@@ -93,50 +39,6 @@ class Exp(Function):
|
||||
|
||||
# ************* reduce ops *************
|
||||
|
||||
def reduce_op(ctx, code, code2, inp, axis=None, start="0.0"):
|
||||
if axis is None:
|
||||
# full reduce
|
||||
osize = [1]*len(inp.shape)
|
||||
else:
|
||||
osize = np.array(inp.shape)
|
||||
osize[list(axis)] = 1
|
||||
ret = buffer_new(ctx, osize)
|
||||
if axis is None:
|
||||
ret.shape = (1,)
|
||||
|
||||
# TODO: this is insanely slow
|
||||
reduce = clbuild(cl_ctx, "reduce", """
|
||||
__kernel void reduce(__global const float *a_g, int sz, __global float *res_g, int prod, int n_dims,
|
||||
__global const int *shape_x, __global const int *shape_ret) {
|
||||
int gid = get_global_id(0);
|
||||
|
||||
float out = """+start+""";
|
||||
for (int x = 0; x < sz; x++) {
|
||||
int idx = 0; // compute index into a_g
|
||||
int tprod = prod;
|
||||
int tsz = sz;
|
||||
for (int dim = 0; dim < n_dims; dim++) {
|
||||
idx *= shape_x[dim];
|
||||
if (shape_x[dim] == shape_ret[dim]) { // dim from gid, don't reduce
|
||||
tprod /= shape_x[dim];
|
||||
idx += (gid / tprod) % shape_x[dim];
|
||||
} else { // dim from x
|
||||
tsz /= shape_x[dim];
|
||||
idx += (x / tsz) % shape_x[dim];
|
||||
}
|
||||
}
|
||||
float a = a_g[idx];
|
||||
"""+code+""";
|
||||
}
|
||||
res_g[gid] = """+code2+""";
|
||||
}""")
|
||||
reduce(cl_queue, [np.prod(osize)], None, inp.cl,
|
||||
i32(np.prod(inp.shape)//np.prod(osize)), ret.cl,
|
||||
i32(np.prod(osize)), i32(len(osize)),
|
||||
buffer_np(ctx, np.array(inp.shape, dtype=np.int32)),
|
||||
buffer_np(ctx, np.array(osize, dtype=np.int32)))
|
||||
return ret
|
||||
|
||||
class Sum(Function):
|
||||
def forward(ctx, input, axis=None):
|
||||
ctx.save_for_backward(input.shape)
|
||||
@@ -162,33 +64,6 @@ class Max(Function):
|
||||
|
||||
# ************* binary ops *************
|
||||
|
||||
@functools.lru_cache
|
||||
def get_binop_prg(cl_ctx, code, complist):
|
||||
ndims = len(complist)
|
||||
args = "".join([f", int d{i}" for i in range(ndims)] + [f", int p{i}" for i in range(ndims-1)])
|
||||
compute_idx_rets = "".join([f"\n int idx_ret{i} = (gid0 / {f'p{i}' if i < ndims-1 else '1'}) % d{i};" for i in range(ndims)])
|
||||
|
||||
idx_exprs = ["0", "0"] # [idx_x, idx_y]
|
||||
for i in range(ndims):
|
||||
for j in range(2):
|
||||
if complist[i][j]:
|
||||
idx_exprs[j] = "idx_ret%d + d%d*(%s)" % (i, i, idx_exprs[j])
|
||||
|
||||
return cl.Program(cl_ctx, """__kernel void binop(__global const float *x_g, __global const float *y_g, __global float *res_g"""+args+""") {
|
||||
int gid0 = get_global_id(0);"""+compute_idx_rets+"""
|
||||
float a = x_g["""+idx_exprs[0]+"""];
|
||||
float b = y_g["""+idx_exprs[1]+"""];
|
||||
res_g[gid0] = """+code+""";\n}""").build()
|
||||
|
||||
def binary_op(ctx, code, x, y):
|
||||
shape_ret, dimlist, complist = binary_broadcast(x.shape, y.shape)
|
||||
prod_list = np.array(dimlist, dtype=i32)[-1::-1].cumprod(dtype=i32)[-1::-1] # take cumprod from back to front
|
||||
|
||||
prg = get_binop_prg(cl_ctx, code, tuple(complist))
|
||||
ret = buffer_new(ctx, shape_ret, zero=True)
|
||||
prg.binop(cl_queue, [prod_list[0]] if len(dimlist) > 0 else [1], None, x.cl, y.cl, ret.cl, *dimlist, *(prod_list[1:]))
|
||||
return ret
|
||||
|
||||
def unbroadcast(ctx, out, in_sh):
|
||||
sum_axis = [i for i in range(len(in_sh)) if in_sh[i]==1 and out.shape[i]>1] if in_sh != (1,) else None
|
||||
return reduce_op(ctx, "out += a", "out", out, sum_axis)
|
||||
@@ -254,7 +129,7 @@ class Reshape(Function):
|
||||
def perm_axis(ctx, inp, order):
|
||||
osize = np.array(inp.shape)[list(order)]
|
||||
ret = buffer_new(ctx, osize)
|
||||
perm = clbuild(cl_ctx, "perm", """
|
||||
perm = clbuild("perm", """
|
||||
__kernel void perm(__global const float *a_g, __global float *res_g, int n_axis,
|
||||
__global const int *shape, __global const int *order) {
|
||||
int gid = get_global_id(0);
|
||||
@@ -268,7 +143,7 @@ def perm_axis(ctx, inp, order):
|
||||
}
|
||||
res_g[gid] = a_g[idx];
|
||||
}""")
|
||||
perm(cl_queue, [np.prod(osize)], None, inp.cl, ret.cl, i32(len(osize)),
|
||||
perm([np.prod(osize)], None, inp.cl, ret.cl, i32(len(osize)),
|
||||
buffer_np(ctx, np.array(inp.shape, dtype=np.int32)),
|
||||
buffer_np(ctx, np.array(order, dtype=np.int32)))
|
||||
return ret
|
||||
@@ -286,7 +161,7 @@ def inner_slice(ctx, x, arg):
|
||||
shift = [y[0] for y in arg]
|
||||
oshape = [y[1]-y[0] for y in arg]
|
||||
ret = buffer_new(ctx, oshape)
|
||||
gslice = clbuild(cl_ctx, "gslice", """
|
||||
gslice = clbuild("gslice", """
|
||||
__kernel void gslice(__global const float *input, __global float *output, int prod, int n_dims,
|
||||
__global const int *shape_x, __global const int *shape_ret,
|
||||
__global const int *shift) {
|
||||
@@ -301,7 +176,7 @@ def inner_slice(ctx, x, arg):
|
||||
}
|
||||
output[gid] = zero ? input[iptr] : 0.0;
|
||||
}""")
|
||||
gslice(cl_queue, [np.prod(ret.shape)], None,
|
||||
gslice([np.prod(ret.shape)], None,
|
||||
x.cl, ret.cl, i32(np.prod(ret.shape)), i32(len(ret.shape)),
|
||||
buffer_np(ctx, np.array(x.shape, dtype=np.int32)),
|
||||
buffer_np(ctx, np.array(ret.shape, dtype=np.int32)),
|
||||
@@ -327,7 +202,7 @@ class Matmul(Function):
|
||||
isize, msize, osize = i32(input.shape[-2]), i32(input.shape[-1]), i32(weight.shape[-1])
|
||||
ret = buffer_new(ctx, list(input.shape[0:-2])+[isize, osize])
|
||||
|
||||
matmul = clbuild(cl_ctx, "matmul", """
|
||||
matmul = clbuild("matmul", """
|
||||
__kernel void matmul(
|
||||
__global const float *input, __global const float *weight, __global float *res,
|
||||
int isize, int is0, int is1, int msize, int ws0, int ws1, int osize
|
||||
@@ -348,7 +223,7 @@ class Matmul(Function):
|
||||
ctx.save_for_backward(input, weight, matmul, cnt)
|
||||
|
||||
# (isize,msize) x (msize,osize) = (isize,osize)
|
||||
matmul(cl_queue, [isize, osize, cnt], None,
|
||||
matmul([isize, osize, cnt], None,
|
||||
input.cl, weight.cl, ret.cl, isize,
|
||||
msize, i32(1), msize, i32(1), osize, osize)
|
||||
return ret
|
||||
@@ -361,12 +236,12 @@ class Matmul(Function):
|
||||
grad_weight = buffer_new(ctx, weight.shape)
|
||||
|
||||
# (isize,osize) x (msize,osize) = (isize,msize)
|
||||
matmul(cl_queue, [isize, msize, cnt], None,
|
||||
matmul([isize, msize, cnt], None,
|
||||
grad_output.cl, weight.cl, grad_input.cl, isize,
|
||||
osize, i32(1), osize, osize, i32(1), msize)
|
||||
|
||||
# (isize,msize) x (isize,osize) = (msize,osize)
|
||||
matmul(cl_queue, [msize, osize, cnt], None,
|
||||
matmul([msize, osize, cnt], None,
|
||||
input.cl, grad_output.cl, grad_weight.cl, msize,
|
||||
i32(1), msize, isize, i32(1), osize, osize)
|
||||
|
||||
@@ -392,7 +267,7 @@ class Conv2D(Function):
|
||||
# weight = (groups, rcout, cin, H, W)
|
||||
# output = (bs, groups, rcout, oy, ox)
|
||||
|
||||
conv = clbuild(cl_ctx, "conv", """
|
||||
conv = clbuild("conv", """
|
||||
__kernel void conv(__global const float *input, __global const float *weight, __global float *output,
|
||||
int H, int W, int groups, int rcout, int cin, int oy, int ox, int iy, int ix, int ys, int xs) {
|
||||
|
||||
@@ -417,7 +292,7 @@ class Conv2D(Function):
|
||||
output[B*groups*rcout*oy*ox + g*rcout*oy*ox + c*oy*ox + Y*ox + X] = acc;
|
||||
}""")
|
||||
|
||||
conv(cl_queue, [bs*groups*rcout, oy, ox], None,
|
||||
conv([bs*groups*rcout, oy, ox], None,
|
||||
x.cl, w.cl, ret.cl,
|
||||
i32(H), i32(W), i32(groups), i32(rcout), i32(cin),
|
||||
i32(oy), i32(ox), i32(iy), i32(ix), i32(ys), i32(xs)
|
||||
@@ -442,7 +317,7 @@ class Conv2D(Function):
|
||||
# tensw = (groups*rcout, cin, H, W)
|
||||
# ggg = (bs, groups*rout, oy, ox)
|
||||
|
||||
convw = clbuild(cl_ctx, "convw", """
|
||||
convw = clbuild("convw", """
|
||||
__kernel void convw(__global const float *tensx, __global const float *ggg, __global float *dw,
|
||||
int H, int W, int groups, int rcout, int cin, int oy, int ox, int iy, int ix, int ys, int xs, int bs) {
|
||||
|
||||
@@ -463,7 +338,7 @@ class Conv2D(Function):
|
||||
}
|
||||
dw[get_global_id(0)*H*W + y*W + x] = acc;
|
||||
}""")
|
||||
convx = clbuild(cl_ctx, "convx", """
|
||||
convx = clbuild("convx", """
|
||||
__kernel void convx(__global const float *tensw, __global const float *ggg, __global float *dx,
|
||||
int H, int W, int groups, int rcout, int cin, int oy, int ox, int iy, int ix, int ys, int xs, int bs) {
|
||||
|
||||
@@ -489,6 +364,6 @@ class Conv2D(Function):
|
||||
""")
|
||||
|
||||
conv_args = i32(H), i32(W), i32(ctx.groups), i32(rcout), i32(cin), i32(oy), i32(ox), i32(iy), i32(ix), i32(ys), i32(xs), i32(bs)
|
||||
convw(cl_queue, [ctx.groups*rcout*cin, H, W], None, x.cl, grad_output.cl, dw.cl, *conv_args)
|
||||
convx(cl_queue, [bs, ctx.groups, cin], None, w.cl, grad_output.cl, dx.cl, *conv_args)
|
||||
convw([ctx.groups*rcout*cin, H, W], None, x.cl, grad_output.cl, dw.cl, *conv_args)
|
||||
convx([bs, ctx.groups, cin], None, w.cl, grad_output.cl, dx.cl, *conv_args)
|
||||
return dx, dw
|
||||
|
||||
Reference in New Issue
Block a user