mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
14
test/external/external_test_onnx_ops.py
vendored
14
test/external/external_test_onnx_ops.py
vendored
@@ -282,11 +282,11 @@ class TestTrainingOnnxOps(TestOnnxOps):
|
||||
tiny_out = runner(inps)
|
||||
onnx_out = onnx_fxn(**inps, **opts)
|
||||
for (nm, t_out), o_out in zip(tiny_out.items(), onnx_out):
|
||||
np.testing.assert_allclose(t_out.numpy(), o_out, rtol=1e-3, atol=1e-6, err_msg=f"{nm} failed")
|
||||
np.testing.assert_allclose(t_out.numpy(), o_out, rtol=1e-6, atol=1e-6, err_msg=f"{nm} failed")
|
||||
|
||||
def test_adagrad_t_greater_than_zero(self):
|
||||
def test_adagrad_t(self):
|
||||
from onnx.backend.test.case.node.adagrad import apply_adagrad
|
||||
for t in [1, 3, 100]:
|
||||
for t in [0, 1, 3, 100]:
|
||||
inputs = {
|
||||
"r": np.array(0.01, dtype=np.float32),
|
||||
"t": np.array(t, dtype=np.int32),
|
||||
@@ -298,10 +298,10 @@ class TestTrainingOnnxOps(TestOnnxOps):
|
||||
outputs = ["X_out", "H_out"]
|
||||
self._validate_training("Adagrad", apply_adagrad, inputs, attributes, outputs)
|
||||
|
||||
def test_momentum_t_greater_than_zero(self):
|
||||
def test_momentum(self):
|
||||
from onnx.backend.test.case.node.momentum import apply_momentum, apply_nesterov
|
||||
for onnx_fxn, mode in ((apply_momentum, "standard"), (apply_nesterov, "nesterov")):
|
||||
for t in [1, 3, 100]:
|
||||
for t in [0, 1, 3, 100]:
|
||||
inputs = {
|
||||
"r": np.array(0.01, dtype=np.float32),
|
||||
"t": np.array(t, dtype=np.int32),
|
||||
@@ -313,9 +313,9 @@ class TestTrainingOnnxOps(TestOnnxOps):
|
||||
outputs = ["X_out", "V_out"]
|
||||
self._validate_training("Momentum", onnx_fxn, inputs, attributes, outputs)
|
||||
|
||||
def test_adam_t_greater_than_zero(self):
|
||||
def test_adam(self):
|
||||
from onnx.backend.test.case.node.adam import apply_adam
|
||||
for t in [1, 3, 100]:
|
||||
for t in [0, 1, 3, 100]:
|
||||
inputs = {
|
||||
"r": np.array(0.01, dtype=np.float32),
|
||||
"t": np.array(t, dtype=np.int32),
|
||||
|
||||
Reference in New Issue
Block a user