diff --git a/test/test_onnx.py b/test/test_onnx.py index 5c26cdea77..7c5dc43453 100644 --- a/test/test_onnx.py +++ b/test/test_onnx.py @@ -6,6 +6,7 @@ import onnx from extra.utils import fetch from tinygrad.tensor import Tensor from tinygrad.helpers import prod +from tinygrad.nn import batch_normalize def run_onnx(onnx_model, inputs={}, debug=False): def shape_to_tuple(s): return tuple(x.dim_value for x in s.dim) @@ -52,47 +53,43 @@ def run_onnx(onnx_model, inputs={}, debug=False): tensors[inp.name] = Tensor(inputs[inp.name]) else: raise Exception(f"no data for {inp.name} with shape {shape}") - #print(f"filling {inp.name} shape {shape} with 0") - #tensors[inp.name] = Tensor.zeros(*shape) - for num,n in enumerate(onnx_model.graph.node): if debug: print(f"{num}: op {n.op_type}") inp = [tensors[x] for x in n.input] opt = attribute_to_dict(n.attribute) - if n.op_type == "Conv": + + # free ones + if n.op_type == "Relu": ret = inp[0].relu() + elif n.op_type == "Sigmoid": ret = inp[0].sigmoid() + elif n.op_type == "Tanh": ret = inp[0].tanh() + elif n.op_type == "Softmax": ret = inp[0].softmax() + # one liners + elif n.op_type == "Elu": ret = inp[0].elu(alpha=opt['alpha']) + elif n.op_type == "Clip": ret = inp[0].clip(*(inp[1:] if len(inp) > 1 else (opt['min'], opt['max']))) + elif n.op_type == "Concat": ret = inp[0].cat(*inp[1:], dim=opt['axis']) + elif n.op_type == "Flatten": ret = inp[0].flatten(opt['axis'] if 'axis' in opt else 0) + elif n.op_type == "Transpose": ret = inp[0].permute(order=opt['perm']) + elif n.op_type == "Squeeze": ret = inp[0].reshape([s for i,s in enumerate(inp[0].shape) if i not in opt['axes']]) + elif n.op_type == "GlobalAveragePool": ret = inp[0].mean(axis=tuple(range(2, len(inp[0].shape))), keepdim=True) + elif n.op_type == "BatchNormalization": ret = batch_normalize(inp[0], inp[1], inp[2], inp[3], inp[4], opt.get('epsilon', 1e-5)) + elif n.op_type == "MatMul": ret = inp[0].dot(inp[1]) + elif n.op_type == "Gemm": ret = inp[0].linear(inp[1].transpose() if opt.get('transB', 0) == 1 else inp[1], inp[2]) + elif n.op_type == "Conv": x,w,b = inp if len(inp) == 3 else (inp[0], inp[1], None) assert 'dilations' not in opt or opt['dilations'] == (1,1) # pads are in different order pads = (opt['pads'][0], opt['pads'][2], opt['pads'][1], opt['pads'][3]) ret = x.pad2d(pads).conv2d(w, b, stride=opt['strides'], groups=opt['group'] if 'group' in opt else 1) - elif n.op_type == "Elu": ret = inp[0].elu(alpha=opt['alpha']) - elif n.op_type == "Relu": ret = inp[0].relu() - elif n.op_type == "Sigmoid": ret = inp[0].sigmoid() - elif n.op_type == "Tanh": ret = inp[0].tanh() - elif n.op_type == "Softmax": ret = inp[0].softmax() elif n.op_type in ["Add", "Sub", "Mul"]: - # TODO: add this to tinygrad + # TODO: add this to tinygrad? i don't think it's in torch if len(inp[0].shape) != len(inp[1].shape) and prod(inp[0].shape) == prod(inp[1].shape): inp[1] = inp[1].reshape(inp[0].shape) # TODO: is this right? - if 'broadcast' in opt: - new_shape = [1 for x in range(len(inp[0].shape))] - new_shape[opt['broadcast']] = -1 - #print(inp[1].shape, new_shape) - inp[1] = inp[1].reshape(new_shape) + if 'broadcast' in opt: inp[1] = inp[1].reshape([-1 if i == opt['broadcast'] else 1 for i in range(len(inp[0].shape))]) if n.op_type == "Add": ret = inp[0] + inp[1] if n.op_type == "Sub": ret = inp[0] - inp[1] if n.op_type == "Mul": ret = inp[0] * inp[1] - elif n.op_type == "Flatten": ret = inp[0].flatten(opt['axis'] if 'axis' in opt else 0) - elif n.op_type == "Concat": ret = inp[0].cat(*inp[1:], dim=opt['axis']) - elif n.op_type == "Transpose": ret = inp[0].permute(order=opt['perm']) - elif n.op_type == "Squeeze": - ret = inp[0].reshape([s for i,s in enumerate(inp[0].shape) if i not in opt['axes']]) - elif n.op_type == "Clip": - if 'min' in opt and 'max' in opt: ret = inp[0].clip(opt['min'], opt['max']) - else: ret = inp[0].clip(inp[1], inp[2]) - elif n.op_type == "GlobalAveragePool": ret = inp[0].mean(axis=tuple(range(2, len(inp[0].shape))), keepdim=True) elif n.op_type == "Split": i = 0 arg = [(0,x) for x in inp[0].shape] @@ -101,16 +98,6 @@ def run_onnx(onnx_model, inputs={}, debug=False): tensors[o] = inp[0].slice(arg=arg) i = i+s continue - elif n.op_type in ["Gemm", "MatMul"]: - x,w,b = inp if len(inp) == 3 else (inp[0], inp[1], None) - #print(a.shape, w.shape, b.shape) - if 'transB' in opt and opt['transB'] == 1: w = w.transpose() - ret = x.dot(w) if b is None else x.linear(w,b) - elif n.op_type == "BatchNormalization": - from tinygrad.nn import batch_normalize - # does ONNX really specify a default eps? - #print(n) - ret = batch_normalize(inp[0], inp[1], inp[2], inp[3], inp[4], opt['epsilon'] if 'epsilon' in opt else 1e-5) elif n.op_type == "AveragePool": assert opt['kernel_shape'] == opt['strides'] or opt['strides'] == (1,1) ret = inp[0].avg_pool2d(opt['kernel_shape'])