diff --git a/test/test_softmax_fusion.py b/test/test_softmax_fusion.py index 01d874cc5d..6e05445dba 100644 --- a/test/test_softmax_fusion.py +++ b/test/test_softmax_fusion.py @@ -86,6 +86,18 @@ class TestFuse(unittest.TestCase): return (arange == idx).mul(vals).sum(-2, dtype=vals.dtype) self._test_fuse(embedding, a, atol=1e-5) + def test_attention_kernel_count(self): + wq = Tensor.empty(32, 32) + wk = Tensor.empty(32, 32) + wv = Tensor.empty(32, 32) + x = Tensor.empty(2, 100, 32) + q = (x @ wq).contiguous() + k = (x @ wk).contiguous() + v = (x @ wv).contiguous() + attn = q.scaled_dot_product_attention(k, v).fuse() + s = attn.schedule() + self.assertEqual(len(s), 4) # 3 matmul and 1 attention + def test_flash_attention(self): BS = 4 HEADS = 2 diff --git a/tinygrad/schedule/kernelize.py b/tinygrad/schedule/kernelize.py index fe70bd9540..0013f827ce 100644 --- a/tinygrad/schedule/kernelize.py +++ b/tinygrad/schedule/kernelize.py @@ -345,7 +345,7 @@ pm_fuse = PatternMatcher([ def do_fusion(x:UOp): found_contiguous = {} def gate_contiguous(x): - if is_contiguous:=(x.op is Ops.CONTIGUOUS): found_contiguous[x] = x.replace(src=(UOp(Ops.VIEW, arg=x.st),)) + if is_contiguous:=(x.op is Ops.CONTIGUOUS): found_contiguous[x] = x.replace(src=(UOp(Ops.VIEW, arg=x.st), UOp.unique())) return not is_contiguous x.toposort(gate=gate_contiguous) del gate_contiguous