From e8a153e4e943c2e680eda40e63bd1936c2f307b5 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Fri, 24 Feb 2023 12:00:03 -0800 Subject: [PATCH] onnx : add a whole bunch of ops --- extra/onnx.py | 5 +- extra/onnx_ops.py | 11 ++++ test/external_test_onnx_backend.py | 80 ++++++++++++++++++++++++++++++ 3 files changed, 94 insertions(+), 2 deletions(-) diff --git a/extra/onnx.py b/extra/onnx.py index 7e6ea0026f..d2994c31c9 100644 --- a/extra/onnx.py +++ b/extra/onnx.py @@ -93,7 +93,7 @@ def get_run_onnx(onnx_model): for num,n in enumerate(onnx_model.graph.node): inp = [tensors[x] if x in tensors else (intermediate_tensors[x] if x in intermediate_tensors else (input_tensors[x] if x != str() else None)) for x in n.input] opt = attribute_dict[num] - if debug: print(f"{num}: op {n.op_type} shape {[x.shape if x is not None else None for x in inp]} opt {opt}") + if debug: print(f"{num}: op {n.op_type} shape {[x.shape if isinstance(x, Tensor) else x for x in inp]} opt {opt}") # free ones if n.op_type == "Relu": ret = inp[0].relu() @@ -125,7 +125,7 @@ def get_run_onnx(onnx_model): ret = ret.reshape([s for i,s in enumerate(shape) if i != axis]) if len(indices) == 1 else ret # squeeze if needed elif n.op_type in ["Sum"]: ret = functools.reduce(Tensor.__add__, inp) - elif n.op_type in ["Add", "Sub", "Mul"]: + elif n.op_type in ["Add", "Sub", "Mul", "Pow"]: # TODO: add this to tinygrad? i don't think it's in torch if len(inp[0].shape) != len(inp[1].shape) and prod(inp[0].shape) == prod(inp[1].shape): inp[1] = inp[1].reshape(inp[0].shape) @@ -134,6 +134,7 @@ def get_run_onnx(onnx_model): if n.op_type == "Add": ret = inp[0] + inp[1] if n.op_type == "Sub": ret = inp[0] - inp[1] if n.op_type == "Mul": ret = inp[0] * inp[1] + if n.op_type == "Pow": ret = inp[0] ** inp[1] elif n.op_type == "Split": if 'split' not in opt: opt['split'] = [int(x) for x in safe_numpy(inp[1])] # split can be a tensor i = 0 diff --git a/extra/onnx_ops.py b/extra/onnx_ops.py index 63a488ad01..14b20b6e8a 100644 --- a/extra/onnx_ops.py +++ b/extra/onnx_ops.py @@ -80,19 +80,30 @@ def Expand(input, shape): return input.reshape(x_shape) #.expand(shape_ret) def Identity(input): return input +def Neg(input): return -input +def Sqrt(input): return input.sqrt() +def Sign(input): return input.sign() def Abs(input): return input.abs() def Exp(input): return input.exp() def Log(input): return input.log() +def Mish(input): return input.mish() +def HardSigmoid(input, alpha=0.2, beta=0.5): return (alpha*input + beta).clip(0, 1) +def HardSwish(input): return input * HardSigmoid(input, 1/6, 0.5) +def Selu(X, alpha=1.67326319217681884765625, gamma=1.05070102214813232421875): return gamma * (X.relu() - (-alpha*X.exp()+alpha).relu()) def Softplus(X): return X.softplus() def PRelu(X, slope): return X.leakyrelu(slope) +def LeakyRelu(X, alpha=0.01): return X.leakyrelu(alpha) def Softmax(input, axis=-1): return input.softmax(axis) def LogSoftmax(input, axis=-1): return input.log_softmax(axis) def Clip(input, min=-3.4e38, max=3.4e38): return input.clip(min, max) def _axes(axes, noop_with_empty_axes): return [int(x) for x in safe_numpy(axes)] if axes is not None else ([] if noop_with_empty_axes else None) +# ReduceProd would require a new llop def ReduceMax(data, axes=None, keepdims=1, noop_with_empty_axes=0): return data.max(_axes(axes, noop_with_empty_axes), keepdim=keepdims) +def ReduceMin(data, axes=None, keepdims=1, noop_with_empty_axes=0): return data.min(_axes(axes, noop_with_empty_axes), keepdim=keepdims) def ReduceSum(data, axes=None, keepdims=1, noop_with_empty_axes=0): return data.sum(_axes(axes, noop_with_empty_axes), keepdim=keepdims) +def ReduceMean(data, axes=None, keepdims=1, noop_with_empty_axes=0): return data.mean(_axes(axes, noop_with_empty_axes), keepdim=keepdims) def ReduceSumSquare(data, axes=None, keepdims=1, noop_with_empty_axes=0): return data.square().sum(_axes(axes, noop_with_empty_axes), keepdim=keepdims) def ReduceL1(data, axes=None, keepdims=1, noop_with_empty_axes=0): return data.abs().sum(_axes(axes, noop_with_empty_axes), keepdim=keepdims) def ReduceL2(data, axes=None, keepdims=1, noop_with_empty_axes=0): return data.square().sum(_axes(axes, noop_with_empty_axes), keepdim=keepdims).sqrt() diff --git a/test/external_test_onnx_backend.py b/test/external_test_onnx_backend.py index 5fdb8938ab..6c38172fac 100644 --- a/test/external_test_onnx_backend.py +++ b/test/external_test_onnx_backend.py @@ -37,6 +37,83 @@ class TinygradBackend(Backend): backend_test = onnx.backend.test.BackendTest(TinygradBackend, __name__) +# no binaryops min or max (needs llop, should add and replace relu) +backend_test.exclude('test_min_*') +backend_test.exclude('test_max_*') + +# add support for SoftmaxCrossEntropyLoss and NegativeLogLikelihoodLoss +backend_test.exclude('test_sce_*') + +# we only support float32 +backend_test.exclude('test_add_uint8_*') +backend_test.exclude('test_div_uint8_*') +backend_test.exclude('test_cast_*') +backend_test.exclude('test_castlike_*') + +# no support for nan or inf +backend_test.exclude('test_isinf_*') +backend_test.exclude('test_isnan_*') + +# no support for mod +backend_test.exclude('test_mod_*') + +# no trig ops +backend_test.exclude('test_acos_*') +backend_test.exclude('test_acosh_*') +backend_test.exclude('test_asin_*') +backend_test.exclude('test_asinh_*') +backend_test.exclude('test_atan_*') +backend_test.exclude('test_atanh_*') +backend_test.exclude('test_cos_*') +backend_test.exclude('test_cosh_*') +backend_test.exclude('test_sin_*') +backend_test.exclude('test_sinh_*') +backend_test.exclude('test_tan_*') + +# no boolean ops (2d, 3d, 4d) +backend_test.exclude('test_and*') +backend_test.exclude('test_xor*') +backend_test.exclude('test_or*') +backend_test.exclude('test_bitshift_*') + +# no scatter gather +backend_test.exclude('test_gather_*') +backend_test.exclude('test_gathernd_*') +backend_test.exclude('test_scatter_*') +backend_test.exclude('test_scatternd_*') + +# unsupported (strange) ops +backend_test.exclude('test_adagrad_*') +backend_test.exclude('test_adam_*') +backend_test.exclude('test_argmax_*') +backend_test.exclude('test_argmin_*') +backend_test.exclude('test_bitwise_*') +backend_test.exclude('test_blackmanwindow_*') +backend_test.exclude('test_bernoulli_*') +backend_test.exclude('test_cumsum_*') +backend_test.exclude('test_tril_*') +backend_test.exclude('test_triu_*') +backend_test.exclude('test_convinteger_*') +backend_test.exclude('test_col2im_*') +backend_test.exclude('test_hammingwindow_*') +backend_test.exclude('test_hannwindow_*') +backend_test.exclude('test_hardmax_*') +backend_test.exclude('test_gru_*') +backend_test.exclude('test_gridsample_*') +backend_test.exclude('test_if_*') +backend_test.exclude('test_compress_*') +backend_test.exclude('test_dequantizelinear_*') +backend_test.exclude('test_dynamicquantizelinear_*') +backend_test.exclude('test_det_*') +backend_test.exclude('test_dft_*') +backend_test.exclude('test_einsum_*') +backend_test.exclude('test_erf_*') +backend_test.exclude('test_strnorm_*') +backend_test.exclude('test_unique_*') +backend_test.exclude('test_sequence_*') + +backend_test.include('test_selu_*') + # the node tests #for x in backend_test.test_suite: # if 'OnnxBackendNodeModelTest' in str(type(x)): @@ -65,6 +142,7 @@ backend_test.include('test_softplus*') # requires cast #backend_test.include('test_reduce_log_sum*') +#backend_test.include('test_pow_*') # almost passing node tests #backend_test.include('test_PReLU*') @@ -89,6 +167,8 @@ backend_test.include('test_tanh_*') # requires CastLike? #backend_test.include('test_relu_*') #backend_test.include('test_elu_*') +#backend_test.include('test_leakyrelu_*') +#backend_test.include('test_hardsigmoid_*') # failing for lack of type support #backend_test.include('test_add_*')