mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
add test_embedding to test_softmax_fusion (#9832)
This commit is contained in:
@@ -61,6 +61,20 @@ class TestFuse(unittest.TestCase):
|
||||
c = (Tensor.rand(N,N)-0.5).realize()
|
||||
self._test_fuse(lambda a,b,c: a@b@c, a, b, c, atol=1e-5)
|
||||
|
||||
def test_embedding(self):
|
||||
with Context(TRACK_MATCH_STATS=0, DEBUG=0):
|
||||
vocab_sz = 123
|
||||
embed_sz = 16
|
||||
weight = (Tensor.rand(vocab_sz, embed_sz)-0.5).realize()
|
||||
a = Tensor([1, 1, 2, 3]).realize()
|
||||
def embedding(idx:Tensor):
|
||||
arange = Tensor.arange(vocab_sz).unsqueeze(-1)
|
||||
big_shp = idx.shape + (vocab_sz, embed_sz)
|
||||
arange, vals = arange.expand(big_shp), weight.expand(big_shp)
|
||||
idx = idx.reshape(idx.shape+(1, 1)).expand(big_shp)
|
||||
return (arange == idx).mul(vals).sum(-2, dtype=vals.dtype)
|
||||
self._test_fuse(embedding, a, atol=1e-5)
|
||||
|
||||
@unittest.skip("still broken")
|
||||
def test_flash_attention(self):
|
||||
BS = 4
|
||||
|
||||
Reference in New Issue
Block a user