fix: pass input device to ONNX helper internal tensors (#16242)

* fix: pass input device to onnx methods internal tensors

* test: onnx helper internal tensors use input device
This commit is contained in:
Sachith Shetty
2026-05-19 11:16:33 -07:00
committed by GitHub
parent a178301dbe
commit 74567c1958
2 changed files with 43 additions and 7 deletions

View File

@@ -4,7 +4,7 @@
from typing import Any
import unittest, onnx, tempfile
from tinygrad import dtypes, Tensor
from tinygrad import dtypes, Tensor, Context
from tinygrad.nn.onnx import OnnxRunner
import numpy as np
from extra.onnx_helpers import validate
@@ -364,6 +364,19 @@ class TestMainOnnxOps(TestOnnxOps):
inputs = {"data": np.random.randn(1, 1, 32, 32, 32).astype(np.half)*100}
self.helper_test_single_op("ReduceL2", inputs, {}, ["reduced"])
def test_same_device_as_input(self):
from tinygrad.nn.onnx import onnx_ops
EyeLike = onnx_ops["EyeLike"]
Shape = onnx_ops["Shape"]
Compress = onnx_ops["Compress"]
with Context(DEV="CPU"):
x = Tensor.arange(4, device="PYTHON").reshape(2,2)
self.assertEqual(EyeLike(x).device, x.device)
self.assertEqual(Shape(x).device, x.device)
out = Compress(x, [True, False, True, False])
self.assertEqual(out.device, x.device)
self.assertEqual(out.tolist(), [0, 2])
class TestTrainingOnnxOps(TestOnnxOps):
# NOTE: ORT doesn't actually support training ops on cpu so we test using functions provided by onnx
DOMAIN = AI_ONNX_PREVIEW_TRAINING_DOMAIN
@@ -581,5 +594,28 @@ class TestContribOnnxOps(TestOnnxOps):
outputs = ["C"]
self.helper_test_single_op("QLinearGlobalAveragePool", inputs, attributes, outputs)
def test_same_device_as_input(self):
from tinygrad.nn.onnx import onnx_ops, OpSetId, Domain
EmbedLayerNormalization = onnx_ops["EmbedLayerNormalization"]
Attention = onnx_ops["Attention"]
with Context(DEV="CPU"):
input_ids = Tensor([[1, 2]], device="PYTHON", dtype=dtypes.int32)
segment_ids = Tensor([[0, 0]], device="PYTHON", dtype=dtypes.int32)
word = Tensor.ones(4, 3, device="PYTHON")
pos = Tensor.ones(5, 3, device="PYTHON")
seg = Tensor.ones(1, 3, device="PYTHON")
gamma, beta = Tensor.ones(3, device="PYTHON"), Tensor.zeros(3, device="PYTHON")
out, _, _ = EmbedLayerNormalization(input_ids, segment_ids, word, pos, seg, gamma, beta)
self.assertEqual(out.device, input_ids.device)
out.realize()
attn = Attention[OpSetId(Domain.MICROSOFT_CONTRIB_OPS, 1)]
x = Tensor.ones(1, 2, 4, device="PYTHON")
w = Tensor.ones(4, 12, device="PYTHON")
mask = Tensor([2, 0], device="PYTHON", dtype=dtypes.int32)
out, _ = attn(x, w, mask_index=mask, num_heads=1, unidirectional=1)
self.assertEqual(out.device, x.device)
out.realize()
if __name__ == "__main__":
unittest.main()