Expand Operator (#327)

* replace broadcasting with expand

* Tensor, not self

* remove broadcasting from mlops

* delete useless A operator

* expand, not repeat

* remove A op

* expand on gpu

* binary_op doesn't broadcast anymore

* expand is still total junk, but the tests should pass
This commit is contained in:
George Hotz
2022-06-12 12:31:48 -07:00
committed by GitHub
parent 5cf7649eda
commit dcbca4fdf1
8 changed files with 127 additions and 102 deletions

View File

@@ -112,18 +112,18 @@ Warning: do not rely on the ANE port. It segfaults sometimes. So if you were doi
### hlops (in tensor.py)
hlops are syntactic sugar around mlops.
hlops are syntactic sugar around mlops. They support most things torch does.
### mlops
mlops are mid level ops, there's 13 of them. They understand memory allocation and derivatives
```
Relu, Log, Exp # unary ops
Sum, Max # reduce ops (with axis argument)
Add, Sub, Mul, Pow # binary ops (with broadcasting)
Reshape, Permute, Slice # movement ops
Conv2D(NCHW) # processing op (Matmul is also Conv2D)
Relu, Log, Exp # unary ops
Sum, Max # reduce ops (with axis argument)
Add, Sub, Mul, Pow # binary ops (no broadcasting, use expand)
Reshape, Permute, Slice, Expand # movement ops
Conv2D(NCHW) # processing op (Matmul is also Conv2D)
```
You no longer need to write mlops for a new accelerator
@@ -136,8 +136,8 @@ The autodiff stuff is all in mlops now so you can focus on the raw operations
Buffer # class of memory on this device
unary_op (RELU, EXP, LOG, NEG, SIGN) # A -> A
reduce_op (SUM, MAX) # A -> B (smaller size, B has 1 in shape)
binary_op (ADD, SUB, MUL, DIV, POW, A, CMPEQ) # A + B -> C (broadcasting supported)
movement_op (RESHAPE, PERMUTE, SLICE) # A -> B (different size)
binary_op (ADD, SUB, MUL, DIV, POW, CMPEQ) # A + B -> C (all the same size)
movement_op (RESHAPE, PERMUTE, SLICE, EXPAND) # A -> B (different size)
processing_op (CONV, CONVT, CONVDW) # A + B -> C
```

View File

@@ -168,6 +168,10 @@ class TestOps(unittest.TestCase):
def test_detach(self):
helper_test_op([(4,3,6,6)], lambda x: x.detach(), lambda x: x.detach(), forward_only=True)
def test_expand(self):
arg = (4,3,2,6)
helper_test_op([(4,3,1,6)], lambda x: x.expand(arg), lambda x: x.expand(shape=arg))
def test_simple_conv2d(self):
helper_test_op([(1,1,9,9), (1,1,3,3)],
lambda x,w: torch.nn.functional.conv2d(x,w).relu(),

View File

@@ -6,27 +6,6 @@ def prod(x): return int(np.prod(x))
def reduce_shape(shape, axis):
return [1 if i in axis else shape[i] for i in range(len(shape))]
def binary_broadcast(x_shape, y_shape, extra=False):
n_dims = max(len(x_shape), len(y_shape))
shape_x, shape_y = np.ones(n_dims, dtype=np.int32), np.ones(n_dims, dtype=np.int32)
shape_x[:len(x_shape)] = np.array(x_shape, dtype=np.int32)
shape_y[:len(y_shape)] = np.array(y_shape, dtype=np.int32)
if not np.all((shape_x == 1) | (shape_y == 1) | (shape_x == shape_y)):
raise Exception(f"binary op unbroadcastable shape mismatch: {x_shape} vs {y_shape}")
shape_ret = tuple([int(x) for x in np.maximum(shape_x, shape_y)])
if extra:
dimlist, complist = [], [] # note: len(dimlist) may be less than n_dims
def push(dim, comp):
if len(complist) > 0 and complist[-1] == comp:
dimlist[-1] *= dim
elif comp != (False, False):
dimlist.append(dim); complist.append(comp)
for i in range(n_dims): # group together any adjacent dimensions that we can to simplify broadcasting
push(np.int32(max(shape_x[i], shape_y[i])), (shape_x[i] > 1, shape_y[i] > 1))
return (shape_ret, dimlist, complist) if extra else shape_ret
def get_conv_args(x_shape, w_shape, stride, groups):
# TODO: https://docs.nvidia.com/deeplearning/performance/dl-performance-convolutional/index.html#tensor-layout
conv_args = namedtuple('conv_args',

View File

@@ -11,6 +11,7 @@ class CPUBuffer(np.ndarray):
def amax(x, *args, **kwargs): return np.amax(x, *args, **kwargs)
def permute(x, order): return x.transpose(order)
def custompad(x, padding): return np.pad(x, padding)
def expand(x, new_shape): return np.broadcast_to(x, new_shape)
@staticmethod
def fromCPU(x): return x
@@ -30,7 +31,6 @@ def binary_op(op, x, y, ret):
elif op == BinaryOps.MUL: ret[:] = x*y
elif op == BinaryOps.DIV: ret[:] = y/x
elif op == BinaryOps.POW: ret[:] = x**y
elif op == BinaryOps.A: ret[:] = x
elif op == BinaryOps.CMPEQ: ret[:] = 1.0*(x==y)
else: raise Exception(f"{op} isn't supported")
@@ -48,13 +48,14 @@ def reduce_op(op, inp, ret):
else: raise Exception(f"{op} isn't supported")
def movement_op(op, x, ret, arg=None):
if op == MovementOps.RESHAPE: ret[:] = x.reshape(ret.shape)
if op == MovementOps.RESHAPE: ret[:] = x.reshape(arg)
elif op == MovementOps.PERMUTE: ret[:] = x.permute(arg)
elif op == MovementOps.SLICE:
padding = [(max(0, -p[0]), max(0, p[1]-x.shape[i])) for i,p in enumerate(arg)]
x = x.custompad(padding)
slicee = [(p[0] + padding[i][0], p[1] + padding[i][0]) for i,p in enumerate(arg)]
ret[:] = x[tuple([slice(x[0], x[1], None) for x in slicee])]
elif op == MovementOps.EXPAND: ret[:] = x.expand(arg)
else: raise Exception(f"{op} isn't supported")
def get_tx(x, C):

View File

@@ -1,7 +1,7 @@
import functools
import numpy as np
import pyopencl as cl
from tinygrad.helpers import prod, binary_broadcast, get_conv_args
from tinygrad.helpers import prod
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, MovementOps, ProcessingOps
cl_ctx, cl_queue = None, None
@@ -66,42 +66,24 @@ def unary_op(op, x, ret):
unop([roundup(prod(ret.shape))//4], None, x.cl, ret.cl)
return ret
@functools.lru_cache
def get_binop_prg(code, complist):
ndims = len(complist)
args = "".join([f", int d{i}" for i in range(ndims)] + [f", int p{i}" for i in range(ndims-1)])
compute_idx_rets = "".join([f"\n int idx_ret{i} = (gid0 / {f'p{i}' if i < ndims-1 else '1'}) % d{i};" for i in range(ndims)])
idx_exprs = ["0", "0"] # [idx_x, idx_y]
for i in range(ndims):
for j in range(2):
if complist[i][j]:
idx_exprs[j] = "idx_ret%d + d%d*(%s)" % (i, i, idx_exprs[j])
dtype = ["float", "float", "float"]
prg = """__kernel void binop(__global const """+dtype[0]+""" *x_g, __global const """+dtype[1]+""" *y_g, __global """+dtype[2]+""" *res_g"""+args+""") {
int gid0 = get_global_id(0);"""+compute_idx_rets+"""
"""+dtype[0]+""" a = x_g["""+idx_exprs[0]+"""];
"""+dtype[1]+""" b = y_g["""+idx_exprs[1]+"""];
res_g[gid0] = """+code+""";\n}"""
return cl.Program(cl_ctx, prg).build(), dtype[2] == "float4"
def binary_op(op, x, y, ret):
if op == BinaryOps.ADD: code = "a+b"
elif op == BinaryOps.SUB: code = "a-b"
elif op == BinaryOps.MUL: code = "a*b"
elif op == BinaryOps.DIV: code = "b/a"
elif op == BinaryOps.POW: code = "pow(a,b)"
elif op == BinaryOps.A: code = "a"
elif op == BinaryOps.CMPEQ: code = "1.0f*(a==b)"
elif op == BinaryOps.CMPEQ: code = "(float4)(1.0f*(a.x==b.x), 1.0f*(a.y==b.y), 1.0f*(a.z==b.z), 1.0f*(a.w==b.w))"
else: raise Exception(f"{op} isn't supported")
shape_ret, dimlist, complist = binary_broadcast(x.shape, y.shape, True)
assert tuple(shape_ret) == tuple(ret.shape)
prod_list = np.array(dimlist, dtype=i32)[-1::-1].cumprod(dtype=i32)[-1::-1] # take cumprod from back to front
prg, is_float4 = get_binop_prg(code, tuple(complist))
kernel_size = ((roundup(prod_list[0])//4) if is_float4 else prod_list[0]) if len(dimlist) > 0 else 1
prg.binop(cl_queue, [kernel_size], None, x.cl, y.cl, ret.cl, *dimlist, *(prod_list[1:]))
assert x.shape == ret.shape and y.shape == ret.shape
binop = clbuild("binop", """
__kernel void binop(__global const float4 *a_g, __global const float4 *b_g, __global float4 *res_g) {
int gid = get_global_id(0);
float4 a = a_g[gid];
float4 b = b_g[gid];
res_g[gid] = """+code+""";
}""")
binop([roundup(prod(ret.shape))//4], None, x.cl, y.cl, ret.cl)
return ret
def reduce_op(op, inp, ret):
if op == ReduceOps.SUM:
@@ -190,10 +172,39 @@ def inner_slice(x, arg, ret):
buffer_np(np.array(ret.shape, dtype=np.int32)),
buffer_np(np.array(shift, dtype=np.int32)))
def expand(x, ret):
assert len(x.shape) == len(ret.shape)
dimlist, complist = [], [] # note: len(dimlist) may be less than n_dims
def push(dim, comp):
if len(complist) > 0 and complist[-1] == comp:
dimlist[-1] *= dim
elif comp != (False, False):
dimlist.append(dim); complist.append(comp)
for i,j in zip(x.shape, ret.shape): # group together any adjacent dimensions that we can to simplify broadcasting
push(np.int32(max(i,j)), (i > 1, j > 1))
prod_list = np.array(dimlist, dtype=i32)[-1::-1].cumprod(dtype=i32)[-1::-1] # take cumprod from back to front
ndims = len(complist)
args = "".join([f", int d{i}" for i in range(ndims)] + [f", int p{i}" for i in range(ndims-1)])
compute_idx_rets = "".join([f"\n int idx_ret{i} = (gid0 / {f'p{i}' if i < ndims-1 else '1'}) % d{i};" for i in range(ndims)])
idx_exprs = ["0", "0"] # [idx_x, idx_y]
for i in range(ndims):
for j in range(2):
if complist[i][j]:
idx_exprs[j] = "idx_ret%d + d%d*(%s)" % (i, i, idx_exprs[j])
expandop = clbuild("expandop", """__kernel void expandop(__global const float *x_g, __global float *res_g"""+args+""") {
int gid0 = get_global_id(0);"""+compute_idx_rets+"""
res_g[gid0] = x_g["""+idx_exprs[0]+"""];\n}""")
expandop([prod_list[0] if len(dimlist) > 0 else 1], None, x.cl, ret.cl, *dimlist, *(prod_list[1:]))
def movement_op(op, x, ret, arg=None):
if op == MovementOps.RESHAPE: reshape(x, ret)
elif op == MovementOps.PERMUTE: perm_axis(x, arg, ret)
elif op == MovementOps.SLICE: inner_slice(x, arg, ret)
elif op == MovementOps.EXPAND: expand(x, ret)
def conv(x,w,ret,C):
# input = (bs, groups, cin, iy, ix)

View File

@@ -44,9 +44,7 @@ class Sum(Function):
def backward(ctx, grad_output):
shape_input, = ctx.saved_tensors
# NOTE: the b Buffer isn't used, since this is just for broadcast
ret = ctx.buffer(shape_input)
return ctx.binary_op(BinaryOps.A, grad_output, ret)
return ctx.movement_op(MovementOps.EXPAND, grad_output, shape_input)
class Max(Function):
def forward(ctx, input, axis=None):
@@ -56,36 +54,35 @@ class Max(Function):
def backward(ctx, grad_output):
input, ret = ctx.saved_tensors
ret2 = ctx.binary_op(BinaryOps.CMPEQ, input, ret)
div = ctx.reduce_op(ReduceOps.SUM, ret2, grad_output.shape)
ret2 = ctx.binary_op(BinaryOps.DIV, div, ret2)
return ctx.binary_op(BinaryOps.MUL, ret2, grad_output)
# 1s in locations where the max was chosen (can be two locations)
max_is_1s = ctx.binary_op(BinaryOps.CMPEQ, input, ctx.movement_op(MovementOps.EXPAND, ret, input.shape))
# sum of locations, averaged
div = ctx.reduce_op(ReduceOps.SUM, max_is_1s, grad_output.shape)
div = ctx.movement_op(MovementOps.EXPAND, div, input.shape)
max_is_amount = ctx.binary_op(BinaryOps.DIV, div, max_is_1s)
grad_output_expanded = ctx.movement_op(MovementOps.EXPAND, grad_output, input.shape)
return ctx.binary_op(BinaryOps.MUL, max_is_amount, grad_output_expanded)
# ************* binary ops *************
def unbroadcast(ctx, out, in_sh):
return ctx.reduce_op(ReduceOps.SUM, out, in_sh)
class Add(Function):
def forward(ctx, x, y):
ctx.save_for_backward(x.shape, y.shape)
return ctx.binary_op(BinaryOps.ADD, x, y)
def backward(ctx, grad_output):
shape_x, shape_y = ctx.saved_tensors
return unbroadcast(ctx, grad_output, shape_x) if ctx.needs_input_grad[0] else None, \
unbroadcast(ctx, grad_output, shape_y) if ctx.needs_input_grad[1] else None
return grad_output if ctx.needs_input_grad[0] else None, \
grad_output if ctx.needs_input_grad[1] else None
class Sub(Function):
def forward(ctx, x, y):
ctx.save_for_backward(x.shape, y.shape)
return ctx.binary_op(BinaryOps.SUB, x, y)
def backward(ctx, grad_output):
shape_x, shape_y = ctx.saved_tensors
neg_grad_output = ctx.unary_op(UnaryOps.NEG, grad_output)
return unbroadcast(ctx, grad_output, shape_x) if ctx.needs_input_grad[0] else None, \
unbroadcast(ctx, neg_grad_output, shape_y) if ctx.needs_input_grad[1] else None
return grad_output if ctx.needs_input_grad[0] else None, \
ctx.unary_op(UnaryOps.NEG, grad_output) if ctx.needs_input_grad[1] else None
class Mul(Function):
def forward(ctx, x, y):
@@ -94,8 +91,8 @@ class Mul(Function):
def backward(ctx, grad_output):
x,y = ctx.saved_tensors
grad_x = unbroadcast(ctx, ctx.binary_op(BinaryOps.MUL, y, grad_output), x.shape) if ctx.needs_input_grad[0] else None
grad_y = unbroadcast(ctx, ctx.binary_op(BinaryOps.MUL, x, grad_output), y.shape) if ctx.needs_input_grad[1] else None
grad_x = ctx.binary_op(BinaryOps.MUL, y, grad_output) if ctx.needs_input_grad[0] else None
grad_y = ctx.binary_op(BinaryOps.MUL, x, grad_output) if ctx.needs_input_grad[1] else None
return grad_x, grad_y
class Pow(Function):
@@ -106,15 +103,28 @@ class Pow(Function):
def backward(ctx, grad_output):
x,y,powxy = ctx.saved_tensors
tmp = ctx.binary_op(BinaryOps.DIV, x, powxy) # pow(x,y)/x
tmp = ctx.binary_op(BinaryOps.MUL, y, tmp) # y * pow(x,y)/x
grad_x = unbroadcast(ctx, ctx.binary_op(BinaryOps.MUL, grad_output, tmp), x.shape) if ctx.needs_input_grad[0] else None
tmp = ctx.binary_op(BinaryOps.MUL, ctx.unary_op(UnaryOps.LOG, x), powxy) # log(x) * pow(x,y)
grad_y = unbroadcast(ctx, ctx.binary_op(BinaryOps.MUL, grad_output, tmp), y.shape) if ctx.needs_input_grad[1] else None
grad_x, grad_y = None, None
if ctx.needs_input_grad[0]:
tmp = ctx.binary_op(BinaryOps.DIV, x, powxy) # pow(x,y)/x
tmp = ctx.binary_op(BinaryOps.MUL, y, tmp) # y * pow(x,y)/x
grad_x = ctx.binary_op(BinaryOps.MUL, grad_output, tmp)
if ctx.needs_input_grad[1]:
tmp = ctx.binary_op(BinaryOps.MUL, ctx.unary_op(UnaryOps.LOG, x), powxy) # log(x) * pow(x,y)
grad_y = ctx.binary_op(BinaryOps.MUL, grad_output, tmp)
return grad_x, grad_y
# ************* movement ops *************
# NOTE: this is sum in reverse
class Expand(Function):
def forward(ctx, x, shape):
ctx.save_for_backward(x.shape)
return ctx.movement_op(MovementOps.EXPAND, x, shape)
def backward(ctx, grad_output):
in_shape, = ctx.saved_tensors
return ctx.reduce_op(ReduceOps.SUM, grad_output, in_shape)
class Reshape(Function):
def forward(ctx, x, shape):
ctx.save_for_backward(x.shape)

View File

@@ -1,9 +1,9 @@
# TODO: move Device to here and proxy buffer call
from enum import Enum
UnaryOps = Enum("UnaryOps", ["RELU", "EXP", "LOG", "NEG", "SIGN"])
BinaryOps = Enum("BinaryOps", ["ADD", "SUB", "MUL", "DIV", "POW", "A", "CMPEQ"])
BinaryOps = Enum("BinaryOps", ["ADD", "SUB", "MUL", "DIV", "POW", "CMPEQ"])
ReduceOps = Enum("ReduceOps", ["SUM", "MAX"])
MovementOps = Enum("MovementOps", ["RESHAPE", "PERMUTE", "SLICE"])
MovementOps = Enum("MovementOps", ["RESHAPE", "PERMUTE", "SLICE", "EXPAND"])
ProcessingOps = Enum("ProcessingOps", ["CONV", "CONVT", "CONVDW"])
import os
@@ -45,7 +45,6 @@ def log_op(op, ret, inp):
G.nodes[nm(ret)]['fillcolor'] = top_colors[top]
G.nodes[nm(ret)]['style'] = 'filled'
from tinygrad.helpers import binary_broadcast
class Ops:
def unary_op(ctx, op:UnaryOps, x):
ret = ctx.buffer(x.shape)
@@ -60,13 +59,14 @@ class Ops:
return ret
def binary_op(ctx, op:ReduceOps, x, y):
ret = ctx.buffer(binary_broadcast(x.shape, y.shape))
assert x.shape == y.shape
ret = ctx.buffer(x.shape)
ctx.op.binary_op(op, x, y, ret)
log_op(op, ret, [x] if op == BinaryOps.A else [x, y])
log_op(op, ret, [x, y])
return ret
def movement_op(ctx, op:MovementOps, x, arg=None):
if op == MovementOps.RESHAPE: new_shape = arg
if op in [MovementOps.RESHAPE, MovementOps.EXPAND]: new_shape = arg
if op == MovementOps.PERMUTE: new_shape = [x.shape[i] for i in arg]
if op == MovementOps.SLICE: new_shape = [y-x for x,y in arg]
ret = ctx.buffer(new_shape)

View File

@@ -349,10 +349,34 @@ class Tensor:
ret = x._conv2d(weight, stride=stride, groups=groups)
return ret if bias is None else ret.add(bias.reshape(shape=[1, -1, 1, 1]))
# ***** broadcasted binary ops *****
@staticmethod
def broadcasted(fxn, x, y):
tt = [arg for arg in [x,y] if isinstance(arg, Tensor)][0] # this is the prototype tensor
if not isinstance(x, Tensor): x = Tensor(np.array([x], dtype=tt.dtype), device=tt.device, requires_grad=False)
if not isinstance(y, Tensor): y = Tensor(np.array([y], dtype=tt.dtype), device=tt.device, requires_grad=False)
n_dims = max(len(x.shape), len(y.shape))
if len(x.shape) != n_dims: x = x.reshape(list(x.shape) + [1]*(n_dims-len(x.shape)))
if len(y.shape) != n_dims: y = y.reshape(list(y.shape) + [1]*(n_dims-len(y.shape)))
shape_ret = tuple([int(x) for x in np.maximum(x.shape, y.shape)])
if x.shape != shape_ret: x = x.expand(shape_ret)
if y.shape != shape_ret: y = y.expand(shape_ret)
return fxn(x, y)
# TODO: are these the only ones that can take number arguments?
def add(self, x): return Tensor.broadcasted(Tensor._add, self, x)
def sub(self, x): return Tensor.broadcasted(Tensor._sub, self, x)
def mul(self, x): return Tensor.broadcasted(Tensor._mul, self, x)
def pow(self, x): return Tensor.broadcasted(Tensor._pow, self, x)
# ***** functional nn ops *****
def reshape(self, shape):
return self._reshape(shape=shape)
# TODO: fix the kwargs problem
def reshape(self, shape): return self._reshape(shape=shape)
def expand(self, shape): return self._expand(shape=shape)
def linear(self, weight, bias):
shp = [1] * (len(self.shape)-1) + [-1]
@@ -391,13 +415,8 @@ class Function(Ops):
@classmethod
def apply(cls, *x, **kwargs):
tt = [arg for arg in x if isinstance(arg, Tensor)][0] # this is the prototype tensor
# create tensors from number arguments
x = [Tensor(np.array([arg], dtype=tt.dtype), device=tt.device, requires_grad=False) if not isinstance(arg, Tensor) else arg for arg in x]
assert all([tt.device == t.device for t in x]), "All tensors are not on the same device"
ctx = cls(tt.device, *x)
assert all([isinstance(arg, Tensor) for arg in x])
ctx = cls(x[0].device, *x)
with ProfileOp(ctx, ctx.__class__.__name__, x) as po:
ret = Tensor(cls.forward(ctx, *[t.data for t in x], **kwargs),
device=ctx.device, requires_grad=ctx.requires_grad)
@@ -419,6 +438,7 @@ for name, cls in inspect.getmembers(importlib.import_module('tinygrad.mlops'), i
if name[0] != "_" and name != "Function" and not name.endswith("Ops"): register(name.lower(), cls)
# register the operators
# TODO: add div
def register_op(name, fxn):
setattr(Tensor, f"__{name}__", fxn)
setattr(Tensor, f"__i{name}__", lambda self,x: self.assign(fxn(self,x)))