From 74567c1958e9d06d5cc3243da7e9444dc201feff Mon Sep 17 00:00:00 2001 From: Sachith Shetty Date: Tue, 19 May 2026 11:16:33 -0700 Subject: [PATCH] fix: pass input device to ONNX helper internal tensors (#16242) * fix: pass input device to onnx methods internal tensors * test: onnx helper internal tensors use input device --- test/external/external_test_onnx_ops.py | 38 ++++++++++++++++++++++++- tinygrad/nn/onnx.py | 12 ++++---- 2 files changed, 43 insertions(+), 7 deletions(-) diff --git a/test/external/external_test_onnx_ops.py b/test/external/external_test_onnx_ops.py index 9f0f39912f..1156960d4d 100644 --- a/test/external/external_test_onnx_ops.py +++ b/test/external/external_test_onnx_ops.py @@ -4,7 +4,7 @@ from typing import Any import unittest, onnx, tempfile -from tinygrad import dtypes, Tensor +from tinygrad import dtypes, Tensor, Context from tinygrad.nn.onnx import OnnxRunner import numpy as np from extra.onnx_helpers import validate @@ -364,6 +364,19 @@ class TestMainOnnxOps(TestOnnxOps): inputs = {"data": np.random.randn(1, 1, 32, 32, 32).astype(np.half)*100} self.helper_test_single_op("ReduceL2", inputs, {}, ["reduced"]) + def test_same_device_as_input(self): + from tinygrad.nn.onnx import onnx_ops + EyeLike = onnx_ops["EyeLike"] + Shape = onnx_ops["Shape"] + Compress = onnx_ops["Compress"] + with Context(DEV="CPU"): + x = Tensor.arange(4, device="PYTHON").reshape(2,2) + self.assertEqual(EyeLike(x).device, x.device) + self.assertEqual(Shape(x).device, x.device) + out = Compress(x, [True, False, True, False]) + self.assertEqual(out.device, x.device) + self.assertEqual(out.tolist(), [0, 2]) + class TestTrainingOnnxOps(TestOnnxOps): # NOTE: ORT doesn't actually support training ops on cpu so we test using functions provided by onnx DOMAIN = AI_ONNX_PREVIEW_TRAINING_DOMAIN @@ -581,5 +594,28 @@ class TestContribOnnxOps(TestOnnxOps): outputs = ["C"] self.helper_test_single_op("QLinearGlobalAveragePool", inputs, attributes, outputs) + def test_same_device_as_input(self): + from tinygrad.nn.onnx import onnx_ops, OpSetId, Domain + EmbedLayerNormalization = onnx_ops["EmbedLayerNormalization"] + Attention = onnx_ops["Attention"] + with Context(DEV="CPU"): + input_ids = Tensor([[1, 2]], device="PYTHON", dtype=dtypes.int32) + segment_ids = Tensor([[0, 0]], device="PYTHON", dtype=dtypes.int32) + word = Tensor.ones(4, 3, device="PYTHON") + pos = Tensor.ones(5, 3, device="PYTHON") + seg = Tensor.ones(1, 3, device="PYTHON") + gamma, beta = Tensor.ones(3, device="PYTHON"), Tensor.zeros(3, device="PYTHON") + out, _, _ = EmbedLayerNormalization(input_ids, segment_ids, word, pos, seg, gamma, beta) + self.assertEqual(out.device, input_ids.device) + out.realize() + + attn = Attention[OpSetId(Domain.MICROSOFT_CONTRIB_OPS, 1)] + x = Tensor.ones(1, 2, 4, device="PYTHON") + w = Tensor.ones(4, 12, device="PYTHON") + mask = Tensor([2, 0], device="PYTHON", dtype=dtypes.int32) + out, _ = attn(x, w, mask_index=mask, num_heads=1, unidirectional=1) + self.assertEqual(out.device, x.device) + out.realize() + if __name__ == "__main__": unittest.main() diff --git a/tinygrad/nn/onnx.py b/tinygrad/nn/onnx.py index 0f2db2279a..2bb71c09ce 100644 --- a/tinygrad/nn/onnx.py +++ b/tinygrad/nn/onnx.py @@ -586,7 +586,7 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT raise ValueError(f"pixel_format={pixel_format!r} is not supported.") def EyeLike(x:Tensor, dtype:int|None=None, k:int=0): - ret = Tensor.eye(cast(int, min(x.shape)), dtype=OnnxDataType(dtype).to_dtype() if dtype is not None else x.dtype) + ret = Tensor.eye(cast(int, min(x.shape)), dtype=OnnxDataType(dtype).to_dtype() if dtype is not None else x.dtype, device=x.device) return ret if x.size(0) == x.size(1) else ret.pad(tuple(None if d == ret.size(0) else (k, d-ret.shape[0]-k) for d in x.shape)) def OptionalHasElement(x:Tensor|None=None): return Tensor(x is not None and x.numel() > 0) @@ -597,7 +597,7 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT return value.expand(shape) def Size(data:Tensor): return data.numel() - def Shape(data:Tensor, end:int|None=None, start:int=0): return Tensor(data.shape[start:end], dtype=dtypes.int64) + def Shape(data:Tensor, end:int|None=None, start:int=0): return Tensor(data.shape[start:end], dtype=dtypes.int64, device=data.device) # ***** Unary Ops (math) ***** def Not(x:Tensor): return x.logical_not() @@ -934,7 +934,7 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT return x.unsqueeze(-1).expand(*x.shape, vocab_size)._one_hot_along_dim(vocab_size) @ weight # bert embedding layer - if position_ids is None: position_ids = Tensor.arange(seq_length).unsqueeze(0).expand(*input_shape) + if position_ids is None: position_ids = Tensor.arange(seq_length, device=input_ids.device).unsqueeze(0).expand(*input_shape) wrd_embedding_res = embedding(input_ids, vocab_size, word_embedding) pos_embedding_res = embedding(position_ids, max_position_embeddings, position_embedding) @@ -1036,14 +1036,14 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT elif mask_index.shape[0] == 2*batch_size: end_positions = mask_index[:batch_size] start_positions = mask_index[batch_size:] - arange = Tensor.arange(seq_len).unsqueeze(0) + arange = Tensor.arange(seq_len, device=mask_index.device).unsqueeze(0) mask = (arange < end_positions.unsqueeze(1)) & (arange >= start_positions.unsqueeze(1)) else: raise NotImplementedError("mask_index with shape (3 * batch_size + 2) is not implemented") while mask.ndim < 4: mask = mask.unsqueeze(1) attn_scores = mask.where(attn_scores, mask_filter_value) if unidirectional: - causal_mask = Tensor.ones((seq_len, seq_len), dtype=dtypes.bool).tril() + causal_mask = Tensor.ones((seq_len, seq_len), dtype=dtypes.bool, device=attn_scores.device).tril() attn_scores = causal_mask.where(attn_scores, mask_filter_value) output = attn_scores.softmax(-1) @ v @@ -1199,7 +1199,7 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT inp = inp.flatten() axis = 0 axis = inp._resolve_dim(axis) - con = Tensor([i for i,cond in enumerate(condition) if cond]) # compress in python + con = Tensor([i for i,cond in enumerate(condition) if cond], device=inp.device) # compress in python return inp[tuple(con if i == axis else slice(None) for i in range(inp.ndim))] # ***** Quantization Ops *****