diff --git a/test/test_quantize_onnx.py b/test/test_quantize_onnx.py index 005d978902..b15cdf12ce 100644 --- a/test/test_quantize_onnx.py +++ b/test/test_quantize_onnx.py @@ -3,6 +3,7 @@ import numpy as np import unittest from dataclasses import replace from tinygrad import Tensor, Context, Device, dtypes +from tinygrad.helpers import RANGEIFY from tinygrad.uop.ops import Ops from tinygrad.codegen.opt import Opt, OptOps from tinygrad.engine.realize import CompiledRunner, ExecItem, lower_schedule_item, get_program @@ -93,7 +94,8 @@ class TestQuantizeOnnx(unittest.TestCase): X = Tensor(np.random.uniform(0, 255, size=(1, 32, 128, 128)).astype(np.uint8)) W = Tensor(np.random.uniform(0, 255, size=(64, 32, 1, 1)).astype(np.uint8)) out = X.conv2d(W, dtype=X.dtype) - opts = [Opt(op=OptOps.UPCAST, axis=1, arg=128), Opt(op=OptOps.UNROLL, axis=0, arg=4)] + # rangeify merges axis in a different order + opts = [Opt(op=OptOps.UPCAST, axis=0 if RANGEIFY else 1, arg=128), Opt(op=OptOps.UNROLL, axis=0, arg=4)] sexec(out, opts) def test_prequant_gemm(self):