add test_embedding to test_softmax_fusion (#9832)

This commit is contained in:
chenyu
2025-04-10 08:25:34 -04:00
committed by GitHub
parent 995d20673a
commit 7fa5f29582

View File

@@ -61,6 +61,20 @@ class TestFuse(unittest.TestCase):
c = (Tensor.rand(N,N)-0.5).realize()
self._test_fuse(lambda a,b,c: a@b@c, a, b, c, atol=1e-5)
def test_embedding(self):
with Context(TRACK_MATCH_STATS=0, DEBUG=0):
vocab_sz = 123
embed_sz = 16
weight = (Tensor.rand(vocab_sz, embed_sz)-0.5).realize()
a = Tensor([1, 1, 2, 3]).realize()
def embedding(idx:Tensor):
arange = Tensor.arange(vocab_sz).unsqueeze(-1)
big_shp = idx.shape + (vocab_sz, embed_sz)
arange, vals = arange.expand(big_shp), weight.expand(big_shp)
idx = idx.reshape(idx.shape+(1, 1)).expand(big_shp)
return (arange == idx).mul(vals).sum(-2, dtype=vals.dtype)
self._test_fuse(embedding, a, atol=1e-5)
@unittest.skip("still broken")
def test_flash_attention(self):
BS = 4