mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 08:28:55 +08:00
fix: pass input device to ONNX helper internal tensors (#16242)
* fix: pass input device to onnx methods internal tensors * test: onnx helper internal tensors use input device
This commit is contained in:
38
test/external/external_test_onnx_ops.py
vendored
38
test/external/external_test_onnx_ops.py
vendored
@@ -4,7 +4,7 @@
|
||||
|
||||
from typing import Any
|
||||
import unittest, onnx, tempfile
|
||||
from tinygrad import dtypes, Tensor
|
||||
from tinygrad import dtypes, Tensor, Context
|
||||
from tinygrad.nn.onnx import OnnxRunner
|
||||
import numpy as np
|
||||
from extra.onnx_helpers import validate
|
||||
@@ -364,6 +364,19 @@ class TestMainOnnxOps(TestOnnxOps):
|
||||
inputs = {"data": np.random.randn(1, 1, 32, 32, 32).astype(np.half)*100}
|
||||
self.helper_test_single_op("ReduceL2", inputs, {}, ["reduced"])
|
||||
|
||||
def test_same_device_as_input(self):
|
||||
from tinygrad.nn.onnx import onnx_ops
|
||||
EyeLike = onnx_ops["EyeLike"]
|
||||
Shape = onnx_ops["Shape"]
|
||||
Compress = onnx_ops["Compress"]
|
||||
with Context(DEV="CPU"):
|
||||
x = Tensor.arange(4, device="PYTHON").reshape(2,2)
|
||||
self.assertEqual(EyeLike(x).device, x.device)
|
||||
self.assertEqual(Shape(x).device, x.device)
|
||||
out = Compress(x, [True, False, True, False])
|
||||
self.assertEqual(out.device, x.device)
|
||||
self.assertEqual(out.tolist(), [0, 2])
|
||||
|
||||
class TestTrainingOnnxOps(TestOnnxOps):
|
||||
# NOTE: ORT doesn't actually support training ops on cpu so we test using functions provided by onnx
|
||||
DOMAIN = AI_ONNX_PREVIEW_TRAINING_DOMAIN
|
||||
@@ -581,5 +594,28 @@ class TestContribOnnxOps(TestOnnxOps):
|
||||
outputs = ["C"]
|
||||
self.helper_test_single_op("QLinearGlobalAveragePool", inputs, attributes, outputs)
|
||||
|
||||
def test_same_device_as_input(self):
|
||||
from tinygrad.nn.onnx import onnx_ops, OpSetId, Domain
|
||||
EmbedLayerNormalization = onnx_ops["EmbedLayerNormalization"]
|
||||
Attention = onnx_ops["Attention"]
|
||||
with Context(DEV="CPU"):
|
||||
input_ids = Tensor([[1, 2]], device="PYTHON", dtype=dtypes.int32)
|
||||
segment_ids = Tensor([[0, 0]], device="PYTHON", dtype=dtypes.int32)
|
||||
word = Tensor.ones(4, 3, device="PYTHON")
|
||||
pos = Tensor.ones(5, 3, device="PYTHON")
|
||||
seg = Tensor.ones(1, 3, device="PYTHON")
|
||||
gamma, beta = Tensor.ones(3, device="PYTHON"), Tensor.zeros(3, device="PYTHON")
|
||||
out, _, _ = EmbedLayerNormalization(input_ids, segment_ids, word, pos, seg, gamma, beta)
|
||||
self.assertEqual(out.device, input_ids.device)
|
||||
out.realize()
|
||||
|
||||
attn = Attention[OpSetId(Domain.MICROSOFT_CONTRIB_OPS, 1)]
|
||||
x = Tensor.ones(1, 2, 4, device="PYTHON")
|
||||
w = Tensor.ones(4, 12, device="PYTHON")
|
||||
mask = Tensor([2, 0], device="PYTHON", dtype=dtypes.int32)
|
||||
out, _ = attn(x, w, mask_index=mask, num_heads=1, unidirectional=1)
|
||||
self.assertEqual(out.device, x.device)
|
||||
out.realize()
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user