From 7fa5f29582ac2012be63014eea8741a863bfa7b6 Mon Sep 17 00:00:00 2001 From: chenyu Date: Thu, 10 Apr 2025 08:25:34 -0400 Subject: [PATCH] add test_embedding to test_softmax_fusion (#9832) --- test/test_softmax_fusion.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/test/test_softmax_fusion.py b/test/test_softmax_fusion.py index 9d8f264714..49dbac7aa4 100644 --- a/test/test_softmax_fusion.py +++ b/test/test_softmax_fusion.py @@ -61,6 +61,20 @@ class TestFuse(unittest.TestCase): c = (Tensor.rand(N,N)-0.5).realize() self._test_fuse(lambda a,b,c: a@b@c, a, b, c, atol=1e-5) + def test_embedding(self): + with Context(TRACK_MATCH_STATS=0, DEBUG=0): + vocab_sz = 123 + embed_sz = 16 + weight = (Tensor.rand(vocab_sz, embed_sz)-0.5).realize() + a = Tensor([1, 1, 2, 3]).realize() + def embedding(idx:Tensor): + arange = Tensor.arange(vocab_sz).unsqueeze(-1) + big_shp = idx.shape + (vocab_sz, embed_sz) + arange, vals = arange.expand(big_shp), weight.expand(big_shp) + idx = idx.reshape(idx.shape+(1, 1)).expand(big_shp) + return (arange == idx).mul(vals).sum(-2, dtype=vals.dtype) + self._test_fuse(embedding, a, atol=1e-5) + @unittest.skip("still broken") def test_flash_attention(self): BS = 4