From 7149eabb345de4713b79c8a42ee9e46290f0108f Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Tue, 29 Oct 2024 13:29:29 +0200 Subject: [PATCH] assert set equality in TestTensorMetadata [pr] (#7364) --- test/test_tensor.py | 37 ++++++++++++++++--------------------- 1 file changed, 16 insertions(+), 21 deletions(-) diff --git a/test/test_tensor.py b/test/test_tensor.py index 51e8002002..80cca781c0 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -709,41 +709,36 @@ class TestInferenceMode(unittest.TestCase): f(x, m, W) class TestTensorMetadata(unittest.TestCase): + def setUp(self) -> None: _METADATA.set(None) def test_matmul(self): - _METADATA.set(None) x = Tensor.rand(3, requires_grad=True) W = Tensor.rand(3, 3, requires_grad=True) out = x.matmul(W) self.assertEqual(out.lazydata.metadata.name, "matmul") - s = create_schedule([out.lazydata]) - self.assertEqual(len(s[-1].metadata), 1) - self.assertEqual(s[-1].metadata[0].name, "matmul") + si = create_schedule([out.lazydata])[-1] + self.assertEqual(len(si.metadata), 1) + self.assertEqual(si.metadata[0].name, "matmul") def test_relu(self): - _METADATA.set(None) x = Tensor.rand(3, requires_grad=True) out = x.relu() self.assertEqual(out.lazydata.metadata.name, "relu") - s = create_schedule([out.lazydata]) - self.assertEqual(len(s[-1].metadata), 1) - self.assertEqual(s[-1].metadata[0].name, "relu") + si = create_schedule([out.lazydata])[-1] + self.assertEqual(len(si.metadata), 1) + self.assertEqual(si.metadata[0].name, "relu") def test_complex(self): - _METADATA.set(None) x = Tensor.rand(3, requires_grad=True) y = Tensor.rand(3, requires_grad=True) out = x.relu() * y.sigmoid() self.assertEqual(out.lazydata.metadata.name, "__mul__") self.assertEqual(out.lazydata.srcs[0].metadata.name, "relu") self.assertEqual(out.lazydata.srcs[1].metadata.name, "sigmoid") - s = create_schedule([out.lazydata]) - self.assertEqual(len(s[-1].metadata), 3) - self.assertEqual(s[-1].metadata[0].name, "relu") - self.assertEqual(s[-1].metadata[1].name, "sigmoid") - self.assertEqual(s[-1].metadata[2].name, "__mul__") + si = create_schedule([out.lazydata])[-1] + self.assertEqual(len(si.metadata), 3) + self.assertEqual(set(m.name for m in si.metadata), {"relu", "sigmoid", "__mul__"}) def test_complex_backward(self): - _METADATA.set(None) x = Tensor.rand(3, requires_grad=True) y = Tensor.rand(3, requires_grad=True) out = (x.relu() * y.sigmoid()).sum() @@ -753,12 +748,12 @@ class TestTensorMetadata(unittest.TestCase): self.assertTrue(x.grad.lazydata.metadata.backward) self.assertEqual(y.grad.lazydata.metadata.name, "sigmoid") self.assertTrue(y.grad.lazydata.metadata.backward) - s = create_schedule([out.lazydata, x.grad.lazydata, y.grad.lazydata]) - self.assertEqual(len(s[-1].metadata), 3) - self.assertEqual(s[-1].metadata[0].name, "sigmoid") - self.assertEqual(s[-1].metadata[1].name, "sigmoid") - self.assertTrue(s[-1].metadata[1].backward) - self.assertEqual(s[-1].metadata[2].name, "relu") + si = create_schedule([out.lazydata, x.grad.lazydata, y.grad.lazydata])[-1] + self.assertEqual(len(si.metadata), 3) + self.assertEqual(set(m.name for m in si.metadata), {"sigmoid", "sigmoid", "relu"}) + bw = [m for m in si.metadata if m.backward] + self.assertEqual(len(bw), 1) + self.assertEqual(bw[0].name, "sigmoid") if __name__ == '__main__': unittest.main()