diff --git a/test/test_symbolic_ops.py b/test/test_symbolic_ops.py index c78ad76a82..766c4c901b 100644 --- a/test/test_symbolic_ops.py +++ b/test/test_symbolic_ops.py @@ -1,9 +1,8 @@ import unittest -from tinygrad import Variable +from tinygrad import Tensor, Variable from tinygrad.shape.shapetracker import View from tinygrad.helpers import Context, GlobalCounters -from tinygrad.tensor import Tensor -from tinygrad.ops import UOp +from tinygrad.ops import sym_infer from examples.gpt2 import Attention import numpy as np @@ -224,16 +223,16 @@ class TestSymbolicOps(unittest.TestCase): @unittest.expectedFailure def test_conv2d_ceildiv_edge_case(self): - def eval_uops(a): return a.sym_infer(dict(v.unbind() for v in a.vars())) - - v = Variable('qwe', 11, 50_000).bind(39601) - x = Tensor.randn(1, 22, 39601).reshape(1, 22, v) + v = Variable('v', 11, 50_000) + val = 39601 + x = Tensor.randn(1, 22, 39601).reshape(1, 22, v.bind(val)) weight = Tensor.randn(256, 22, 12) result = x.conv2d(weight=weight, groups=1, stride=6, dilation=1, padding=(3, 3)) - shape = tuple(eval_uops(i) if isinstance(i, UOp) else i for i in result.shape) - self.assertEqual(shape, (1, 256, 6600)) # fails if ceildiv is incorrect - + var_val = {v: val} + shape = tuple(sym_infer(s, var_val) for s in result.shape) + self.assertEqual(shape, (1, 256, 6600)) # TODO: fails if ceildiv is incorrect + # TODO: test output is correct if __name__ == '__main__': unittest.main()