From 4c9a930de286f256f46065be4b07eb3e4de9a2ca Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Wed, 1 Oct 2025 09:59:19 +0800 Subject: [PATCH] rangeify attn tests (#12377) --- test/unit/test_attention.py | 7 +++++-- tinygrad/apps/llm.py | 3 ++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/test/unit/test_attention.py b/test/unit/test_attention.py index e6a3c487bd..5043f7335a 100644 --- a/test/unit/test_attention.py +++ b/test/unit/test_attention.py @@ -2,9 +2,11 @@ import unittest from tinygrad import Tensor, dtypes, TinyJit, UOp from tinygrad.helpers import RANGEIFY from tinygrad.apps.llm import apply_rope +#from tinygrad.engine.realize import run_schedule # TODO: test_scheduler, but just in uint class TestAttention(unittest.TestCase): + @unittest.skipIf(RANGEIFY > 0, "not half on rangeify") def test_half_qkv_buffers(self): BS, seqlen, dim = 10, 4, 100 q = Tensor.ones(BS, seqlen, dim, dtype=dtypes.half).contiguous().realize() @@ -12,11 +14,12 @@ class TestAttention(unittest.TestCase): v = Tensor.ones(BS, seqlen, dim, dtype=dtypes.half).contiguous().realize() attn = q.scaled_dot_product_attention(k, v) sched = attn.schedule() + #run_schedule(sched[:]) # attention has 5 kernels now self.assertEqual(len(sched), 4 if RANGEIFY else 5) softmax_inputs = sched[1:4] - for si in softmax_inputs: - assert all(b.dtype == dtypes.half for b in si.bufs), f"non half {si.bufs=}" + for i,si in enumerate(softmax_inputs): + assert all(b.dtype == dtypes.half for b in si.bufs), f"non half {si.bufs=} in kernel {i}" def test_apply_rope(self): x = Tensor.randn(1, 2, 4, 8, dtype=dtypes.float32) diff --git a/tinygrad/apps/llm.py b/tinygrad/apps/llm.py index 62801117bd..50649ffe4e 100644 --- a/tinygrad/apps/llm.py +++ b/tinygrad/apps/llm.py @@ -58,7 +58,8 @@ def apply_rope(x:Tensor, start_pos:int|UOp, base:float = 10000.0) -> Tensor: assert (Hd & 1) == 0, "RoPE requires an even head dimension" half = Hd // 2 angles = (Tensor.arange(T, dtype="float32") + start_pos)[:, None] * (base ** (-(Tensor.arange(half, dtype="float32") / half)))[None, :] - cos, sin = angles.cos().reshape(1, 1, T, half).cast(x.dtype), angles.sin().reshape(1, 1, T, half).cast(x.dtype) + # contiguous here allows RoPE to be pruned in the JIT + cos, sin = angles.cos().reshape(1, 1, T, half).cast(x.dtype).contiguous(), angles.sin().reshape(1, 1, T, half).cast(x.dtype).contiguous() x_pairs = x.reshape(B, H, T, half, 2) return Tensor.stack(x_pairs[..., 0] * cos - x_pairs[..., 1] * sin, x_pairs[..., 0] * sin + x_pairs[..., 1] * cos, dim=-1).reshape(B, H, T, Hd)