mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 08:28:55 +08:00
onnx : add a whole bunch of ops
This commit is contained in:
@@ -93,7 +93,7 @@ def get_run_onnx(onnx_model):
|
||||
for num,n in enumerate(onnx_model.graph.node):
|
||||
inp = [tensors[x] if x in tensors else (intermediate_tensors[x] if x in intermediate_tensors else (input_tensors[x] if x != str() else None)) for x in n.input]
|
||||
opt = attribute_dict[num]
|
||||
if debug: print(f"{num}: op {n.op_type} shape {[x.shape if x is not None else None for x in inp]} opt {opt}")
|
||||
if debug: print(f"{num}: op {n.op_type} shape {[x.shape if isinstance(x, Tensor) else x for x in inp]} opt {opt}")
|
||||
|
||||
# free ones
|
||||
if n.op_type == "Relu": ret = inp[0].relu()
|
||||
@@ -125,7 +125,7 @@ def get_run_onnx(onnx_model):
|
||||
ret = ret.reshape([s for i,s in enumerate(shape) if i != axis]) if len(indices) == 1 else ret # squeeze if needed
|
||||
elif n.op_type in ["Sum"]:
|
||||
ret = functools.reduce(Tensor.__add__, inp)
|
||||
elif n.op_type in ["Add", "Sub", "Mul"]:
|
||||
elif n.op_type in ["Add", "Sub", "Mul", "Pow"]:
|
||||
# TODO: add this to tinygrad? i don't think it's in torch
|
||||
if len(inp[0].shape) != len(inp[1].shape) and prod(inp[0].shape) == prod(inp[1].shape):
|
||||
inp[1] = inp[1].reshape(inp[0].shape)
|
||||
@@ -134,6 +134,7 @@ def get_run_onnx(onnx_model):
|
||||
if n.op_type == "Add": ret = inp[0] + inp[1]
|
||||
if n.op_type == "Sub": ret = inp[0] - inp[1]
|
||||
if n.op_type == "Mul": ret = inp[0] * inp[1]
|
||||
if n.op_type == "Pow": ret = inp[0] ** inp[1]
|
||||
elif n.op_type == "Split":
|
||||
if 'split' not in opt: opt['split'] = [int(x) for x in safe_numpy(inp[1])] # split can be a tensor
|
||||
i = 0
|
||||
|
||||
@@ -80,19 +80,30 @@ def Expand(input, shape):
|
||||
return input.reshape(x_shape) #.expand(shape_ret)
|
||||
|
||||
def Identity(input): return input
|
||||
def Neg(input): return -input
|
||||
def Sqrt(input): return input.sqrt()
|
||||
def Sign(input): return input.sign()
|
||||
def Abs(input): return input.abs()
|
||||
def Exp(input): return input.exp()
|
||||
def Log(input): return input.log()
|
||||
def Mish(input): return input.mish()
|
||||
def HardSigmoid(input, alpha=0.2, beta=0.5): return (alpha*input + beta).clip(0, 1)
|
||||
def HardSwish(input): return input * HardSigmoid(input, 1/6, 0.5)
|
||||
def Selu(X, alpha=1.67326319217681884765625, gamma=1.05070102214813232421875): return gamma * (X.relu() - (-alpha*X.exp()+alpha).relu())
|
||||
def Softplus(X): return X.softplus()
|
||||
def PRelu(X, slope): return X.leakyrelu(slope)
|
||||
def LeakyRelu(X, alpha=0.01): return X.leakyrelu(alpha)
|
||||
def Softmax(input, axis=-1): return input.softmax(axis)
|
||||
def LogSoftmax(input, axis=-1): return input.log_softmax(axis)
|
||||
def Clip(input, min=-3.4e38, max=3.4e38): return input.clip(min, max)
|
||||
|
||||
def _axes(axes, noop_with_empty_axes): return [int(x) for x in safe_numpy(axes)] if axes is not None else ([] if noop_with_empty_axes else None)
|
||||
|
||||
# ReduceProd would require a new llop
|
||||
def ReduceMax(data, axes=None, keepdims=1, noop_with_empty_axes=0): return data.max(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
|
||||
def ReduceMin(data, axes=None, keepdims=1, noop_with_empty_axes=0): return data.min(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
|
||||
def ReduceSum(data, axes=None, keepdims=1, noop_with_empty_axes=0): return data.sum(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
|
||||
def ReduceMean(data, axes=None, keepdims=1, noop_with_empty_axes=0): return data.mean(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
|
||||
def ReduceSumSquare(data, axes=None, keepdims=1, noop_with_empty_axes=0): return data.square().sum(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
|
||||
def ReduceL1(data, axes=None, keepdims=1, noop_with_empty_axes=0): return data.abs().sum(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
|
||||
def ReduceL2(data, axes=None, keepdims=1, noop_with_empty_axes=0): return data.square().sum(_axes(axes, noop_with_empty_axes), keepdim=keepdims).sqrt()
|
||||
|
||||
@@ -37,6 +37,83 @@ class TinygradBackend(Backend):
|
||||
|
||||
backend_test = onnx.backend.test.BackendTest(TinygradBackend, __name__)
|
||||
|
||||
# no binaryops min or max (needs llop, should add and replace relu)
|
||||
backend_test.exclude('test_min_*')
|
||||
backend_test.exclude('test_max_*')
|
||||
|
||||
# add support for SoftmaxCrossEntropyLoss and NegativeLogLikelihoodLoss
|
||||
backend_test.exclude('test_sce_*')
|
||||
|
||||
# we only support float32
|
||||
backend_test.exclude('test_add_uint8_*')
|
||||
backend_test.exclude('test_div_uint8_*')
|
||||
backend_test.exclude('test_cast_*')
|
||||
backend_test.exclude('test_castlike_*')
|
||||
|
||||
# no support for nan or inf
|
||||
backend_test.exclude('test_isinf_*')
|
||||
backend_test.exclude('test_isnan_*')
|
||||
|
||||
# no support for mod
|
||||
backend_test.exclude('test_mod_*')
|
||||
|
||||
# no trig ops
|
||||
backend_test.exclude('test_acos_*')
|
||||
backend_test.exclude('test_acosh_*')
|
||||
backend_test.exclude('test_asin_*')
|
||||
backend_test.exclude('test_asinh_*')
|
||||
backend_test.exclude('test_atan_*')
|
||||
backend_test.exclude('test_atanh_*')
|
||||
backend_test.exclude('test_cos_*')
|
||||
backend_test.exclude('test_cosh_*')
|
||||
backend_test.exclude('test_sin_*')
|
||||
backend_test.exclude('test_sinh_*')
|
||||
backend_test.exclude('test_tan_*')
|
||||
|
||||
# no boolean ops (2d, 3d, 4d)
|
||||
backend_test.exclude('test_and*')
|
||||
backend_test.exclude('test_xor*')
|
||||
backend_test.exclude('test_or*')
|
||||
backend_test.exclude('test_bitshift_*')
|
||||
|
||||
# no scatter gather
|
||||
backend_test.exclude('test_gather_*')
|
||||
backend_test.exclude('test_gathernd_*')
|
||||
backend_test.exclude('test_scatter_*')
|
||||
backend_test.exclude('test_scatternd_*')
|
||||
|
||||
# unsupported (strange) ops
|
||||
backend_test.exclude('test_adagrad_*')
|
||||
backend_test.exclude('test_adam_*')
|
||||
backend_test.exclude('test_argmax_*')
|
||||
backend_test.exclude('test_argmin_*')
|
||||
backend_test.exclude('test_bitwise_*')
|
||||
backend_test.exclude('test_blackmanwindow_*')
|
||||
backend_test.exclude('test_bernoulli_*')
|
||||
backend_test.exclude('test_cumsum_*')
|
||||
backend_test.exclude('test_tril_*')
|
||||
backend_test.exclude('test_triu_*')
|
||||
backend_test.exclude('test_convinteger_*')
|
||||
backend_test.exclude('test_col2im_*')
|
||||
backend_test.exclude('test_hammingwindow_*')
|
||||
backend_test.exclude('test_hannwindow_*')
|
||||
backend_test.exclude('test_hardmax_*')
|
||||
backend_test.exclude('test_gru_*')
|
||||
backend_test.exclude('test_gridsample_*')
|
||||
backend_test.exclude('test_if_*')
|
||||
backend_test.exclude('test_compress_*')
|
||||
backend_test.exclude('test_dequantizelinear_*')
|
||||
backend_test.exclude('test_dynamicquantizelinear_*')
|
||||
backend_test.exclude('test_det_*')
|
||||
backend_test.exclude('test_dft_*')
|
||||
backend_test.exclude('test_einsum_*')
|
||||
backend_test.exclude('test_erf_*')
|
||||
backend_test.exclude('test_strnorm_*')
|
||||
backend_test.exclude('test_unique_*')
|
||||
backend_test.exclude('test_sequence_*')
|
||||
|
||||
backend_test.include('test_selu_*')
|
||||
|
||||
# the node tests
|
||||
#for x in backend_test.test_suite:
|
||||
# if 'OnnxBackendNodeModelTest' in str(type(x)):
|
||||
@@ -65,6 +142,7 @@ backend_test.include('test_softplus*')
|
||||
|
||||
# requires cast
|
||||
#backend_test.include('test_reduce_log_sum*')
|
||||
#backend_test.include('test_pow_*')
|
||||
|
||||
# almost passing node tests
|
||||
#backend_test.include('test_PReLU*')
|
||||
@@ -89,6 +167,8 @@ backend_test.include('test_tanh_*')
|
||||
# requires CastLike?
|
||||
#backend_test.include('test_relu_*')
|
||||
#backend_test.include('test_elu_*')
|
||||
#backend_test.include('test_leakyrelu_*')
|
||||
#backend_test.include('test_hardsigmoid_*')
|
||||
|
||||
# failing for lack of type support
|
||||
#backend_test.include('test_add_*')
|
||||
|
||||
Reference in New Issue
Block a user