From 63a23dfe80da09e9d27c0e1629aba33289e995c0 Mon Sep 17 00:00:00 2001 From: chenyu Date: Sun, 19 Oct 2025 09:15:49 -0400 Subject: [PATCH] test step 0 in TestTrainingOnnxOps (#12790) and tighter rtol --- test/external/external_test_onnx_ops.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/test/external/external_test_onnx_ops.py b/test/external/external_test_onnx_ops.py index 0abb334c57..3e1cc9503f 100644 --- a/test/external/external_test_onnx_ops.py +++ b/test/external/external_test_onnx_ops.py @@ -282,11 +282,11 @@ class TestTrainingOnnxOps(TestOnnxOps): tiny_out = runner(inps) onnx_out = onnx_fxn(**inps, **opts) for (nm, t_out), o_out in zip(tiny_out.items(), onnx_out): - np.testing.assert_allclose(t_out.numpy(), o_out, rtol=1e-3, atol=1e-6, err_msg=f"{nm} failed") + np.testing.assert_allclose(t_out.numpy(), o_out, rtol=1e-6, atol=1e-6, err_msg=f"{nm} failed") - def test_adagrad_t_greater_than_zero(self): + def test_adagrad_t(self): from onnx.backend.test.case.node.adagrad import apply_adagrad - for t in [1, 3, 100]: + for t in [0, 1, 3, 100]: inputs = { "r": np.array(0.01, dtype=np.float32), "t": np.array(t, dtype=np.int32), @@ -298,10 +298,10 @@ class TestTrainingOnnxOps(TestOnnxOps): outputs = ["X_out", "H_out"] self._validate_training("Adagrad", apply_adagrad, inputs, attributes, outputs) - def test_momentum_t_greater_than_zero(self): + def test_momentum(self): from onnx.backend.test.case.node.momentum import apply_momentum, apply_nesterov for onnx_fxn, mode in ((apply_momentum, "standard"), (apply_nesterov, "nesterov")): - for t in [1, 3, 100]: + for t in [0, 1, 3, 100]: inputs = { "r": np.array(0.01, dtype=np.float32), "t": np.array(t, dtype=np.int32), @@ -313,9 +313,9 @@ class TestTrainingOnnxOps(TestOnnxOps): outputs = ["X_out", "V_out"] self._validate_training("Momentum", onnx_fxn, inputs, attributes, outputs) - def test_adam_t_greater_than_zero(self): + def test_adam(self): from onnx.backend.test.case.node.adam import apply_adam - for t in [1, 3, 100]: + for t in [0, 1, 3, 100]: inputs = { "r": np.array(0.01, dtype=np.float32), "t": np.array(t, dtype=np.int32),