mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-15 01:15:49 +08:00
simpler onnx
This commit is contained in:
@@ -6,6 +6,7 @@ import onnx
|
||||
from extra.utils import fetch
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import prod
|
||||
from tinygrad.nn import batch_normalize
|
||||
|
||||
def run_onnx(onnx_model, inputs={}, debug=False):
|
||||
def shape_to_tuple(s): return tuple(x.dim_value for x in s.dim)
|
||||
@@ -52,47 +53,43 @@ def run_onnx(onnx_model, inputs={}, debug=False):
|
||||
tensors[inp.name] = Tensor(inputs[inp.name])
|
||||
else:
|
||||
raise Exception(f"no data for {inp.name} with shape {shape}")
|
||||
#print(f"filling {inp.name} shape {shape} with 0")
|
||||
#tensors[inp.name] = Tensor.zeros(*shape)
|
||||
|
||||
|
||||
for num,n in enumerate(onnx_model.graph.node):
|
||||
if debug: print(f"{num}: op {n.op_type}")
|
||||
inp = [tensors[x] for x in n.input]
|
||||
opt = attribute_to_dict(n.attribute)
|
||||
if n.op_type == "Conv":
|
||||
|
||||
# free ones
|
||||
if n.op_type == "Relu": ret = inp[0].relu()
|
||||
elif n.op_type == "Sigmoid": ret = inp[0].sigmoid()
|
||||
elif n.op_type == "Tanh": ret = inp[0].tanh()
|
||||
elif n.op_type == "Softmax": ret = inp[0].softmax()
|
||||
# one liners
|
||||
elif n.op_type == "Elu": ret = inp[0].elu(alpha=opt['alpha'])
|
||||
elif n.op_type == "Clip": ret = inp[0].clip(*(inp[1:] if len(inp) > 1 else (opt['min'], opt['max'])))
|
||||
elif n.op_type == "Concat": ret = inp[0].cat(*inp[1:], dim=opt['axis'])
|
||||
elif n.op_type == "Flatten": ret = inp[0].flatten(opt['axis'] if 'axis' in opt else 0)
|
||||
elif n.op_type == "Transpose": ret = inp[0].permute(order=opt['perm'])
|
||||
elif n.op_type == "Squeeze": ret = inp[0].reshape([s for i,s in enumerate(inp[0].shape) if i not in opt['axes']])
|
||||
elif n.op_type == "GlobalAveragePool": ret = inp[0].mean(axis=tuple(range(2, len(inp[0].shape))), keepdim=True)
|
||||
elif n.op_type == "BatchNormalization": ret = batch_normalize(inp[0], inp[1], inp[2], inp[3], inp[4], opt.get('epsilon', 1e-5))
|
||||
elif n.op_type == "MatMul": ret = inp[0].dot(inp[1])
|
||||
elif n.op_type == "Gemm": ret = inp[0].linear(inp[1].transpose() if opt.get('transB', 0) == 1 else inp[1], inp[2])
|
||||
elif n.op_type == "Conv":
|
||||
x,w,b = inp if len(inp) == 3 else (inp[0], inp[1], None)
|
||||
assert 'dilations' not in opt or opt['dilations'] == (1,1)
|
||||
# pads are in different order
|
||||
pads = (opt['pads'][0], opt['pads'][2], opt['pads'][1], opt['pads'][3])
|
||||
ret = x.pad2d(pads).conv2d(w, b, stride=opt['strides'], groups=opt['group'] if 'group' in opt else 1)
|
||||
elif n.op_type == "Elu": ret = inp[0].elu(alpha=opt['alpha'])
|
||||
elif n.op_type == "Relu": ret = inp[0].relu()
|
||||
elif n.op_type == "Sigmoid": ret = inp[0].sigmoid()
|
||||
elif n.op_type == "Tanh": ret = inp[0].tanh()
|
||||
elif n.op_type == "Softmax": ret = inp[0].softmax()
|
||||
elif n.op_type in ["Add", "Sub", "Mul"]:
|
||||
# TODO: add this to tinygrad
|
||||
# TODO: add this to tinygrad? i don't think it's in torch
|
||||
if len(inp[0].shape) != len(inp[1].shape) and prod(inp[0].shape) == prod(inp[1].shape):
|
||||
inp[1] = inp[1].reshape(inp[0].shape)
|
||||
# TODO: is this right?
|
||||
if 'broadcast' in opt:
|
||||
new_shape = [1 for x in range(len(inp[0].shape))]
|
||||
new_shape[opt['broadcast']] = -1
|
||||
#print(inp[1].shape, new_shape)
|
||||
inp[1] = inp[1].reshape(new_shape)
|
||||
if 'broadcast' in opt: inp[1] = inp[1].reshape([-1 if i == opt['broadcast'] else 1 for i in range(len(inp[0].shape))])
|
||||
if n.op_type == "Add": ret = inp[0] + inp[1]
|
||||
if n.op_type == "Sub": ret = inp[0] - inp[1]
|
||||
if n.op_type == "Mul": ret = inp[0] * inp[1]
|
||||
elif n.op_type == "Flatten": ret = inp[0].flatten(opt['axis'] if 'axis' in opt else 0)
|
||||
elif n.op_type == "Concat": ret = inp[0].cat(*inp[1:], dim=opt['axis'])
|
||||
elif n.op_type == "Transpose": ret = inp[0].permute(order=opt['perm'])
|
||||
elif n.op_type == "Squeeze":
|
||||
ret = inp[0].reshape([s for i,s in enumerate(inp[0].shape) if i not in opt['axes']])
|
||||
elif n.op_type == "Clip":
|
||||
if 'min' in opt and 'max' in opt: ret = inp[0].clip(opt['min'], opt['max'])
|
||||
else: ret = inp[0].clip(inp[1], inp[2])
|
||||
elif n.op_type == "GlobalAveragePool": ret = inp[0].mean(axis=tuple(range(2, len(inp[0].shape))), keepdim=True)
|
||||
elif n.op_type == "Split":
|
||||
i = 0
|
||||
arg = [(0,x) for x in inp[0].shape]
|
||||
@@ -101,16 +98,6 @@ def run_onnx(onnx_model, inputs={}, debug=False):
|
||||
tensors[o] = inp[0].slice(arg=arg)
|
||||
i = i+s
|
||||
continue
|
||||
elif n.op_type in ["Gemm", "MatMul"]:
|
||||
x,w,b = inp if len(inp) == 3 else (inp[0], inp[1], None)
|
||||
#print(a.shape, w.shape, b.shape)
|
||||
if 'transB' in opt and opt['transB'] == 1: w = w.transpose()
|
||||
ret = x.dot(w) if b is None else x.linear(w,b)
|
||||
elif n.op_type == "BatchNormalization":
|
||||
from tinygrad.nn import batch_normalize
|
||||
# does ONNX really specify a default eps?
|
||||
#print(n)
|
||||
ret = batch_normalize(inp[0], inp[1], inp[2], inp[3], inp[4], opt['epsilon'] if 'epsilon' in opt else 1e-5)
|
||||
elif n.op_type == "AveragePool":
|
||||
assert opt['kernel_shape'] == opt['strides'] or opt['strides'] == (1,1)
|
||||
ret = inp[0].avg_pool2d(opt['kernel_shape'])
|
||||
|
||||
Reference in New Issue
Block a user