cleanup test_conv2d_ceildiv_edge_case [pr] (#10317)

This commit is contained in:
chenyu
2025-05-15 11:35:28 +08:00
committed by GitHub
parent 50d7162acd
commit f6cf25fce4

View File

@@ -1,9 +1,8 @@
import unittest
from tinygrad import Variable
from tinygrad import Tensor, Variable
from tinygrad.shape.shapetracker import View
from tinygrad.helpers import Context, GlobalCounters
from tinygrad.tensor import Tensor
from tinygrad.ops import UOp
from tinygrad.ops import sym_infer
from examples.gpt2 import Attention
import numpy as np
@@ -224,16 +223,16 @@ class TestSymbolicOps(unittest.TestCase):
@unittest.expectedFailure
def test_conv2d_ceildiv_edge_case(self):
def eval_uops(a): return a.sym_infer(dict(v.unbind() for v in a.vars()))
v = Variable('qwe', 11, 50_000).bind(39601)
x = Tensor.randn(1, 22, 39601).reshape(1, 22, v)
v = Variable('v', 11, 50_000)
val = 39601
x = Tensor.randn(1, 22, 39601).reshape(1, 22, v.bind(val))
weight = Tensor.randn(256, 22, 12)
result = x.conv2d(weight=weight, groups=1, stride=6, dilation=1, padding=(3, 3))
shape = tuple(eval_uops(i) if isinstance(i, UOp) else i for i in result.shape)
self.assertEqual(shape, (1, 256, 6600)) # fails if ceildiv is incorrect
var_val = {v: val}
shape = tuple(sym_infer(s, var_val) for s in result.shape)
self.assertEqual(shape, (1, 256, 6600)) # TODO: fails if ceildiv is incorrect
# TODO: test output is correct
if __name__ == '__main__':
unittest.main()