mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-15 01:15:49 +08:00
Expand Operator (#327)
* replace broadcasting with expand * Tensor, not self * remove broadcasting from mlops * delete useless A operator * expand, not repeat * remove A op * expand on gpu * binary_op doesn't broadcast anymore * expand is still total junk, but the tests should pass
This commit is contained in:
16
README.md
16
README.md
@@ -112,18 +112,18 @@ Warning: do not rely on the ANE port. It segfaults sometimes. So if you were doi
|
||||
|
||||
### hlops (in tensor.py)
|
||||
|
||||
hlops are syntactic sugar around mlops.
|
||||
hlops are syntactic sugar around mlops. They support most things torch does.
|
||||
|
||||
### mlops
|
||||
|
||||
mlops are mid level ops, there's 13 of them. They understand memory allocation and derivatives
|
||||
|
||||
```
|
||||
Relu, Log, Exp # unary ops
|
||||
Sum, Max # reduce ops (with axis argument)
|
||||
Add, Sub, Mul, Pow # binary ops (with broadcasting)
|
||||
Reshape, Permute, Slice # movement ops
|
||||
Conv2D(NCHW) # processing op (Matmul is also Conv2D)
|
||||
Relu, Log, Exp # unary ops
|
||||
Sum, Max # reduce ops (with axis argument)
|
||||
Add, Sub, Mul, Pow # binary ops (no broadcasting, use expand)
|
||||
Reshape, Permute, Slice, Expand # movement ops
|
||||
Conv2D(NCHW) # processing op (Matmul is also Conv2D)
|
||||
```
|
||||
|
||||
You no longer need to write mlops for a new accelerator
|
||||
@@ -136,8 +136,8 @@ The autodiff stuff is all in mlops now so you can focus on the raw operations
|
||||
Buffer # class of memory on this device
|
||||
unary_op (RELU, EXP, LOG, NEG, SIGN) # A -> A
|
||||
reduce_op (SUM, MAX) # A -> B (smaller size, B has 1 in shape)
|
||||
binary_op (ADD, SUB, MUL, DIV, POW, A, CMPEQ) # A + B -> C (broadcasting supported)
|
||||
movement_op (RESHAPE, PERMUTE, SLICE) # A -> B (different size)
|
||||
binary_op (ADD, SUB, MUL, DIV, POW, CMPEQ) # A + B -> C (all the same size)
|
||||
movement_op (RESHAPE, PERMUTE, SLICE, EXPAND) # A -> B (different size)
|
||||
processing_op (CONV, CONVT, CONVDW) # A + B -> C
|
||||
```
|
||||
|
||||
|
||||
@@ -168,6 +168,10 @@ class TestOps(unittest.TestCase):
|
||||
def test_detach(self):
|
||||
helper_test_op([(4,3,6,6)], lambda x: x.detach(), lambda x: x.detach(), forward_only=True)
|
||||
|
||||
def test_expand(self):
|
||||
arg = (4,3,2,6)
|
||||
helper_test_op([(4,3,1,6)], lambda x: x.expand(arg), lambda x: x.expand(shape=arg))
|
||||
|
||||
def test_simple_conv2d(self):
|
||||
helper_test_op([(1,1,9,9), (1,1,3,3)],
|
||||
lambda x,w: torch.nn.functional.conv2d(x,w).relu(),
|
||||
|
||||
@@ -6,27 +6,6 @@ def prod(x): return int(np.prod(x))
|
||||
def reduce_shape(shape, axis):
|
||||
return [1 if i in axis else shape[i] for i in range(len(shape))]
|
||||
|
||||
def binary_broadcast(x_shape, y_shape, extra=False):
|
||||
n_dims = max(len(x_shape), len(y_shape))
|
||||
shape_x, shape_y = np.ones(n_dims, dtype=np.int32), np.ones(n_dims, dtype=np.int32)
|
||||
shape_x[:len(x_shape)] = np.array(x_shape, dtype=np.int32)
|
||||
shape_y[:len(y_shape)] = np.array(y_shape, dtype=np.int32)
|
||||
if not np.all((shape_x == 1) | (shape_y == 1) | (shape_x == shape_y)):
|
||||
raise Exception(f"binary op unbroadcastable shape mismatch: {x_shape} vs {y_shape}")
|
||||
shape_ret = tuple([int(x) for x in np.maximum(shape_x, shape_y)])
|
||||
|
||||
if extra:
|
||||
dimlist, complist = [], [] # note: len(dimlist) may be less than n_dims
|
||||
def push(dim, comp):
|
||||
if len(complist) > 0 and complist[-1] == comp:
|
||||
dimlist[-1] *= dim
|
||||
elif comp != (False, False):
|
||||
dimlist.append(dim); complist.append(comp)
|
||||
for i in range(n_dims): # group together any adjacent dimensions that we can to simplify broadcasting
|
||||
push(np.int32(max(shape_x[i], shape_y[i])), (shape_x[i] > 1, shape_y[i] > 1))
|
||||
|
||||
return (shape_ret, dimlist, complist) if extra else shape_ret
|
||||
|
||||
def get_conv_args(x_shape, w_shape, stride, groups):
|
||||
# TODO: https://docs.nvidia.com/deeplearning/performance/dl-performance-convolutional/index.html#tensor-layout
|
||||
conv_args = namedtuple('conv_args',
|
||||
|
||||
@@ -11,6 +11,7 @@ class CPUBuffer(np.ndarray):
|
||||
def amax(x, *args, **kwargs): return np.amax(x, *args, **kwargs)
|
||||
def permute(x, order): return x.transpose(order)
|
||||
def custompad(x, padding): return np.pad(x, padding)
|
||||
def expand(x, new_shape): return np.broadcast_to(x, new_shape)
|
||||
|
||||
@staticmethod
|
||||
def fromCPU(x): return x
|
||||
@@ -30,7 +31,6 @@ def binary_op(op, x, y, ret):
|
||||
elif op == BinaryOps.MUL: ret[:] = x*y
|
||||
elif op == BinaryOps.DIV: ret[:] = y/x
|
||||
elif op == BinaryOps.POW: ret[:] = x**y
|
||||
elif op == BinaryOps.A: ret[:] = x
|
||||
elif op == BinaryOps.CMPEQ: ret[:] = 1.0*(x==y)
|
||||
else: raise Exception(f"{op} isn't supported")
|
||||
|
||||
@@ -48,13 +48,14 @@ def reduce_op(op, inp, ret):
|
||||
else: raise Exception(f"{op} isn't supported")
|
||||
|
||||
def movement_op(op, x, ret, arg=None):
|
||||
if op == MovementOps.RESHAPE: ret[:] = x.reshape(ret.shape)
|
||||
if op == MovementOps.RESHAPE: ret[:] = x.reshape(arg)
|
||||
elif op == MovementOps.PERMUTE: ret[:] = x.permute(arg)
|
||||
elif op == MovementOps.SLICE:
|
||||
padding = [(max(0, -p[0]), max(0, p[1]-x.shape[i])) for i,p in enumerate(arg)]
|
||||
x = x.custompad(padding)
|
||||
slicee = [(p[0] + padding[i][0], p[1] + padding[i][0]) for i,p in enumerate(arg)]
|
||||
ret[:] = x[tuple([slice(x[0], x[1], None) for x in slicee])]
|
||||
elif op == MovementOps.EXPAND: ret[:] = x.expand(arg)
|
||||
else: raise Exception(f"{op} isn't supported")
|
||||
|
||||
def get_tx(x, C):
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import functools
|
||||
import numpy as np
|
||||
import pyopencl as cl
|
||||
from tinygrad.helpers import prod, binary_broadcast, get_conv_args
|
||||
from tinygrad.helpers import prod
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, MovementOps, ProcessingOps
|
||||
|
||||
cl_ctx, cl_queue = None, None
|
||||
@@ -66,42 +66,24 @@ def unary_op(op, x, ret):
|
||||
unop([roundup(prod(ret.shape))//4], None, x.cl, ret.cl)
|
||||
return ret
|
||||
|
||||
@functools.lru_cache
|
||||
def get_binop_prg(code, complist):
|
||||
ndims = len(complist)
|
||||
args = "".join([f", int d{i}" for i in range(ndims)] + [f", int p{i}" for i in range(ndims-1)])
|
||||
compute_idx_rets = "".join([f"\n int idx_ret{i} = (gid0 / {f'p{i}' if i < ndims-1 else '1'}) % d{i};" for i in range(ndims)])
|
||||
|
||||
idx_exprs = ["0", "0"] # [idx_x, idx_y]
|
||||
for i in range(ndims):
|
||||
for j in range(2):
|
||||
if complist[i][j]:
|
||||
idx_exprs[j] = "idx_ret%d + d%d*(%s)" % (i, i, idx_exprs[j])
|
||||
|
||||
dtype = ["float", "float", "float"]
|
||||
prg = """__kernel void binop(__global const """+dtype[0]+""" *x_g, __global const """+dtype[1]+""" *y_g, __global """+dtype[2]+""" *res_g"""+args+""") {
|
||||
int gid0 = get_global_id(0);"""+compute_idx_rets+"""
|
||||
"""+dtype[0]+""" a = x_g["""+idx_exprs[0]+"""];
|
||||
"""+dtype[1]+""" b = y_g["""+idx_exprs[1]+"""];
|
||||
res_g[gid0] = """+code+""";\n}"""
|
||||
return cl.Program(cl_ctx, prg).build(), dtype[2] == "float4"
|
||||
|
||||
def binary_op(op, x, y, ret):
|
||||
if op == BinaryOps.ADD: code = "a+b"
|
||||
elif op == BinaryOps.SUB: code = "a-b"
|
||||
elif op == BinaryOps.MUL: code = "a*b"
|
||||
elif op == BinaryOps.DIV: code = "b/a"
|
||||
elif op == BinaryOps.POW: code = "pow(a,b)"
|
||||
elif op == BinaryOps.A: code = "a"
|
||||
elif op == BinaryOps.CMPEQ: code = "1.0f*(a==b)"
|
||||
elif op == BinaryOps.CMPEQ: code = "(float4)(1.0f*(a.x==b.x), 1.0f*(a.y==b.y), 1.0f*(a.z==b.z), 1.0f*(a.w==b.w))"
|
||||
else: raise Exception(f"{op} isn't supported")
|
||||
|
||||
shape_ret, dimlist, complist = binary_broadcast(x.shape, y.shape, True)
|
||||
assert tuple(shape_ret) == tuple(ret.shape)
|
||||
prod_list = np.array(dimlist, dtype=i32)[-1::-1].cumprod(dtype=i32)[-1::-1] # take cumprod from back to front
|
||||
prg, is_float4 = get_binop_prg(code, tuple(complist))
|
||||
kernel_size = ((roundup(prod_list[0])//4) if is_float4 else prod_list[0]) if len(dimlist) > 0 else 1
|
||||
prg.binop(cl_queue, [kernel_size], None, x.cl, y.cl, ret.cl, *dimlist, *(prod_list[1:]))
|
||||
assert x.shape == ret.shape and y.shape == ret.shape
|
||||
binop = clbuild("binop", """
|
||||
__kernel void binop(__global const float4 *a_g, __global const float4 *b_g, __global float4 *res_g) {
|
||||
int gid = get_global_id(0);
|
||||
float4 a = a_g[gid];
|
||||
float4 b = b_g[gid];
|
||||
res_g[gid] = """+code+""";
|
||||
}""")
|
||||
binop([roundup(prod(ret.shape))//4], None, x.cl, y.cl, ret.cl)
|
||||
return ret
|
||||
|
||||
def reduce_op(op, inp, ret):
|
||||
if op == ReduceOps.SUM:
|
||||
@@ -190,10 +172,39 @@ def inner_slice(x, arg, ret):
|
||||
buffer_np(np.array(ret.shape, dtype=np.int32)),
|
||||
buffer_np(np.array(shift, dtype=np.int32)))
|
||||
|
||||
def expand(x, ret):
|
||||
assert len(x.shape) == len(ret.shape)
|
||||
|
||||
dimlist, complist = [], [] # note: len(dimlist) may be less than n_dims
|
||||
def push(dim, comp):
|
||||
if len(complist) > 0 and complist[-1] == comp:
|
||||
dimlist[-1] *= dim
|
||||
elif comp != (False, False):
|
||||
dimlist.append(dim); complist.append(comp)
|
||||
for i,j in zip(x.shape, ret.shape): # group together any adjacent dimensions that we can to simplify broadcasting
|
||||
push(np.int32(max(i,j)), (i > 1, j > 1))
|
||||
prod_list = np.array(dimlist, dtype=i32)[-1::-1].cumprod(dtype=i32)[-1::-1] # take cumprod from back to front
|
||||
|
||||
ndims = len(complist)
|
||||
args = "".join([f", int d{i}" for i in range(ndims)] + [f", int p{i}" for i in range(ndims-1)])
|
||||
compute_idx_rets = "".join([f"\n int idx_ret{i} = (gid0 / {f'p{i}' if i < ndims-1 else '1'}) % d{i};" for i in range(ndims)])
|
||||
|
||||
idx_exprs = ["0", "0"] # [idx_x, idx_y]
|
||||
for i in range(ndims):
|
||||
for j in range(2):
|
||||
if complist[i][j]:
|
||||
idx_exprs[j] = "idx_ret%d + d%d*(%s)" % (i, i, idx_exprs[j])
|
||||
|
||||
expandop = clbuild("expandop", """__kernel void expandop(__global const float *x_g, __global float *res_g"""+args+""") {
|
||||
int gid0 = get_global_id(0);"""+compute_idx_rets+"""
|
||||
res_g[gid0] = x_g["""+idx_exprs[0]+"""];\n}""")
|
||||
expandop([prod_list[0] if len(dimlist) > 0 else 1], None, x.cl, ret.cl, *dimlist, *(prod_list[1:]))
|
||||
|
||||
def movement_op(op, x, ret, arg=None):
|
||||
if op == MovementOps.RESHAPE: reshape(x, ret)
|
||||
elif op == MovementOps.PERMUTE: perm_axis(x, arg, ret)
|
||||
elif op == MovementOps.SLICE: inner_slice(x, arg, ret)
|
||||
elif op == MovementOps.EXPAND: expand(x, ret)
|
||||
|
||||
def conv(x,w,ret,C):
|
||||
# input = (bs, groups, cin, iy, ix)
|
||||
|
||||
@@ -44,9 +44,7 @@ class Sum(Function):
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
shape_input, = ctx.saved_tensors
|
||||
# NOTE: the b Buffer isn't used, since this is just for broadcast
|
||||
ret = ctx.buffer(shape_input)
|
||||
return ctx.binary_op(BinaryOps.A, grad_output, ret)
|
||||
return ctx.movement_op(MovementOps.EXPAND, grad_output, shape_input)
|
||||
|
||||
class Max(Function):
|
||||
def forward(ctx, input, axis=None):
|
||||
@@ -56,36 +54,35 @@ class Max(Function):
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
input, ret = ctx.saved_tensors
|
||||
ret2 = ctx.binary_op(BinaryOps.CMPEQ, input, ret)
|
||||
div = ctx.reduce_op(ReduceOps.SUM, ret2, grad_output.shape)
|
||||
ret2 = ctx.binary_op(BinaryOps.DIV, div, ret2)
|
||||
return ctx.binary_op(BinaryOps.MUL, ret2, grad_output)
|
||||
|
||||
# 1s in locations where the max was chosen (can be two locations)
|
||||
max_is_1s = ctx.binary_op(BinaryOps.CMPEQ, input, ctx.movement_op(MovementOps.EXPAND, ret, input.shape))
|
||||
|
||||
# sum of locations, averaged
|
||||
div = ctx.reduce_op(ReduceOps.SUM, max_is_1s, grad_output.shape)
|
||||
div = ctx.movement_op(MovementOps.EXPAND, div, input.shape)
|
||||
max_is_amount = ctx.binary_op(BinaryOps.DIV, div, max_is_1s)
|
||||
|
||||
grad_output_expanded = ctx.movement_op(MovementOps.EXPAND, grad_output, input.shape)
|
||||
return ctx.binary_op(BinaryOps.MUL, max_is_amount, grad_output_expanded)
|
||||
|
||||
# ************* binary ops *************
|
||||
|
||||
def unbroadcast(ctx, out, in_sh):
|
||||
return ctx.reduce_op(ReduceOps.SUM, out, in_sh)
|
||||
|
||||
class Add(Function):
|
||||
def forward(ctx, x, y):
|
||||
ctx.save_for_backward(x.shape, y.shape)
|
||||
return ctx.binary_op(BinaryOps.ADD, x, y)
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
shape_x, shape_y = ctx.saved_tensors
|
||||
return unbroadcast(ctx, grad_output, shape_x) if ctx.needs_input_grad[0] else None, \
|
||||
unbroadcast(ctx, grad_output, shape_y) if ctx.needs_input_grad[1] else None
|
||||
return grad_output if ctx.needs_input_grad[0] else None, \
|
||||
grad_output if ctx.needs_input_grad[1] else None
|
||||
|
||||
class Sub(Function):
|
||||
def forward(ctx, x, y):
|
||||
ctx.save_for_backward(x.shape, y.shape)
|
||||
return ctx.binary_op(BinaryOps.SUB, x, y)
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
shape_x, shape_y = ctx.saved_tensors
|
||||
neg_grad_output = ctx.unary_op(UnaryOps.NEG, grad_output)
|
||||
return unbroadcast(ctx, grad_output, shape_x) if ctx.needs_input_grad[0] else None, \
|
||||
unbroadcast(ctx, neg_grad_output, shape_y) if ctx.needs_input_grad[1] else None
|
||||
return grad_output if ctx.needs_input_grad[0] else None, \
|
||||
ctx.unary_op(UnaryOps.NEG, grad_output) if ctx.needs_input_grad[1] else None
|
||||
|
||||
class Mul(Function):
|
||||
def forward(ctx, x, y):
|
||||
@@ -94,8 +91,8 @@ class Mul(Function):
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
x,y = ctx.saved_tensors
|
||||
grad_x = unbroadcast(ctx, ctx.binary_op(BinaryOps.MUL, y, grad_output), x.shape) if ctx.needs_input_grad[0] else None
|
||||
grad_y = unbroadcast(ctx, ctx.binary_op(BinaryOps.MUL, x, grad_output), y.shape) if ctx.needs_input_grad[1] else None
|
||||
grad_x = ctx.binary_op(BinaryOps.MUL, y, grad_output) if ctx.needs_input_grad[0] else None
|
||||
grad_y = ctx.binary_op(BinaryOps.MUL, x, grad_output) if ctx.needs_input_grad[1] else None
|
||||
return grad_x, grad_y
|
||||
|
||||
class Pow(Function):
|
||||
@@ -106,15 +103,28 @@ class Pow(Function):
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
x,y,powxy = ctx.saved_tensors
|
||||
tmp = ctx.binary_op(BinaryOps.DIV, x, powxy) # pow(x,y)/x
|
||||
tmp = ctx.binary_op(BinaryOps.MUL, y, tmp) # y * pow(x,y)/x
|
||||
grad_x = unbroadcast(ctx, ctx.binary_op(BinaryOps.MUL, grad_output, tmp), x.shape) if ctx.needs_input_grad[0] else None
|
||||
tmp = ctx.binary_op(BinaryOps.MUL, ctx.unary_op(UnaryOps.LOG, x), powxy) # log(x) * pow(x,y)
|
||||
grad_y = unbroadcast(ctx, ctx.binary_op(BinaryOps.MUL, grad_output, tmp), y.shape) if ctx.needs_input_grad[1] else None
|
||||
grad_x, grad_y = None, None
|
||||
if ctx.needs_input_grad[0]:
|
||||
tmp = ctx.binary_op(BinaryOps.DIV, x, powxy) # pow(x,y)/x
|
||||
tmp = ctx.binary_op(BinaryOps.MUL, y, tmp) # y * pow(x,y)/x
|
||||
grad_x = ctx.binary_op(BinaryOps.MUL, grad_output, tmp)
|
||||
if ctx.needs_input_grad[1]:
|
||||
tmp = ctx.binary_op(BinaryOps.MUL, ctx.unary_op(UnaryOps.LOG, x), powxy) # log(x) * pow(x,y)
|
||||
grad_y = ctx.binary_op(BinaryOps.MUL, grad_output, tmp)
|
||||
return grad_x, grad_y
|
||||
|
||||
# ************* movement ops *************
|
||||
|
||||
# NOTE: this is sum in reverse
|
||||
class Expand(Function):
|
||||
def forward(ctx, x, shape):
|
||||
ctx.save_for_backward(x.shape)
|
||||
return ctx.movement_op(MovementOps.EXPAND, x, shape)
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
in_shape, = ctx.saved_tensors
|
||||
return ctx.reduce_op(ReduceOps.SUM, grad_output, in_shape)
|
||||
|
||||
class Reshape(Function):
|
||||
def forward(ctx, x, shape):
|
||||
ctx.save_for_backward(x.shape)
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
# TODO: move Device to here and proxy buffer call
|
||||
from enum import Enum
|
||||
UnaryOps = Enum("UnaryOps", ["RELU", "EXP", "LOG", "NEG", "SIGN"])
|
||||
BinaryOps = Enum("BinaryOps", ["ADD", "SUB", "MUL", "DIV", "POW", "A", "CMPEQ"])
|
||||
BinaryOps = Enum("BinaryOps", ["ADD", "SUB", "MUL", "DIV", "POW", "CMPEQ"])
|
||||
ReduceOps = Enum("ReduceOps", ["SUM", "MAX"])
|
||||
MovementOps = Enum("MovementOps", ["RESHAPE", "PERMUTE", "SLICE"])
|
||||
MovementOps = Enum("MovementOps", ["RESHAPE", "PERMUTE", "SLICE", "EXPAND"])
|
||||
ProcessingOps = Enum("ProcessingOps", ["CONV", "CONVT", "CONVDW"])
|
||||
|
||||
import os
|
||||
@@ -45,7 +45,6 @@ def log_op(op, ret, inp):
|
||||
G.nodes[nm(ret)]['fillcolor'] = top_colors[top]
|
||||
G.nodes[nm(ret)]['style'] = 'filled'
|
||||
|
||||
from tinygrad.helpers import binary_broadcast
|
||||
class Ops:
|
||||
def unary_op(ctx, op:UnaryOps, x):
|
||||
ret = ctx.buffer(x.shape)
|
||||
@@ -60,13 +59,14 @@ class Ops:
|
||||
return ret
|
||||
|
||||
def binary_op(ctx, op:ReduceOps, x, y):
|
||||
ret = ctx.buffer(binary_broadcast(x.shape, y.shape))
|
||||
assert x.shape == y.shape
|
||||
ret = ctx.buffer(x.shape)
|
||||
ctx.op.binary_op(op, x, y, ret)
|
||||
log_op(op, ret, [x] if op == BinaryOps.A else [x, y])
|
||||
log_op(op, ret, [x, y])
|
||||
return ret
|
||||
|
||||
def movement_op(ctx, op:MovementOps, x, arg=None):
|
||||
if op == MovementOps.RESHAPE: new_shape = arg
|
||||
if op in [MovementOps.RESHAPE, MovementOps.EXPAND]: new_shape = arg
|
||||
if op == MovementOps.PERMUTE: new_shape = [x.shape[i] for i in arg]
|
||||
if op == MovementOps.SLICE: new_shape = [y-x for x,y in arg]
|
||||
ret = ctx.buffer(new_shape)
|
||||
|
||||
@@ -349,10 +349,34 @@ class Tensor:
|
||||
ret = x._conv2d(weight, stride=stride, groups=groups)
|
||||
return ret if bias is None else ret.add(bias.reshape(shape=[1, -1, 1, 1]))
|
||||
|
||||
# ***** broadcasted binary ops *****
|
||||
|
||||
@staticmethod
|
||||
def broadcasted(fxn, x, y):
|
||||
tt = [arg for arg in [x,y] if isinstance(arg, Tensor)][0] # this is the prototype tensor
|
||||
if not isinstance(x, Tensor): x = Tensor(np.array([x], dtype=tt.dtype), device=tt.device, requires_grad=False)
|
||||
if not isinstance(y, Tensor): y = Tensor(np.array([y], dtype=tt.dtype), device=tt.device, requires_grad=False)
|
||||
|
||||
n_dims = max(len(x.shape), len(y.shape))
|
||||
if len(x.shape) != n_dims: x = x.reshape(list(x.shape) + [1]*(n_dims-len(x.shape)))
|
||||
if len(y.shape) != n_dims: y = y.reshape(list(y.shape) + [1]*(n_dims-len(y.shape)))
|
||||
|
||||
shape_ret = tuple([int(x) for x in np.maximum(x.shape, y.shape)])
|
||||
if x.shape != shape_ret: x = x.expand(shape_ret)
|
||||
if y.shape != shape_ret: y = y.expand(shape_ret)
|
||||
return fxn(x, y)
|
||||
|
||||
# TODO: are these the only ones that can take number arguments?
|
||||
def add(self, x): return Tensor.broadcasted(Tensor._add, self, x)
|
||||
def sub(self, x): return Tensor.broadcasted(Tensor._sub, self, x)
|
||||
def mul(self, x): return Tensor.broadcasted(Tensor._mul, self, x)
|
||||
def pow(self, x): return Tensor.broadcasted(Tensor._pow, self, x)
|
||||
|
||||
# ***** functional nn ops *****
|
||||
|
||||
def reshape(self, shape):
|
||||
return self._reshape(shape=shape)
|
||||
# TODO: fix the kwargs problem
|
||||
def reshape(self, shape): return self._reshape(shape=shape)
|
||||
def expand(self, shape): return self._expand(shape=shape)
|
||||
|
||||
def linear(self, weight, bias):
|
||||
shp = [1] * (len(self.shape)-1) + [-1]
|
||||
@@ -391,13 +415,8 @@ class Function(Ops):
|
||||
|
||||
@classmethod
|
||||
def apply(cls, *x, **kwargs):
|
||||
tt = [arg for arg in x if isinstance(arg, Tensor)][0] # this is the prototype tensor
|
||||
|
||||
# create tensors from number arguments
|
||||
x = [Tensor(np.array([arg], dtype=tt.dtype), device=tt.device, requires_grad=False) if not isinstance(arg, Tensor) else arg for arg in x]
|
||||
assert all([tt.device == t.device for t in x]), "All tensors are not on the same device"
|
||||
|
||||
ctx = cls(tt.device, *x)
|
||||
assert all([isinstance(arg, Tensor) for arg in x])
|
||||
ctx = cls(x[0].device, *x)
|
||||
with ProfileOp(ctx, ctx.__class__.__name__, x) as po:
|
||||
ret = Tensor(cls.forward(ctx, *[t.data for t in x], **kwargs),
|
||||
device=ctx.device, requires_grad=ctx.requires_grad)
|
||||
@@ -419,6 +438,7 @@ for name, cls in inspect.getmembers(importlib.import_module('tinygrad.mlops'), i
|
||||
if name[0] != "_" and name != "Function" and not name.endswith("Ops"): register(name.lower(), cls)
|
||||
|
||||
# register the operators
|
||||
# TODO: add div
|
||||
def register_op(name, fxn):
|
||||
setattr(Tensor, f"__{name}__", fxn)
|
||||
setattr(Tensor, f"__i{name}__", lambda self,x: self.assign(fxn(self,x)))
|
||||
|
||||
Reference in New Issue
Block a user