onnx : add a whole bunch of ops

This commit is contained in:
George Hotz
2023-02-24 12:00:03 -08:00
parent f2486a7248
commit e8a153e4e9
3 changed files with 94 additions and 2 deletions

View File

@@ -93,7 +93,7 @@ def get_run_onnx(onnx_model):
for num,n in enumerate(onnx_model.graph.node):
inp = [tensors[x] if x in tensors else (intermediate_tensors[x] if x in intermediate_tensors else (input_tensors[x] if x != str() else None)) for x in n.input]
opt = attribute_dict[num]
if debug: print(f"{num}: op {n.op_type} shape {[x.shape if x is not None else None for x in inp]} opt {opt}")
if debug: print(f"{num}: op {n.op_type} shape {[x.shape if isinstance(x, Tensor) else x for x in inp]} opt {opt}")
# free ones
if n.op_type == "Relu": ret = inp[0].relu()
@@ -125,7 +125,7 @@ def get_run_onnx(onnx_model):
ret = ret.reshape([s for i,s in enumerate(shape) if i != axis]) if len(indices) == 1 else ret # squeeze if needed
elif n.op_type in ["Sum"]:
ret = functools.reduce(Tensor.__add__, inp)
elif n.op_type in ["Add", "Sub", "Mul"]:
elif n.op_type in ["Add", "Sub", "Mul", "Pow"]:
# TODO: add this to tinygrad? i don't think it's in torch
if len(inp[0].shape) != len(inp[1].shape) and prod(inp[0].shape) == prod(inp[1].shape):
inp[1] = inp[1].reshape(inp[0].shape)
@@ -134,6 +134,7 @@ def get_run_onnx(onnx_model):
if n.op_type == "Add": ret = inp[0] + inp[1]
if n.op_type == "Sub": ret = inp[0] - inp[1]
if n.op_type == "Mul": ret = inp[0] * inp[1]
if n.op_type == "Pow": ret = inp[0] ** inp[1]
elif n.op_type == "Split":
if 'split' not in opt: opt['split'] = [int(x) for x in safe_numpy(inp[1])] # split can be a tensor
i = 0

View File

@@ -80,19 +80,30 @@ def Expand(input, shape):
return input.reshape(x_shape) #.expand(shape_ret)
def Identity(input): return input
def Neg(input): return -input
def Sqrt(input): return input.sqrt()
def Sign(input): return input.sign()
def Abs(input): return input.abs()
def Exp(input): return input.exp()
def Log(input): return input.log()
def Mish(input): return input.mish()
def HardSigmoid(input, alpha=0.2, beta=0.5): return (alpha*input + beta).clip(0, 1)
def HardSwish(input): return input * HardSigmoid(input, 1/6, 0.5)
def Selu(X, alpha=1.67326319217681884765625, gamma=1.05070102214813232421875): return gamma * (X.relu() - (-alpha*X.exp()+alpha).relu())
def Softplus(X): return X.softplus()
def PRelu(X, slope): return X.leakyrelu(slope)
def LeakyRelu(X, alpha=0.01): return X.leakyrelu(alpha)
def Softmax(input, axis=-1): return input.softmax(axis)
def LogSoftmax(input, axis=-1): return input.log_softmax(axis)
def Clip(input, min=-3.4e38, max=3.4e38): return input.clip(min, max)
def _axes(axes, noop_with_empty_axes): return [int(x) for x in safe_numpy(axes)] if axes is not None else ([] if noop_with_empty_axes else None)
# ReduceProd would require a new llop
def ReduceMax(data, axes=None, keepdims=1, noop_with_empty_axes=0): return data.max(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
def ReduceMin(data, axes=None, keepdims=1, noop_with_empty_axes=0): return data.min(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
def ReduceSum(data, axes=None, keepdims=1, noop_with_empty_axes=0): return data.sum(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
def ReduceMean(data, axes=None, keepdims=1, noop_with_empty_axes=0): return data.mean(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
def ReduceSumSquare(data, axes=None, keepdims=1, noop_with_empty_axes=0): return data.square().sum(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
def ReduceL1(data, axes=None, keepdims=1, noop_with_empty_axes=0): return data.abs().sum(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
def ReduceL2(data, axes=None, keepdims=1, noop_with_empty_axes=0): return data.square().sum(_axes(axes, noop_with_empty_axes), keepdim=keepdims).sqrt()

View File

@@ -37,6 +37,83 @@ class TinygradBackend(Backend):
backend_test = onnx.backend.test.BackendTest(TinygradBackend, __name__)
# no binaryops min or max (needs llop, should add and replace relu)
backend_test.exclude('test_min_*')
backend_test.exclude('test_max_*')
# add support for SoftmaxCrossEntropyLoss and NegativeLogLikelihoodLoss
backend_test.exclude('test_sce_*')
# we only support float32
backend_test.exclude('test_add_uint8_*')
backend_test.exclude('test_div_uint8_*')
backend_test.exclude('test_cast_*')
backend_test.exclude('test_castlike_*')
# no support for nan or inf
backend_test.exclude('test_isinf_*')
backend_test.exclude('test_isnan_*')
# no support for mod
backend_test.exclude('test_mod_*')
# no trig ops
backend_test.exclude('test_acos_*')
backend_test.exclude('test_acosh_*')
backend_test.exclude('test_asin_*')
backend_test.exclude('test_asinh_*')
backend_test.exclude('test_atan_*')
backend_test.exclude('test_atanh_*')
backend_test.exclude('test_cos_*')
backend_test.exclude('test_cosh_*')
backend_test.exclude('test_sin_*')
backend_test.exclude('test_sinh_*')
backend_test.exclude('test_tan_*')
# no boolean ops (2d, 3d, 4d)
backend_test.exclude('test_and*')
backend_test.exclude('test_xor*')
backend_test.exclude('test_or*')
backend_test.exclude('test_bitshift_*')
# no scatter gather
backend_test.exclude('test_gather_*')
backend_test.exclude('test_gathernd_*')
backend_test.exclude('test_scatter_*')
backend_test.exclude('test_scatternd_*')
# unsupported (strange) ops
backend_test.exclude('test_adagrad_*')
backend_test.exclude('test_adam_*')
backend_test.exclude('test_argmax_*')
backend_test.exclude('test_argmin_*')
backend_test.exclude('test_bitwise_*')
backend_test.exclude('test_blackmanwindow_*')
backend_test.exclude('test_bernoulli_*')
backend_test.exclude('test_cumsum_*')
backend_test.exclude('test_tril_*')
backend_test.exclude('test_triu_*')
backend_test.exclude('test_convinteger_*')
backend_test.exclude('test_col2im_*')
backend_test.exclude('test_hammingwindow_*')
backend_test.exclude('test_hannwindow_*')
backend_test.exclude('test_hardmax_*')
backend_test.exclude('test_gru_*')
backend_test.exclude('test_gridsample_*')
backend_test.exclude('test_if_*')
backend_test.exclude('test_compress_*')
backend_test.exclude('test_dequantizelinear_*')
backend_test.exclude('test_dynamicquantizelinear_*')
backend_test.exclude('test_det_*')
backend_test.exclude('test_dft_*')
backend_test.exclude('test_einsum_*')
backend_test.exclude('test_erf_*')
backend_test.exclude('test_strnorm_*')
backend_test.exclude('test_unique_*')
backend_test.exclude('test_sequence_*')
backend_test.include('test_selu_*')
# the node tests
#for x in backend_test.test_suite:
# if 'OnnxBackendNodeModelTest' in str(type(x)):
@@ -65,6 +142,7 @@ backend_test.include('test_softplus*')
# requires cast
#backend_test.include('test_reduce_log_sum*')
#backend_test.include('test_pow_*')
# almost passing node tests
#backend_test.include('test_PReLU*')
@@ -89,6 +167,8 @@ backend_test.include('test_tanh_*')
# requires CastLike?
#backend_test.include('test_relu_*')
#backend_test.include('test_elu_*')
#backend_test.include('test_leakyrelu_*')
#backend_test.include('test_hardsigmoid_*')
# failing for lack of type support
#backend_test.include('test_add_*')