From 04846b91aadd3d52f2a128a397afe0b963bedebe Mon Sep 17 00:00:00 2001 From: geohotstan <135171913+geohotstan@users.noreply.github.com> Date: Fri, 24 Jan 2025 02:18:54 +0800 Subject: [PATCH] reorder and categorize onnx_ops (#8731) * new order * remove a todo * constant node is definitely requires_grad false * one new line spacing * property and graph * oops linter --- extra/onnx_ops.py | 643 ++++++++++++++++++++++------------------------ 1 file changed, 312 insertions(+), 331 deletions(-) diff --git a/extra/onnx_ops.py b/extra/onnx_ops.py index 4f745e680b..165de3154b 100644 --- a/extra/onnx_ops.py +++ b/extra/onnx_ops.py @@ -1,41 +1,15 @@ import functools, io, math from typing import cast, Literal from tinygrad.tensor import Tensor, _broadcast_shape, ConstType, ReductionStr -from tinygrad.dtype import ImageDType, dtypes +from tinygrad.dtype import ImageDType, dtypes, DType from tinygrad.helpers import prod, flatten, make_tuple from extra.onnx import dtype_parse, _cached_to_python_const import numpy as np -# **************** Free Ops **************** - +# ***** Property/Graph Ops ***** def Identity(x:Tensor): return x -# TODO: fix buffer_parse -def Add(x:Tensor, other:Tensor, broadcast=None, axis=None): return x + other if x.dtype == dtypes.float or isinstance(x.dtype, ImageDType) else (x + other).cast(x.dtype) -def Sub(x:Tensor|int, other:Tensor): return x - other # some test has input as int -def Less(x:Tensor,y:Tensor): return x < y -def LessOrEqual(x:Tensor,y:Tensor): return x <= y -def Greater(x:Tensor,y:Tensor): return x > y -def GreaterOrEqual(x:Tensor,y:Tensor): return x >= y -def Equal(x:Tensor,y:Tensor): return x == y -def BitwiseNot(x:Tensor): return ~x -def BitwiseOr(x:Tensor, y:Tensor): return x | y -def BitwiseAnd(x:Tensor, y:Tensor): return x & y -def BitwiseXor(x:Tensor, y:Tensor): return x ^ y -def Max(*data_0:Tensor): return functools.reduce(Tensor.maximum, data_0) -def Min(*data_0:Tensor): return functools.reduce(Tensor.minimum, data_0) -def Sum(*data_0:Tensor): return functools.reduce(Tensor.add, data_0) -def Mean(*data_0:Tensor): return Sum(*data_0) / len(data_0) -def Cast(x:Tensor, to:int, saturate:int=1): return x.cast(dtype_parse(to)) -def CastLike(x:Tensor, target_type:Tensor, saturate:int=1): return x.cast(target_type.dtype) - -# **************** Simple Ops **************** - -# https://github.com/onnx/onnx/blob/main/onnx/reference/ops/op_div.py -def Div(x:Tensor, other:Tensor): return (x/other).cast(x.dtype) - -def Constant(sparse_value:Tensor|None=None, value:Tensor|None=None, value_float:float|None=None, - value_floats:list[float]|None=None, value_int:int|None=None, value_ints:list[int]|None=None, - value_string:str|None=None, value_strings:list[str]|None=None): +def Constant(sparse_value:Tensor|None=None, value:Tensor|None=None, value_float:float|None=None, value_floats:list[float]|None=None, + value_int:int|None=None, value_ints:list[int]|None=None, value_string:str|None=None, value_strings:list[str]|None=None): if value is not None: return value if value_float is not None: return Tensor(value_float, dtype=dtypes.float32, requires_grad=False) if value_floats is not None: return Tensor(list(value_floats), dtype=dtypes.float32, requires_grad=False) @@ -44,21 +18,79 @@ def Constant(sparse_value:Tensor|None=None, value:Tensor|None=None, value_float: if value_string is not None or value_strings is not None and sparse_value is not None: raise NotImplementedError('Constant OP not implemented for value_string, value_strings and sparse_value') +def Range(start:float|int, limit:float|int, delta:float|int): return Tensor.arange(start=start, stop=limit, step=delta) + +def ImageDecoder(encoded_stream:bytes, pixel_format="RGB"): + try: import PIL.Image + except ImportError as e: raise ImportError("Pillow must be installed for the ImageDecoder operator") from e + img = PIL.Image.open(io.BytesIO(encoded_stream)) + if pixel_format == "BGR": return Tensor(np.array(img))[:, :, ::-1] + if pixel_format == "RGB": return Tensor(np.array(img)) + if pixel_format == "Grayscale": return Tensor(np.array(img.convert("L"))).unsqueeze(-1) # (H, W) to (H, W, 1) + raise ValueError(f"pixel_format={pixel_format!r} is not supported.") + +def EyeLike(x:Tensor, dtype:int|None=None, k:int=0): + ret = Tensor.eye(cast(int, min(x.shape)), dtype=dtype_parse(dtype) if dtype is not None else x.dtype) + return ret if x.size(0) == x.size(1) else ret.pad(tuple(None if d == ret.size(0) else (k, d-ret.shape[0]-k) for d in x.shape)) + +def OptionalHasElement(x:Tensor|None=None): return Tensor(x is not None and x.numel() > 0) +def OptionalGetElement(x:Tensor|None=None): return x if x is not None else Tensor([]) +def ConstantOfShape(shape:list[int], value:Tensor|None=None): + if value is None: value = Tensor(0, dtype=dtypes.float32) + return Tensor.ones(*shape, dtype=value.dtype) * (value if shape != [0] else 1) + +def Size(data:Tensor): return data.numel() +def Shape(data:Tensor, end:int|None=None, start:int=0): return Tensor(data.shape[start:end], dtype=dtypes.int64) + +# ***** Unary Ops (math) ***** +def Not(x:Tensor): return x.logical_not() +def Clip(x: Tensor, min:Tensor|None=None, max:Tensor|None=None): + return x.clip(float('-inf') if min is None else min, float('inf') if max is None else max).cast(x.dtype) + +# ***** Unary Ops (activation) ***** +def Softmax_1(x:Tensor, axis:int=1): return x.softmax(axis) +def Softmax_13(x:Tensor, axis:int=-1): return x.softmax(axis) +Softmax = {1:Softmax_1, 13:Softmax_13} def HardSigmoid(x:Tensor, alpha:float=0.2, beta:float=0.5): return (alpha*x + beta).clip(0, 1) def Gelu(x:Tensor, approximate:str|None=None): return x.gelu() if approximate == "tanh" else 0.5 * x * (1 + (x/math.sqrt(2)).erf()) +def FastGelu(x:Tensor, bias:Tensor|None=None): + # this is tanh approximated + return (x + bias).gelu() if bias is not None else x.gelu() # TODO: fix this def PRelu(X:Tensor, slope:Tensor): - slope = slope[0] if slope.shape[-1] != X.shape[-1] else slope # HACK OnnxBackendPyTorchConvertedModelTest HAS WEIRD SLOPE WHERE IT'S [0.25, 0.25, 0.25] FOR ANY X.SHAPE + slope = slope[0] if slope.shape[-1] != X.shape[-1] else slope return (X > 0).where(X, X * slope) def LeakyRelu(X:Tensor, alpha:float=0.01): return X.leakyrelu(alpha) def ThresholdedRelu(X:Tensor, alpha:float=1.0): return (X > alpha).where(X, 0) -def Softmax_1(x:Tensor, axis:int=1): return x.softmax(axis) -def Softmax_13(x:Tensor, axis:int=-1): return x.softmax(axis) -Softmax = {1: Softmax_1, 13: Softmax_13} # Softmax default axis changed def LogSoftmax(x: Tensor, axis:int=-1): return x.log_softmax(axis) -def Clip(x: Tensor, min:Tensor|None=None, max:Tensor|None=None): # noqa: A002 - return x.clip(float('-inf') if min is None else min, float('inf') if max is None else max).cast(x.dtype) +def Binarizer(x:Tensor, threshold:float=0.0): return (x > threshold).float() +# ***** Unary Ops (broadcasted) ***** +def Add(x:Tensor,y:Tensor, broadcast=None, axis=None): return x + y if x.dtype == dtypes.float or isinstance(x.dtype, ImageDType) else (x + y).cast(x.dtype) +def Sub(x:Tensor|int,y:Tensor): return x - y # some test has input as int +def Div(x:Tensor,y:Tensor): return (x/y).cast(x.dtype) +def Less(x:Tensor,y:Tensor): return x < y +def LessOrEqual(x:Tensor,y:Tensor): return x <= y +def Greater(x:Tensor,y:Tensor): return x > y +def GreaterOrEqual(x:Tensor,y:Tensor): return x >= y +def Equal(x:Tensor,y:Tensor): return x == y +def And(x:Tensor,y:Tensor): return (x==y).where(x, False) +def Or(x:Tensor,y:Tensor): return (x==y).where(x, True) +def BitwiseAnd(x:Tensor,y:Tensor): return x & y +def BitwiseOr(x:Tensor,y:Tensor): return x | y +def BitwiseXor(x:Tensor,y:Tensor): return x ^ y +def BitwiseNot(x:Tensor): return ~x + +# ***** Casting Ops ***** +# TODO: saturate +def Cast(x:Tensor, to:int, saturate:int=1): return x.cast(dtype_parse(to)) +def CastLike(x:Tensor, target_type:Tensor, saturate:int=1): return x.cast(target_type.dtype) + +# ***** Reduce Ops ***** +def Max(*data_0:Tensor): return functools.reduce(Tensor.maximum, data_0) +def Min(*data_0:Tensor): return functools.reduce(Tensor.minimum, data_0) +def Sum(*data_0:Tensor): return functools.reduce(Tensor.add, data_0) +def Mean(*data_0:Tensor): return Sum(*data_0) / len(data_0) def _axes(axes, noop_with_empty_axes): return axes or ([] if noop_with_empty_axes else None) def ReduceMax(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0): return data.max(_axes(axes, noop_with_empty_axes), keepdim=keepdims) @@ -80,27 +112,27 @@ def ReduceLogSum(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_wit return ReduceSum(data, axes, keepdims, noop_with_empty_axes).log() def ReduceLogSumExp(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0): return ReduceSum(data.exp(), axes, keepdims, noop_with_empty_axes).log() +def ArgMax(x:Tensor, axis:int=0, keepdims:int=1, select_last_index:int=0): + if select_last_index: return ((x.shape[axis]-1) - x.flip(axis).argmax(axis, keepdim=keepdims)).cast(dtypes.int64) + return x.argmax(axis, keepdim=keepdims).cast(dtypes.int64) +def ArgMin(x, axis:int=0, keepdims:int=1, select_last_index:int=0): + return ArgMax(-x, axis=axis, keepdims=keepdims, select_last_index=select_last_index) -def GlobalAveragePool(X:Tensor): return X.mean(axis=tuple(range(2, X.ndim)), keepdim=True) -def GlobalMaxPool(X:Tensor): return X.max(axis=tuple(range(2, X.ndim)), keepdim=True) -def OptionalHasElement(x:Tensor|None=None): return Tensor(x is not None and x.numel() > 0) -def OptionalGetElement(x:Tensor|None=None): return x if x is not None else Tensor([]) - -def Tile(x:Tensor, repeats:list[int]): return x.repeat(repeats) -def Range(start:float|int, limit:float|int, delta:float|int): return Tensor.arange(start=start, stop=limit, step=delta) -def Shape(data:Tensor, end:int|None=None, start:int=0): return Tensor(data.shape[start:end], dtype=dtypes.int64) -def Size(data:Tensor): return data.numel() -def Flatten(x:Tensor, axis:int=1): return x.reshape(prod(x.shape[0:axis]), -1) +# ***** Movement Ops ***** def Reshape(data:Tensor, shape:list[int], allowzero:int=0): return data.reshape([x if x != 0 else (0 if allowzero else data.shape[i]) for i,x in enumerate(shape)]) +def Flatten(x:Tensor, axis:int=1): return x.reshape(prod(x.shape[0:axis]), -1) def Expand(x:Tensor, shape:list[int]): return x.expand(_broadcast_shape(x.shape, tuple(shape))) def Shrink(x:Tensor, bias:float=0.0, lambd:float=0.5): return (x < -lambd)*(x+bias) + (x > lambd)*(x-bias) -def And(x:Tensor, y:Tensor): return (x==y).where(x, False) -def Or(x:Tensor, y:Tensor): return (x==y).where(x, True) -def Not(x:Tensor): return x.logical_not() +def Transpose(x:Tensor, perm:list[int]|None=None): return x.permute(order=list(range(x.ndim)[::-1]) if perm is None else perm) -def Trilu(x:Tensor, k:int=0, upper:int=1): return x.triu(k) if upper else x.tril(k) +# TODO: add test for when axes is None +def Squeeze(data:Tensor, axes:list[int]|None=None): + return data.squeeze() if axes is None else functools.reduce(lambda d, dim: d.squeeze(dim), sorted(axes, reverse=True), data) +def Unsqueeze(data:Tensor, axes:list[int]): return functools.reduce(lambda d, dim: d.unsqueeze(dim), sorted(axes), data) +def Tile(x:Tensor, repeats:list[int]): return x.repeat(repeats) +def Concat(*xs:Tensor, axis:int): return Tensor.cat(*xs, dim=axis) def Slice(data:Tensor, starts:list[int], ends:list[int], axes:list[int]|None=None, steps:list[int]|None=None): axes = axes or list(range(data.ndim)) steps = steps or [1]*data.ndim @@ -113,83 +145,9 @@ def Split(data:Tensor, split:list[int]|None=None, num_outputs:int=0, axis:int=0) if split is None: split = [sz // num_outputs + (1 if i < sz % num_outputs else 0) for i in range(num_outputs)] return data.split(split, axis) -# TODO: add test for when axes is None -def Squeeze(data:Tensor, axes:list[int]|None=None): - return data.squeeze() if axes is None else functools.reduce(lambda d, dim: d.squeeze(dim), sorted(axes, reverse=True), data) -def Unsqueeze(data:Tensor, axes:list[int]): return functools.reduce(lambda d, dim: d.unsqueeze(dim), sorted(axes), data) - -def Binarizer(x:Tensor, threshold:float=0.0): return (x > threshold).float() - -def ArgMax(x:Tensor, axis:int=0, keepdims:int=1, select_last_index:int=0): - if select_last_index: return ((x.shape[axis]-1) - x.flip(axis).argmax(axis, keepdim=keepdims)).cast(dtypes.int64) - return x.argmax(axis, keepdim=keepdims).cast(dtypes.int64) -def ArgMin(x, axis:int=0, keepdims:int=1, select_last_index:int=0): return ArgMax(-x, axis=axis, keepdims=keepdims, select_last_index=select_last_index) - -def Concat(*xs:Tensor, axis:int): return Tensor.cat(*xs, dim=axis) -def Transpose(x:Tensor, perm:list[int]|None=None): return x.permute(order=list(range(x.ndim)[::-1]) if perm is None else perm) - -def ConstantOfShape(shape:list[int], value:Tensor|None=None): - if value is None: value = Tensor(0, dtype=dtypes.float32) - return Tensor.ones(*shape, dtype=value.dtype) * (value if shape != [0] else 1) - -# **************** Complex Ops **************** - -def Gemm(A:Tensor, B:Tensor, C:Tensor|None=None, alpha:float=1.0, beta:float=1.0, transA:int=0, transB:int=0, broadcast=0): - ret = alpha * (A.transpose(transA) @ B.transpose(transB)) - if C is not None: ret = ret + beta * (C if broadcast == 0 else C.reshape([-1 if i < len(C.shape) else 1 for i in range(ret.ndim)][::-1])) - return ret - -def Einsum(*Inputs:list[Tensor], equation:str): return Tensor.einsum(equation, *Inputs) - -def CumSum(X:Tensor, axis:int|list, exclusive:int=0, reverse:int=0): - axis = X._resolve_dim(axis[0] if isinstance(axis, list) else axis) - if reverse: X = X.flip(axis) - if exclusive: X = X.pad(tuple((1,0) if i == axis else None for i in range(X.ndim)))\ - .shrink(tuple((0,X.shape[axis]) if i == axis else None for i in range(X.ndim))) - return X.cumsum(axis).flip(axis) if reverse else X.cumsum(axis) - -# TODO: this is copied from tinygrad/nn/__init__.py -# spatial is from opset 7 and has since been removed -def BatchNormalization(X:Tensor, scale:Tensor, B:Tensor, input_mean:Tensor, input_var:Tensor, epsilon:float=1e-05, momentum:float=0.9, - training_mode:int=0, spatial=1, is_test=0): - if training_mode: - x_detached = X.detach() - current_mean = x_detached.mean(axis=(0,2,3)) - y = (x_detached - current_mean.reshape(shape=[1, -1, 1, 1])) - current_var = (y*y).mean(axis=(0,2,3)) - current_invstd = current_var.add(epsilon).rsqrt() - - running_mean = input_mean * momentum + current_mean * (1 - momentum) - running_var = input_var * momentum + current_var * (1 - momentum) - - return X.batchnorm(scale, B, current_mean, current_invstd), running_mean, running_var - invstd = (input_var + epsilon).rsqrt() - return X.batchnorm(scale, B, input_mean, invstd) - -def InstanceNormalization(x:Tensor, scale:Tensor, bias:Tensor, epsilon:float=1e-05): - axis = tuple(range(2, x.ndim)) - mean = x.mean(axis=axis, keepdim=True) - invstd = x.sub(mean).square().mean(axis=axis, keepdim=True).add(epsilon).rsqrt() - return x.sub(mean).mul(scale.reshape(shape=[-1, 1, 1])).mul(invstd).add(bias.reshape(shape=[-1, 1, 1])) - -def LayerNormalization(x:Tensor, scale:Tensor, bias:Tensor, axis:int=-1, epsilon:float=1e-05, stash_type:int=1): - assert stash_type == 1, "only float32 is supported" - axes = tuple(i for i in range(axis if axis >= 0 else x.ndim + axis, x.ndim)) - mean = x.mean(axis=axes, keepdim=True) - return x.layernorm(axes, epsilon).mul(scale).add(bias), mean, (x.sub(mean)).square().mean(axis=axes, keepdim=True).add(epsilon).rsqrt() - -def GroupNormalization(x:Tensor, scale:Tensor, bias:Tensor, num_groups:int, epsilon:float=1e-05): - return x.reshape(x.shape[0], num_groups, -1).layernorm(axis=-1, eps=epsilon).mul(scale.unsqueeze(-1)).add(bias.unsqueeze(-1)).reshape(x.shape) - -# (padding_top, padding_left, ..., padding_bottom, padding_right, ...) -> (padding_left, padding_right, padding_top, padding_bottom, ...) -def _onnx_pads_to_tiny_pads(pads): return flatten(reversed([(pB,pA) for pB, pA in zip(pads, pads[len(pads)//2:])])) - -AUTO_PAD_OPTIONS = Literal["NOTSET", "SAME_UPPER", "SAME_LOWER", "VALID"] -# (padding_height, padding_width) -> (padding_top, padding_left, padding_bottom, padding_right) -def _auto_pad(pads, auto_pad: AUTO_PAD_OPTIONS): - if auto_pad == "SAME_UPPER": return [pads[i]//2 for i in range(len(pads))] + [pads[i]-pads[i]//2 for i in range(len(pads))] - return [pads[i]-pads[i]//2 for i in range(len(pads))] + [pads[i]//2 for i in range(len(pads))] - +def _onnx_pads_to_tiny_pads(pads): + # (padding_top, padding_left, ..., padding_bottom, padding_right, ...) -> (padding_left, padding_right, padding_top, padding_bottom, ...) + return tuple(flatten(reversed(list(zip(pads, pads[len(pads)//2:]))))) def Pad(x:Tensor, pads:list[int], constant_value:ConstType|None=None, axes:list[int]|None=None, mode:Literal["constant", "reflect", "edge", "wrap"]="constant", value=0): value = constant_value or value @@ -198,6 +156,21 @@ def Pad(x:Tensor, pads:list[int], constant_value:ConstType|None=None, axes:list[ for i,axis in enumerate(axes): real_pads[axis%x.ndim], real_pads[axis%x.ndim+x.ndim] = pads[i], pads[i+len(axes)] return x.pad(padding=_onnx_pads_to_tiny_pads(real_pads), mode={"edge":"replicate", "wrap":"circular"}.get(mode, mode), value=value) +def CenterCropPad(t:Tensor, shape:list[int], axes:list[int]|None=None): + shrink_arg:list[None|tuple[int,int]] = [None] * t.ndim + pad_arg:list[None|tuple[int,int]] = [None] * t.ndim + for s, x in zip(shape, axes or range(t.ndim)): + tx = t.shape[x] + if s < tx: shrink_arg[x] = (tx//2 - (s+1)//2, tx//2 + s//2) + elif s > tx: pad_arg[x] = ((s-tx)//2, (s-tx+1)//2) + return t.shrink(tuple(shrink_arg)).pad(tuple(pad_arg)) + +# ***** Processing Ops ***** +AUTO_PAD_OPTIONS = Literal["NOTSET", "SAME_UPPER", "SAME_LOWER", "VALID"] +def _auto_pad(pads, auto_pad: AUTO_PAD_OPTIONS): + # (padding_height, padding_width) -> (padding_top, padding_left, padding_bottom, padding_right) + if auto_pad == "SAME_UPPER": return [pads[i]//2 for i in range(len(pads))] + [pads[i]-pads[i]//2 for i in range(len(pads))] + return [pads[i]-pads[i]//2 for i in range(len(pads))] + [pads[i]//2 for i in range(len(pads))] def _resolve_pool_pads(x:Tensor, p_, k_, d_, s_, auto_pad:AUTO_PAD_OPTIONS): i_, (s_,d_,p_) = x.shape[-len(k_):], (make_tuple(x, len(k_)*2) for x in (s_, d_, p_)) if auto_pad == "NOTSET": return _onnx_pads_to_tiny_pads(p_ if len(p_)==len(k_)*2 else p_*2) @@ -206,25 +179,22 @@ def _resolve_pool_pads(x:Tensor, p_, k_, d_, s_, auto_pad:AUTO_PAD_OPTIONS): def AveragePool(X: Tensor, kernel_shape:list[int], auto_pad:AUTO_PAD_OPTIONS="NOTSET", ceil_mode:int=0, count_include_pad:int=0, dilations:list[int]|int=1, pads:list[int]|int=0, strides:list[int]|int=1): - pads = _resolve_pool_pads(X, pads, kernel_shape, dilations, strides, auto_pad) - return X.avg_pool2d(kernel_shape, strides, dilations, pads, ceil_mode=ceil_mode, count_include_pad=count_include_pad) + return X.avg_pool2d(kernel_shape, strides, dilations, _resolve_pool_pads(X, pads, kernel_shape, dilations, strides, auto_pad), + ceil_mode=ceil_mode, count_include_pad=count_include_pad) def MaxPool(X: Tensor, kernel_shape:list[int], auto_pad:AUTO_PAD_OPTIONS="NOTSET", ceil_mode:int=0, dilations:list[int]|int=1, pads:list[int]|int=0, storage_order:int=0, strides:list[int]|int=1): - pads = _resolve_pool_pads(X, pads, kernel_shape, dilations, strides, auto_pad) - ret = X.max_pool2d(kernel_shape, strides, dilations, pads, ceil_mode=ceil_mode) + ret = X.max_pool2d(kernel_shape, strides, dilations, _resolve_pool_pads(X, pads, kernel_shape, dilations, strides, auto_pad), ceil_mode=ceil_mode) # tests expect indices with int64 dtype # TODO: if there are repeated values, this is wrong indices = ((ret.reshape(-1, 1) == X.reshape(1, -1)) * Tensor.arange(X.numel(), dtype=dtypes.int64).unsqueeze(0)).sum(1).reshape(ret.shape) return ret.cast(X.dtype), indices.transpose(-2, -1) if storage_order else indices def Conv(X: Tensor, W: Tensor, B:Tensor|None=None, auto_pad:AUTO_PAD_OPTIONS="NOTSET", dilations:list[int]|int=1, group:int=1, - kernel_shape:list[int]|None=None, pads:list[int]|int=0, strides:list[int]|int=1): - pads = _resolve_pool_pads(X, pads, kernel_shape or W.shape[2:], dilations, strides, auto_pad) - return X.conv2d(W, B, stride=strides, groups=group, dilation=dilations, padding=tuple(pads)) + kernel_shape:list[int]|None=None, pads:list[int]|int=0, strides:list[int]|int=1): + return X.conv2d(W, B, stride=strides, groups=group, dilation=dilations, + padding=_resolve_pool_pads(X, pads, kernel_shape or W.shape[2:], dilations, strides, auto_pad)) -# src: https://github.com/onnx/onnx/blob/main/docs/Operators.md#ConvTranspose -# another src: https://github.com/onnx/onnx/blob/main/onnx/reference/ops/op_conv_transpose.py def ConvTranspose(X: Tensor, W: Tensor, B:Tensor|None=None, auto_pad:AUTO_PAD_OPTIONS="NOTSET", dilations:list[int]|int=1, group:int=1, kernel_shape:list[int]|None=None, pads:list[int]|None=None, output_shape:list[int]|None=None, output_padding:list[int]|int=0, strides:list[int]|int=1): @@ -247,80 +217,28 @@ def MaxUnpool(xT: Tensor, xI: Tensor, outshape: list[int]|None=None, kernel_shap if outshape is not None and outshape != ret.shape: pads = _auto_pad([outshape[-2] - ret.shape[-2], outshape[-1] - ret.shape[-1]], "SAME_UPPER") return ret.pad(_onnx_pads_to_tiny_pads(pads)) -def DepthToSpace(X:Tensor, blocksize:int, mode:str="DCR"): - return X.rearrange("b (c h1 w1) h w -> b c (h h1) (w w1)" if mode=="CRD" else "b (h1 w1 c) h w -> b c (h h1) (w w1)", h1=blocksize, w1=blocksize) -def SpaceToDepth(X:Tensor, blocksize:int): - return X.rearrange("b c (h h1) (w w1) -> b (h1 w1 c) h w", h1=blocksize, w1=blocksize) +def GlobalAveragePool(X:Tensor): return X.mean(axis=tuple(range(2, X.ndim)), keepdim=True) +def GlobalMaxPool(X:Tensor): return X.max(axis=tuple(range(2, X.ndim)), keepdim=True) -# Reimplemented here because you need legacy RNG for passing ONNX tests. -def Dropout_7(data:Tensor, ratio:float=0.5, training_mode:bool=False, seed:int|None=None): - if not training_mode: return data, Tensor.ones(data.shape, dtype=dtypes.bool) # if mask is requested as output it will contain all True's. - mask = Tensor(np.random.RandomState(seed).random(cast(tuple[int,...], data.shape)) >= ratio, requires_grad=False, device=data.device) - return data * mask * (1/(1.0 - ratio)), mask -# 6 with 'is_test' needed for https://github.com/MTlab/onnx2caffe/raw/refs/heads/master/model/MobileNetV2.onnx -def Dropout_6(data:Tensor, ratio:float=0.5, is_test=0): return Dropout_7(data, ratio, training_mode=not is_test) -Dropout = {6:Dropout_6, 7:Dropout_7} +def Gemm(A:Tensor, B:Tensor, C:Tensor|None=None, alpha:float=1.0, beta:float=1.0, transA:int=0, transB:int=0, broadcast=0): + ret = alpha * (A.transpose(transA) @ B.transpose(transB)) + if C is not None: ret = ret + beta * (C if broadcast == 0 else C.reshape([-1 if i < len(C.shape) else 1 for i in range(ret.ndim)][::-1])) + return ret -def LRN(x:Tensor, size:int, alpha:float=1e-4, beta:float=0.75, bias:float=1.0): - pooled_x = (x**2).rearrange('b c h w -> b 1 c (h w)').pad((0,0,(size-1)//2, size//2)).avg_pool2d((size, 1), 1) - return x / (pooled_x.reshape(x.shape) * alpha + bias).pow(beta) +def Einsum(*Inputs:list[Tensor], equation:str): return Tensor.einsum(equation, *Inputs) -def MeanVarianceNormalization(x:Tensor, axis:list[int]=[0,2,3]): - return (x - x.mean(axis, keepdim=True)) / (x.std(axis, keepdim=True, correction=0) + 1e-9) +def CumSum(X:Tensor, axis:int|list, exclusive:int=0, reverse:int=0): + axis = X._resolve_dim(axis[0] if isinstance(axis, list) else axis) + if reverse: X = X.flip(axis) + if exclusive: X = X.pad(tuple((1,0) if i == axis else None for i in range(X.ndim)))\ + .shrink(tuple((0,X.shape[axis]) if i == axis else None for i in range(X.ndim))) + return X.cumsum(axis).flip(axis) if reverse else X.cumsum(axis) -def NegativeLogLikelihoodLoss(x:Tensor, target:Tensor, weight:Tensor|None=None, ignore_index:int|None=None, reduction:ReductionStr="mean"): - return x.nll_loss(target, weight, ignore_index, reduction) - -def SoftmaxCrossEntropyLoss(scores:Tensor, labels:Tensor, weights:Tensor|None=None, ignore_index:int|None=None, reduction:ReductionStr="mean"): - log_probs = scores.log_softmax(1) - return log_probs.nll_loss(labels, weights, ignore_index, reduction), log_probs - -def ArrayFeatureExtractor(x:Tensor, indices:Tensor): return x[..., indices] - -def Gather(x:Tensor, indices:Tensor, axis:int=0): - if indices.numel() < 9: # NOTE lessor kernels for smaller indices but kernel number increases depending on size of indices - x_sh = list(x.shape) - ret_shape = x_sh[:axis] + list(indices.shape) + x_sh[axis+1:] - if indices.ndim > 1: indices = indices.flatten() - indices = [_cached_to_python_const(indices)] if indices.shape == () else [x_sh[axis]+x if x<0 else x for x in _cached_to_python_const(indices)] - args = [[(0,x) if j != axis else (i,i+1) for j, x in enumerate(x_sh)] for i in indices] # type: ignore - return x.shrink(arg=tuple(args[0])).cat(*[x.shrink(arg=tuple(arg)) for arg in args[1:]], dim=axis).reshape(ret_shape) - # NOTE faster gather, fixed number of kernels, but exceeds limited kernels for openpilot - return x[tuple([slice(None) if i != axis else indices for i in range(x.ndim)])] -def Scatter(*args, **kwargs): return ScatterElements(*args, **kwargs) # deprecated - -def GatherND(x:Tensor, indices:Tensor, batch_dims:int=0): - if batch_dims == 0: return x[tuple(i.squeeze(-1) for i in indices.split(1, -1))] - x_shape, i_shape = x.shape, indices.shape - b = math.prod(x.shape[dim] for dim in range(batch_dims)) - # NOTE: each batched dim of both input and indices are equal - x = x.reshape(b, *x.shape[batch_dims:]) - indices = indices.reshape(b, *indices.shape[batch_dims:]) - b_idx = Tensor.arange(b, device=x.device).reshape(b, *(1,)*(indices.ndim - 2)).expand(*indices.shape[:-1]) - ret = x[(b_idx,) + tuple(i.squeeze(-1) for i in indices.split(1, -1))] - return ret.reshape(*x_shape[:batch_dims], *i_shape[batch_dims:-1], *ret.shape[indices.ndim-1:]) -def ScatterND(x:Tensor, indices:Tensor, updates:Tensor, reduction:Literal["none", "add", "mul"]='none'): - assert updates.shape == indices.shape[:-1] + x.shape[cast(int, indices.shape[-1]):] - x = x.contiguous() - for index, u in zip(indices.split(1, 0), updates.split(1, 0)): - i = tuple(idx.squeeze(-1) for idx in index.squeeze(0).split(1, -1)) - u = u.squeeze(0) - if reduction == "none": x[i] = u - elif reduction == "add": x[i] += u - elif reduction == "mul": x[i] *= u - else: raise NotImplementedError("reduction doesn't support max or min") - return x - -def ScatterElements(x: Tensor, indices: Tensor, updates: Tensor, axis=0, reduction:Literal["none", "add", "mul"]="none"): - indices = (indices < 0).where(x.shape[axis], 0) + indices - return x.scatter(axis, indices, updates, {"none":None, "mul": "multiply"}.get(reduction, reduction)) -def GatherElements(x:Tensor, indices:Tensor, axis:int): - indices = (indices < 0).where(x.shape[axis], 0) + indices - return x.gather(axis, indices) +def Trilu(x:Tensor, k:int=0, upper:int=1): return x.triu(k) if upper else x.tril(k) def Resize(X:Tensor, roi:list[float]|None=None, scales:list[float]|None=None, sizes:list[int]|None=None, antialias:int=0, - axes:list[int]|None=None, coordinate_transformation_mode:str='half_pixel', cubic_coeff_a:float=-0.75, exclude_outside:int=0, - extrapolation_value:float=0.0, keep_aspect_ratio_policy:str='stretch', mode:str='nearest', nearest_mode:str='round_prefer_floor'): + axes:list[int]|None=None, coordinate_transformation_mode:str='half_pixel', cubic_coeff_a:float=-0.75, exclude_outside:int=0, + extrapolation_value:float=0.0, keep_aspect_ratio_policy:str='stretch', mode:str='nearest', nearest_mode:str='round_prefer_floor'): def _apply_nearest_mode(index: Tensor, input_dim, mode: str): if mode == "round_prefer_floor": index = (index - 0.5).ceil() elif mode == "round_prefer_ceil": index = (index + 0.5).floor() @@ -377,116 +295,43 @@ def Resize(X:Tensor, roi:list[float]|None=None, scales:list[float]|None=None, si X = X.gather(i, low).lerp(X.gather(i, high), perc) if mode == "cubic": raise NotImplementedError("cubic interpolation is not implemented") return X.permute(*[perm.index(i) for i in range(len(perm))]) if perm else X - -def CenterCropPad(t:Tensor, shape:list[int], axes:list[int]|None=None): - shrink_arg:list[None|tuple[int,int]] = [None] * t.ndim - pad_arg:list[None|tuple[int,int]] = [None] * t.ndim - for s, x in zip(shape, axes or range(t.ndim)): - tx = t.shape[x] - if s < tx: shrink_arg[x] = (tx//2 - (s+1)//2, tx//2 + s//2) - elif s > tx: pad_arg[x] = ((s-tx)//2, (s-tx+1)//2) - return t.shrink(tuple(shrink_arg)).pad(tuple(pad_arg)) - -def OneHot(indices:Tensor, depth:float|int|list, values:Tensor, axis:int=-1): - # Scalar or Rank 1 tensor containing exactly one element - depth = int(depth[0] if isinstance(depth, list) else depth) - indices = (indices < 0).where(indices+depth, indices) - return indices[:, None]._one_hot_along_dim(depth, dim=axis).where(values[1], values[0]) - -def Compress(inp:Tensor, condition:list[bool], axis:int|None=None): - if axis is None: - inp = inp.flatten() - axis = 0 - if axis < 0: axis += inp.ndim - con = Tensor(np.arange(len(condition))[condition]) # no boolean indexing in Tensor - return inp[tuple(con if i == axis else slice(None) for i in range(inp.ndim))] - -def EyeLike(x:Tensor, dtype:int|None=None, k:int=0): - ret = Tensor.eye(cast(int, min(x.shape)), dtype=dtype_parse(dtype) if dtype is not None else x.dtype) - return ret if x.size(0) == x.size(1) else ret.pad(tuple(None if d == ret.size(0) else (k, d-ret.shape[0]-k) for d in x.shape)) - def Upsample(X, scales, mode): return Resize(X=X, scales=scales, mode=mode) # deprecated -def _prepare_quantize(x, scale, zero_point, axis=1, block_size=0): - if axis < 0: axis += x.ndim - if not isinstance(zero_point, Tensor): zero_point = Tensor(zero_point, dtype=dtypes.uint8)._broadcast_to(scale.shape) - if block_size == 0: - shape = (*[1]*axis, *scale.shape, *[1]*(x.ndim - axis - scale.ndim)) - return scale.reshape(shape), zero_point.reshape(shape) - return scale.repeat_interleave(block_size, dim=axis), zero_point.repeat_interleave(block_size, dim=axis) +# ***** Neural Network Ops ***** +# TODO: try to factor out common implementations for these ops +# https://medium.com/@zljdanceholic/groupnorm-then-batchnorm-instancenorm-layernorm-e2b2a1d350a0 +def BatchNormalization(X:Tensor, scale:Tensor, B:Tensor, input_mean:Tensor, input_var:Tensor, epsilon:float=1e-05, momentum:float=0.9, + training_mode:int=0, spatial=1, is_test=0): + if training_mode: + x_detached = X.detach() + current_mean = x_detached.mean(axis=(0,2,3)) + y = (x_detached - current_mean.reshape(shape=[1, -1, 1, 1])) + current_var = (y*y).mean(axis=(0,2,3)) + current_invstd = current_var.add(epsilon).rsqrt() -def QuantizeLinear(x:Tensor, y_scale:Tensor, y_zero_point:Tensor|int=0, axis:int=1, block_size:int=0, output_dtype:int=0, saturate=1): - out_dtype = y_zero_point.dtype if isinstance(y_zero_point, Tensor) else dtype_parse(output_dtype) if output_dtype else dtypes.uint8 - y_scale, y_zero_point = _prepare_quantize(x, y_scale, y_zero_point, axis, block_size) - return ((x / y_scale).round() + y_zero_point).clamp(dtypes.min(out_dtype), dtypes.max(out_dtype)).cast(out_dtype).contiguous() - -def DequantizeLinear(x:Tensor, x_scale:Tensor, x_zero_point:Tensor|int=0, axis:int=1, block_size:int=0): - x_scale, x_zero_point = _prepare_quantize(x, x_scale, x_zero_point, axis, block_size) - return ((x.int() - x_zero_point) * x_scale).cast(x_scale.dtype) - -def _quantize_linear(y:Tensor, y_scale:Tensor, y_zero_point:Tensor): - assert y_scale.dtype is dtypes.float32 and y_zero_point.dtype in {dtypes.uint8, dtypes.int8}, "used only for qlinear ops" - y = (y / y_scale + y_zero_point).round() - return y.clamp(dtypes.min(y_zero_point.dtype), dtypes.max(y_zero_point.dtype)).cast(y_zero_point.dtype) - -def QLinearConv(x:Tensor, x_scale:Tensor, x_zero_point:Tensor|int, w:Tensor, w_scale:Tensor, w_zero_point:Tensor|int, y_scale:Tensor, - y_zero_point: Tensor|int, B:Tensor|None=None, auto_pad:AUTO_PAD_OPTIONS="NOTSET", dilations:int|list[int]=1, group:int=1, - kernel_shape:list[int]|None=None, pads:int|list[int]=0, strides:int|list[int]=1): - x = x.int() - x_zero_point - w = w.int() - w_zero_point - y = Conv(x, w, B, auto_pad, dilations, group, kernel_shape, pads, strides) - y_scale = y_scale / (x_scale * w_scale) - return _quantize_linear(y, y_scale, y_zero_point) - -def QLinearMatMul(a:Tensor, a_scale:Tensor, a_zero_point:Tensor|int, b:Tensor, b_scale:Tensor, b_zero_point:Tensor|int, y_scale:Tensor, - y_zero_point:Tensor|int) -> Tensor: - a = a.int() - a_zero_point - b = b.int() - b_zero_point - y = Tensor.matmul(a, b, acc_dtype=dtypes.int32) - y_scale = y_scale / (a_scale * b_scale) - return _quantize_linear(y, y_scale, y_zero_point) - -def ConvInteger(x: Tensor, w: Tensor, x_zero_point: Tensor | int = 0, w_zero_point: Tensor | int = 0, B: Tensor | None = None, - auto_pad: AUTO_PAD_OPTIONS = "NOTSET", dilations: int | list[int] = 1, group: int = 1, kernel_shape: list[int] | None = None, - pads: int | list[int] = 0, strides: int | list[int] = 1) -> Tensor: - x_int = x.int() - x_zero_point - w_int = w.int() - w_zero_point - return Conv(x_int, w_int, B, auto_pad, dilations, group, kernel_shape, pads, strides) - -def MatMulInteger(A: Tensor, B: Tensor, a_zero_point: Tensor | int = 0, b_zero_point: Tensor | int = 0) -> Tensor: - A_int = A.int() - a_zero_point - B_int = B.int() - b_zero_point - return Tensor.matmul(A_int, B_int, acc_dtype=dtypes.int32) - -# copied from https://github.com/onnx/onnx/blob/main/onnx/reference/ops/op_image_decoder.py -def ImageDecoder(encoded_stream:bytes, pixel_format="RGB"): - try: import PIL.Image - except ImportError as e: raise ImportError("Pillow must be installed to use the reference implementation of the ImageDecoder operator") from e - img = PIL.Image.open(io.BytesIO(encoded_stream)) - if pixel_format == "BGR": return Tensor(np.array(img))[:, :, ::-1] - if pixel_format == "RGB": return Tensor(np.array(img)) - if pixel_format == "Grayscale": return Tensor(np.array(img.convert("L"))).unsqueeze(-1) # (H, W) to (H, W, 1) - raise ValueError(f"pixel_format={pixel_format!r} is not supported.") - -def AffineGrid(theta:Tensor, size:list[int], align_corners:int=0): - N, _, *spatial_dims = size - def generate_grid(steps): - return Tensor.linspace(-1, 1, steps, device=theta.device) if align_corners else Tensor.linspace(-1+1/steps, 1-1/steps, steps, device=theta.device) - grids = Tensor.meshgrid(*(generate_grid(d) for d in spatial_dims)) - base_grid = Tensor.stack(*reversed(grids), Tensor.ones_like(grids[0], device=theta.device), dim=-1) - base_grid = base_grid.reshape(1, prod(spatial_dims), len(grids)+1).expand(N, -1, -1) - return (base_grid @ theta.transpose(1, 2)).reshape(N, *spatial_dims, -1) - -# **************** com.microsoft Ops **************** + running_mean = input_mean * momentum + current_mean * (1 - momentum) + running_var = input_var * momentum + current_var * (1 - momentum) + return X.batchnorm(scale, B, current_mean, current_invstd), running_mean, running_var + invstd = (input_var + epsilon).rsqrt() + return X.batchnorm(scale, B, input_mean, invstd) +def InstanceNormalization(x:Tensor, scale:Tensor, bias:Tensor, epsilon:float=1e-05): + axis = tuple(range(2, x.ndim)) + mean = x.mean(axis=axis, keepdim=True) + invstd = x.sub(mean).square().mean(axis=axis, keepdim=True).add(epsilon).rsqrt() + return x.sub(mean).mul(scale.reshape(shape=[-1, 1, 1])).mul(invstd).add(bias.reshape(shape=[-1, 1, 1])) +def LayerNormalization(x:Tensor, scale:Tensor, bias:Tensor, axis:int=-1, epsilon:float=1e-05, stash_type:int=1): + assert stash_type == 1, "only float32 is supported" + axes = tuple(i for i in range(axis if axis >= 0 else x.ndim + axis, x.ndim)) + mean = x.mean(axis=axes, keepdim=True) + return x.layernorm(axes, epsilon).mul(scale).add(bias), mean, (x.sub(mean)).square().mean(axis=axes, keepdim=True).add(epsilon).rsqrt() +def GroupNormalization(x:Tensor, scale:Tensor, bias:Tensor, num_groups:int, epsilon:float=1e-05): + return x.reshape(x.shape[0], num_groups, -1).layernorm(axis=-1, eps=epsilon).mul(scale.unsqueeze(-1)).add(bias.unsqueeze(-1)).reshape(x.shape) +def MeanVarianceNormalization(x:Tensor, axis:list[int]=[0,2,3]): + return (x - x.mean(axis, keepdim=True)) / (x.std(axis, keepdim=True, correction=0) + 1e-9) def SkipLayerNormalization(x:Tensor, skip:Tensor, gamma:Tensor, beta:Tensor|None=None, bias:Tensor|None=None, epsilon:float=1e-12): x = x + skip + bias return x.layernorm(eps=epsilon) * gamma + beta, None, None, x - -def FastGelu(x:Tensor, bias:Tensor|None=None): - # this is tanh approximated - return (x + bias).gelu() if bias is not None else x.gelu() - def EmbedLayerNormalization(input_ids: Tensor, segment_ids:Tensor, word_embedding:Tensor, position_embedding:Tensor, segment_embedding:Tensor, gamma=None, beta=None, mask:Tensor|None=None, position_ids:Tensor|None=None, epsilon=1e-12, mask_index_type=0): @@ -513,6 +358,45 @@ def EmbedLayerNormalization(input_ids: Tensor, segment_ids:Tensor, word_embeddin out = embedding_sum.layernorm(eps=epsilon) * gamma + beta return out, None, embedding_sum +def OneHot(indices:Tensor, depth:float|int|list, values:Tensor, axis:int=-1): + # Scalar or Rank 1 tensor containing exactly one element + depth = int(depth[0] if isinstance(depth, list) else depth) + indices = (indices < 0).where(indices+depth, indices) + return indices[:, None]._one_hot_along_dim(depth, dim=axis).where(values[1], values[0]) + +def DepthToSpace(X:Tensor, blocksize:int, mode:str="DCR"): + return X.rearrange("b (c h1 w1) h w -> b c (h h1) (w w1)" if mode=="CRD" else "b (h1 w1 c) h w -> b c (h h1) (w w1)", h1=blocksize, w1=blocksize) +def SpaceToDepth(X:Tensor, blocksize:int): + return X.rearrange("b c (h h1) (w w1) -> b (h1 w1 c) h w", h1=blocksize, w1=blocksize) + +# Reimplemented here because you need legacy RNG for passing ONNX tests. +def Dropout_7(data:Tensor, ratio:float=0.5, training_mode:bool=False, seed:int|None=None): + if not training_mode: return data, Tensor.ones(data.shape, dtype=dtypes.bool) # if mask is requested as output it will contain all True's. + mask = Tensor(np.random.RandomState(seed).random(cast(tuple[int,...], data.shape)) >= ratio, requires_grad=False, device=data.device) + return data * mask * (1/(1.0 - ratio)), mask +# 6 with 'is_test' needed for https://github.com/MTlab/onnx2caffe/raw/refs/heads/master/model/MobileNetV2.onnx +def Dropout_6(data:Tensor, ratio:float=0.5, is_test=0): return Dropout_7(data, ratio, training_mode=not is_test) +Dropout = {6:Dropout_6, 7:Dropout_7} + +def LRN(x:Tensor, size:int, alpha:float=1e-4, beta:float=0.75, bias:float=1.0): + pooled_x = (x**2).rearrange('b c h w -> b 1 c (h w)').pad((0,0,(size-1)//2, size//2)).avg_pool2d((size, 1), 1) + return x / (pooled_x.reshape(x.shape) * alpha + bias).pow(beta) + +def NegativeLogLikelihoodLoss(x:Tensor, target:Tensor, weight:Tensor|None=None, ignore_index:int|None=None, reduction:ReductionStr="mean"): + return x.nll_loss(target, weight, ignore_index, reduction) +def SoftmaxCrossEntropyLoss(scores:Tensor, labels:Tensor, weights:Tensor|None=None, ignore_index:int|None=None, reduction:ReductionStr="mean"): + log_probs = scores.log_softmax(1) + return log_probs.nll_loss(labels, weights, ignore_index, reduction), log_probs + +def AffineGrid(theta:Tensor, size:list[int], align_corners:int=0): + N, _, *spatial_dims = size + def generate_grid(steps): + return Tensor.linspace(-1, 1, steps, device=theta.device) if align_corners else Tensor.linspace(-1+1/steps, 1-1/steps, steps, device=theta.device) + grids = Tensor.meshgrid(*(generate_grid(d) for d in spatial_dims)) + base_grid = Tensor.stack(*reversed(grids), Tensor.ones_like(grids[0], device=theta.device), dim=-1) + base_grid = base_grid.reshape(1, prod(spatial_dims), len(grids)+1).expand(N, -1, -1) + return (base_grid @ theta.transpose(1, 2)).reshape(N, *spatial_dims, -1) + def Attention(x:Tensor, weights, bias:Tensor, mask_index:Tensor|None=None, past:Tensor|None=None, relative_position_bias:Tensor|None=None, past_sequence_length:Tensor|None=None, do_rotary:int|None=None, mask_filter_value:float|None=None, num_heads:int|None=None, past_present_share_buffer:int|None=None, @@ -552,37 +436,132 @@ def Attention(x:Tensor, weights, bias:Tensor, mask_index:Tensor|None=None, past: out = attn(xq, xk, xv, mask_index).transpose(1, 2).reshape(bsz, seq_len, -1) return out, present if past is not None else out +# ***** Indexing Ops ***** +def ArrayFeatureExtractor(x:Tensor, indices:Tensor): return x[..., indices] + +def Gather(x:Tensor, indices:Tensor, axis:int=0): + if indices.numel() < 9: # NOTE lessor kernels for smaller indices but kernel number increases depending on size of indices + x_sh = list(x.shape) + ret_shape = x_sh[:axis] + list(indices.shape) + x_sh[axis+1:] + if indices.ndim > 1: indices = indices.flatten() + indices = [_cached_to_python_const(indices)] if indices.shape == () else [x_sh[axis]+x if x<0 else x for x in _cached_to_python_const(indices)] + args = [[(0,x) if j != axis else (i,i+1) for j, x in enumerate(x_sh)] for i in indices] # type: ignore + return x.shrink(arg=tuple(args[0])).cat(*[x.shrink(arg=tuple(arg)) for arg in args[1:]], dim=axis).reshape(ret_shape) + # NOTE faster gather, fixed number of kernels, but exceeds limited kernels for openpilot + return x[tuple([slice(None) if i != axis else indices for i in range(x.ndim)])] +def Scatter(*args, **kwargs): return ScatterElements(*args, **kwargs) # deprecated + +def GatherND(x:Tensor, indices:Tensor, batch_dims:int=0): + if batch_dims == 0: return x[tuple(i.squeeze(-1) for i in indices.split(1, -1))] + x_shape, i_shape = x.shape, indices.shape + b = math.prod(x.shape[dim] for dim in range(batch_dims)) + # NOTE: each batched dim of both input and indices are equal + x = x.reshape(b, *x.shape[batch_dims:]) + indices = indices.reshape(b, *indices.shape[batch_dims:]) + b_idx = Tensor.arange(b, device=x.device).reshape(b, *(1,)*(indices.ndim - 2)).expand(*indices.shape[:-1]) + ret = x[(b_idx,) + tuple(i.squeeze(-1) for i in indices.split(1, -1))] + return ret.reshape(*x_shape[:batch_dims], *i_shape[batch_dims:-1], *ret.shape[indices.ndim-1:]) +def ScatterND(x:Tensor, indices:Tensor, updates:Tensor, reduction:Literal["none", "add", "mul"]='none'): + assert updates.shape == indices.shape[:-1] + x.shape[cast(int, indices.shape[-1]):] + x = x.contiguous() + for index, u in zip(indices.split(1, 0), updates.split(1, 0)): + i = tuple(idx.squeeze(-1) for idx in index.squeeze(0).split(1, -1)) + u = u.squeeze(0) + if reduction == "none": x[i] = u + elif reduction == "add": x[i] += u + elif reduction == "mul": x[i] *= u + else: raise NotImplementedError("reduction doesn't support max or min") + return x + +def ScatterElements(x: Tensor, indices: Tensor, updates: Tensor, axis=0, reduction:Literal["none", "add", "mul"]="none"): + indices = (indices < 0).where(x.shape[axis], 0) + indices + return x.scatter(axis, indices, updates, {"none":None, "mul": "multiply"}.get(reduction, reduction)) +def GatherElements(x:Tensor, indices:Tensor, axis:int): + indices = (indices < 0).where(x.shape[axis], 0) + indices + return x.gather(axis, indices) + +def Compress(inp:Tensor, condition:list[bool], axis:int|None=None): + if axis is None: + inp = inp.flatten() + axis = 0 + if axis < 0: axis += inp.ndim + con = Tensor(np.arange(len(condition))[condition]) # no boolean indexing in Tensor + return inp[tuple(con if i == axis else slice(None) for i in range(inp.ndim))] + +# ***** Quantization Ops ***** +def _clamp_cast(x:Tensor, dtype:DType): return x.clamp(dtypes.min(dtype), dtypes.max(dtype)).cast(dtype) + +def _prepare_quantize(x, scale, zero_point, axis=1, block_size=0): + if axis < 0: axis += x.ndim + if not isinstance(zero_point, Tensor): zero_point = Tensor(zero_point, dtype=dtypes.uint8)._broadcast_to(scale.shape) + if block_size == 0: + shape = (*[1]*axis, *scale.shape, *[1]*(x.ndim - axis - scale.ndim)) + return scale.reshape(shape), zero_point.reshape(shape) + return scale.repeat_interleave(block_size, dim=axis), zero_point.repeat_interleave(block_size, dim=axis) + +def QuantizeLinear(x:Tensor, y_scale:Tensor, y_zero_point:Tensor|int=0, axis:int=1, block_size:int=0, output_dtype:int=0, saturate=1): + out_dtype = y_zero_point.dtype if isinstance(y_zero_point, Tensor) else dtype_parse(output_dtype) if output_dtype else dtypes.uint8 + y_scale, y_zero_point = _prepare_quantize(x, y_scale, y_zero_point, axis, block_size) + return _clamp_cast(((x / y_scale).round() + y_zero_point), out_dtype).contiguous() + +def DequantizeLinear(x:Tensor, x_scale:Tensor, x_zero_point:Tensor|int=0, axis:int=1, block_size:int=0): + x_scale, x_zero_point = _prepare_quantize(x, x_scale, x_zero_point, axis, block_size) + return ((x.int() - x_zero_point) * x_scale).cast(x_scale.dtype) + +def _op_integer(op, inputs:list[Tensor], zero_points:list[Tensor], **opts): + adjusted_inputs = [inp.int() - zp for inp, zp in zip(inputs, zero_points)] + return op(*adjusted_inputs, **opts) + +def _qlinearop_quantized(op, inputs:list[Tensor], zero_points:list[Tensor], scales:list[Tensor], out_scale:Tensor, out_zero_point:Tensor, **opts): + # op execution is done in quantized int + out = _op_integer(op, inputs, zero_points, **opts) + assert dtypes.is_int(out.dtype), "quantized op should've done math in int" + out_quantized = (out * prod(scales) / out_scale).round() + out_zero_point + return _clamp_cast(out_quantized, out_zero_point.dtype) + +def _qlinearop_float(op, inputs:list[Tensor], zero_points:list[Tensor], scales:list[Tensor], out_scale:Tensor, out_zero_point:Tensor, **opts): + # op execution is done in float32 + dequantized_inputs = [(inp.int() - zp) * scale for inp, zp, scale in zip(inputs, zero_points, scales)] + out = op(*dequantized_inputs, **opts) + assert dtypes.is_float(out.dtype), "op should've done math in float" + out_quantized = (out / out_scale).round() + out_zero_point + return _clamp_cast(out_quantized, out_zero_point.dtype) + +def QLinearConv(x:Tensor, x_scale:Tensor, x_zero_point:Tensor|int, w:Tensor, w_scale:Tensor, w_zero_point:Tensor|int, y_scale:Tensor, + y_zero_point: Tensor|int, B:Tensor|None=None, **opts): + return _qlinearop_quantized(Conv, [x,w], [x_zero_point,w_zero_point], [x_scale,w_scale], y_scale, y_zero_point, **{"B":B, **opts}) + +def QLinearMatMul(a:Tensor, a_scale:Tensor, a_zero_point:Tensor|int, b:Tensor, b_scale:Tensor, b_zero_point:Tensor|int, y_scale:Tensor, + y_zero_point:Tensor|int) -> Tensor: + return _qlinearop_quantized(Tensor.matmul, [a,b], [a_zero_point,b_zero_point], [a_scale,b_scale], y_scale, y_zero_point) + def QLinearAdd(a:Tensor, a_scale:Tensor, a_zero_point:Tensor, b:Tensor, b_scale:Tensor, b_zero_point:Tensor, c_scale:Tensor, c_zero_point:Tensor): - a = a.int() - a_zero_point - b = b.int() - b_zero_point - c = (a * a_scale + b * b_scale) - return _quantize_linear(c, c_scale, c_zero_point) + return _qlinearop_float(Tensor.add, [a,b], [a_zero_point,b_zero_point], [a_scale,b_scale], c_scale, c_zero_point) def QLinearGlobalAveragePool(X:Tensor, x_scale:Tensor, x_zero_point:Tensor, y_scale:Tensor, y_zero_point:Tensor, channels_last:int): - assert channels_last in {0, 1} - if channels_last == 1: X = X.permute(0, 2, 3, 1) - X = (X.int() - x_zero_point) * x_scale - y = GlobalAveragePool(X) - return _quantize_linear(y, y_scale, y_zero_point) + assert channels_last == 0, "unsure what this does" + return _qlinearop_float(GlobalAveragePool, [X], [x_zero_point], [x_scale], y_scale, y_zero_point) -# **************** ai.onnx.preview.training Ops **************** +def ConvInteger(x: Tensor, w: Tensor, x_zero_point: Tensor | int = 0, w_zero_point: Tensor | int = 0, B: Tensor | None = None, **opts) -> Tensor: + return _op_integer(Conv, [x,w], [x_zero_point,w_zero_point], **{"B":B, **opts}) + +def MatMulInteger(A: Tensor, B: Tensor, a_zero_point: Tensor | int = 0, b_zero_point: Tensor | int = 0) -> Tensor: + return _op_integer(Tensor.matmul, [A,B], [a_zero_point,b_zero_point]) + +# ***** Training Ops ***** # NOTE: onnx test coverage only covers `T==0` cases, so for all `T>0` this isn't tested # NOTE: onnx training ops actually don't need the state for optim, all the ops work in a functional way, but we still can reuse optim.py code - -from tinygrad.nn.optim import Adam as TinyAdam -from tinygrad.nn.optim import SGD - -def onnx_training(input_group_size): - def _decorator(func): - def __wrapper(R:Tensor, T:int, *inputs:Tensor, **kwargs): +def _onnx_training(input_group_size): + def __decorator(func): + def ___wrapper(R:Tensor, T:int, *inputs:Tensor, **kwargs): R = R.detach() groups = len(inputs) // input_group_size ret = [func(R, T, *inps, **kwargs) for inps in (inputs[i::groups] for i in range(groups))] return tuple(flatten(zip(*ret))) - return __wrapper - return _decorator + return ___wrapper + return __decorator -@onnx_training(3) +@_onnx_training(3) def Adagrad(R:Tensor, T:int, *inputs:Tensor, decay_factor:float=0.0, epsilon:float=0.0, norm_coefficient:float=0.0): X, G, H = (i.detach() for i in inputs) grad = norm_coefficient * X + G @@ -592,9 +571,10 @@ def Adagrad(R:Tensor, T:int, *inputs:Tensor, decay_factor:float=0.0, epsilon:flo X.assign(X.detach() - r * up) return [X, H] -@onnx_training(4) +@_onnx_training(4) def Adam(R:Tensor, T:int, *inputs:Tensor, alpha:float=0.9, beta:float=0.999, epsilon:float=0.0, norm_coefficient:float=0.0, - norm_coefficient_post:float=0.0): + norm_coefficient_post:float=0.0): + from tinygrad.nn.optim import Adam as TinyAdam X, G, V, H = inputs G, V, H = G.detach(), V.detach(), H.detach() # TODO we shouldn't need these detaches X.grad = norm_coefficient * X.detach() + G @@ -610,8 +590,9 @@ def Adam(R:Tensor, T:int, *inputs:Tensor, alpha:float=0.9, beta:float=0.999, eps X = (1 - norm_coefficient_post) * X return [X, V, H] -@onnx_training(3) +@_onnx_training(3) def Momentum(R:Tensor, T:int, *inputs:Tensor, alpha:float, beta:float, mode:str, norm_coefficient:float): + from tinygrad.nn.optim import SGD X, G, V = inputs G, V = G.detach(), V.detach() X.grad = (norm_coefficient * X.detach() + G) * (beta if T > 0 else 1) @@ -620,6 +601,6 @@ def Momentum(R:Tensor, T:int, *inputs:Tensor, alpha:float, beta:float, mode:str, opt.step() return [X, V] -def Gradient(*inputs:Tensor, y:str, intermediate_tensors:dict[str, Tensor], **__): +def Gradient(*inputs:Tensor, y:str, intermediate_tensors:dict[str, Tensor], **_): intermediate_tensors[y].backward() return tuple([t.grad for t in inputs])