From 152ef7fc795856027871b3dee2c1e63f69b2ef57 Mon Sep 17 00:00:00 2001 From: chenyu Date: Sun, 14 Jan 2024 02:15:24 -0500 Subject: [PATCH] minor cleanups of onnx_ops (#3116) --- extra/onnx_ops.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/extra/onnx_ops.py b/extra/onnx_ops.py index cf028d2564..c4ea89a5a8 100644 --- a/extra/onnx_ops.py +++ b/extra/onnx_ops.py @@ -332,7 +332,7 @@ def Dropout(data: Tensor, ratio=0.5, training_mode=False, seed=None): if isinstance(training_mode, Tensor) and not training_mode.shape: training_mode = safe_numpy(training_mode) if not training_mode: return data, Tensor.ones(*data.shape, dtype=dtypes.bool) # if mask is requested as output it will contain all True's. rng = np.random.RandomState(seed) - ratio = ratio.lazydata.realize().toCPU()[0] if isinstance(ratio, Tensor) else ratio + ratio = ratio.item() if isinstance(ratio, Tensor) else ratio mask = Tensor((rng.random(data.shape) >= ratio), requires_grad=False, device=data.device) return data * mask * (1/(1.0 - ratio)), mask @@ -356,7 +356,7 @@ def NegativeLogLikelihoodLoss(x: Tensor, target: Tensor, weight=None, ignore_ind weight = (mask * weight).sum(axis=-1) if ignore_index is not None: cond = target == ignore_index - weight = cond.where(0, weight) if weight is not None else cond.where(Tensor.zeros(*target.shape), 1) + weight = cond.where(0, weight) if weight is not None else cond.where(0, 1) mask = target[:, None, :] == Tensor.arange(C).reshape([1, C] + [1]*(len(x.shape) -2)) loss = -(mask * x).sum(axis=1) * (1 if weight is None else weight) if reduction == "mean": return loss.mean() if weight is None else loss.sum() / weight.sum() @@ -368,14 +368,17 @@ def SoftmaxCrossEntropyLoss(scores: Tensor, labels: Tensor, weights=None, ignore if ignore_index is not None: labels = (labels == ignore_index).where(C+1, labels) mask = labels.unsqueeze(1) == Tensor.arange(C).reshape(1, C, *[1]*len(s_dimensions)) y = scores.log_softmax(axis=1) - if weights is not None: weights = weights.__getitem__(tuple([labels, *[slice(None)]*(weights.ndim-1)])) - loss = (mask * -y).sum(1) if weights is None else (mask * -y).sum(1) * weights + loss = (mask * -y).sum(1) + if weights is not None: + weights = weights[labels, ...] + loss = loss * weights if reduction == "mean": loss = loss.sum() / (loss == 0).where(0, 1).sum() if weights is None else loss.sum() / weights.sum() elif reduction == "sum": loss = loss.sum() return loss, y def ArrayFeatureExtractor(x: Tensor, indices: Tensor): - return x.__getitem__(tuple([slice(None) if i != (x.ndim-1) else indices for i in range(x.ndim)])) + return x[tuple([slice(None) if i != (x.ndim-1) else indices for i in range(x.ndim)])] + def Gather(x: Tensor, indices: Tensor, axis=0): if indices.numel() < 9: # NOTE lessor kernels for smaller indices but kernel number increases depending on size of indices x_sh = list(x.shape) @@ -385,7 +388,7 @@ def Gather(x: Tensor, indices: Tensor, axis=0): args = [[(0,x) if j != axis else (i,i+1) for j, x in enumerate(x_sh)] for i in indices] return x.shrink(arg=tuple(args[0])).cat(*[x.shrink(arg=tuple(arg)) for arg in args[1:]], dim=axis).reshape(ret_shape) # NOTE faster gather, fixed number of kernels, but exceeds limited kernels for openpilot - return x.__getitem__(tuple([slice(None) if i != axis else indices for i in range(x.ndim)])) + return x[tuple([slice(None) if i != axis else indices for i in range(x.ndim)])] def GatherElements(x: Tensor, indices: Tensor, axis): indices = (indices < 0).where(x.shape[axis], 0) + indices @@ -541,7 +544,7 @@ def Compress(inp: Tensor, condition: Tensor, axis=None): con_np = safe_numpy(condition) con = Tensor(np.arange(condition.shape[0])[con_np]) # no boolean indexing in Tensor - return inp.__getitem__(tuple([slice(None) if i != axis else con for i in range(inp.ndim)])) + return inp[tuple([slice(None) if i != axis else con for i in range(inp.ndim)])] def EyeLike(x: Tensor, dtype=None, k=0): if dtype is None: dtype = x.dtype