mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 16:37:04 +08:00
cleanup test_conv2d_ceildiv_edge_case [pr] (#10317)
This commit is contained in:
@@ -1,9 +1,8 @@
|
||||
import unittest
|
||||
from tinygrad import Variable
|
||||
from tinygrad import Tensor, Variable
|
||||
from tinygrad.shape.shapetracker import View
|
||||
from tinygrad.helpers import Context, GlobalCounters
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.ops import UOp
|
||||
from tinygrad.ops import sym_infer
|
||||
from examples.gpt2 import Attention
|
||||
import numpy as np
|
||||
|
||||
@@ -224,16 +223,16 @@ class TestSymbolicOps(unittest.TestCase):
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_conv2d_ceildiv_edge_case(self):
|
||||
def eval_uops(a): return a.sym_infer(dict(v.unbind() for v in a.vars()))
|
||||
|
||||
v = Variable('qwe', 11, 50_000).bind(39601)
|
||||
x = Tensor.randn(1, 22, 39601).reshape(1, 22, v)
|
||||
v = Variable('v', 11, 50_000)
|
||||
val = 39601
|
||||
x = Tensor.randn(1, 22, 39601).reshape(1, 22, v.bind(val))
|
||||
weight = Tensor.randn(256, 22, 12)
|
||||
|
||||
result = x.conv2d(weight=weight, groups=1, stride=6, dilation=1, padding=(3, 3))
|
||||
shape = tuple(eval_uops(i) if isinstance(i, UOp) else i for i in result.shape)
|
||||
self.assertEqual(shape, (1, 256, 6600)) # fails if ceildiv is incorrect
|
||||
|
||||
var_val = {v: val}
|
||||
shape = tuple(sym_infer(s, var_val) for s in result.shape)
|
||||
self.assertEqual(shape, (1, 256, 6600)) # TODO: fails if ceildiv is incorrect
|
||||
# TODO: test output is correct
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user