From b268755d51d73543ee7bb7586bbae847ff39cd5c Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Tue, 26 Aug 2025 11:56:16 -0700 Subject: [PATCH] small changes from postopt (#11854) --- test/test_rangeify.py | 34 +++++++++++++++++++++++-------- test/test_tiny.py | 5 ++++- tinygrad/codegen/__init__.py | 4 ++-- tinygrad/codegen/late/expander.py | 2 +- tinygrad/codegen/opt/kernel.py | 2 +- tinygrad/runtime/ops_null.py | 1 + tinygrad/schedule/rangeify.py | 14 ++++++++++--- tinygrad/uop/ops.py | 14 +++++-------- 8 files changed, 50 insertions(+), 26 deletions(-) diff --git a/test/test_rangeify.py b/test/test_rangeify.py index 88c388095e..3daf41f14b 100644 --- a/test/test_rangeify.py +++ b/test/test_rangeify.py @@ -1,6 +1,6 @@ import unittest from tinygrad import Tensor -from tinygrad.helpers import RANGEIFY +from tinygrad.helpers import RANGEIFY, Context, GlobalCounters N = 256 @@ -96,14 +96,30 @@ class TestRangeify(unittest.TestCase): out.realize() def test_flash_attention(self): - BS = 4 - HEADS = 2 - MATDIM = 16 - EMB = 8 - q = Tensor.empty(BS, HEADS, MATDIM, EMB) - k = Tensor.empty(BS, HEADS, MATDIM, EMB) - v = Tensor.empty(BS, HEADS, MATDIM, EMB) - q.scaled_dot_product_attention(k, v).realize() + BS, HEADS, SEQLEN, EMB = 4, 2, 16, 8 + + # bigger + #BS, HEADS, SEQLEN, EMB = 4, 16, 128, 64 + + # llama 8B + #BS, HEADS, SEQLEN, EMB = 4, 32, 2048, 128 + + def fa(): + Tensor.manual_seed(1337) + with Context(DEBUG=0): q,k,v = [Tensor.rand(BS, HEADS, SEQLEN, EMB).contiguous().realize() for _ in range(3)] + return q.scaled_dot_product_attention(k, v).realize() + + with Context(DEBUG=4): + GlobalCounters.reset() + ret = fa() + with Context(RANGEIFY=0): + with Context(DEBUG=2): + GlobalCounters.reset() + cmp = fa() + with Context(DEBUG=0): + mse = ((cmp-ret)**2).sum().item() + print(f"mse: {mse}") + self.assertLessEqual(mse, 1e-6) from tinygrad import dtypes from tinygrad.uop.ops import UOp diff --git a/test/test_tiny.py b/test/test_tiny.py index e86a726944..bc133a0dcf 100644 --- a/test/test_tiny.py +++ b/test/test_tiny.py @@ -30,7 +30,10 @@ class TestTiny(unittest.TestCase): def test_gemm(self, N=64, out_dtype=dtypes.float): a = Tensor.ones(N,N).contiguous() b = Tensor.eye(N).contiguous() - self.assertListEqual((out:=a@b).flatten().tolist(), [1.0]*(N*N)) + lst = (out:=a@b).tolist() + for y in range(N): + for x in range(N): + self.assertEqual(lst[y][x], 1.0, msg=f"mismatch at ({y},{x})") if IMAGE < 2: self.assertEqual(out.dtype, out_dtype) # *** randomness *** diff --git a/tinygrad/codegen/__init__.py b/tinygrad/codegen/__init__.py index c90c9b8000..88005c97b2 100644 --- a/tinygrad/codegen/__init__.py +++ b/tinygrad/codegen/__init__.py @@ -18,7 +18,7 @@ from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_in from tinygrad.codegen.late.linearize import block_create, pm_blockend_merge, block_merge, pm_finalize, BlockContext from tinygrad.codegen.opt import pm_get_optimization, pm_do_optimize from tinygrad.codegen.opt.swizzler import view_left, view_right, fix_kernel_ops -from tinygrad.schedule.rangeify import pm_add_buffers, rangeify_codegen +from tinygrad.schedule.rangeify import pm_add_buffers_local, rangeify_codegen @dataclass class RewriteStep: @@ -72,7 +72,7 @@ def _get_rewrites_for_renderer(opts:Renderer, linearizer:bool, _QUANTIZE, _DEVEC ret.append(RewriteStep(sym+expander, name="expander")) # add locals - ret.append(RewriteStep(pm_add_buffers+rangeify_codegen, name="add local buffers")) + ret.append(RewriteStep(pm_add_buffers_local+rangeify_codegen, name="add local buffers")) # ** devectorizer (full_graph_rewrite) ** # remove reduce diff --git a/tinygrad/codegen/late/expander.py b/tinygrad/codegen/late/expander.py index 46fb5d990e..bd88548d87 100644 --- a/tinygrad/codegen/late/expander.py +++ b/tinygrad/codegen/late/expander.py @@ -50,7 +50,7 @@ def do_expand(root:UOp): if root.op is Ops.IF or src.op is Ops.IF: # for the first arg of IF, just pass them through ignoring UNROLLS new_srcs.append(src) - elif (root.op is Ops.STORE and i >= 2) or (root.op in {Ops.REDUCE, Ops.BUFFERIZE} and i >= 1): + elif (root.op is Ops.STORE and i >= 2) or (root.op in {Ops.REDUCE, Ops.BUFFERIZE} and i >= 1) or (root.op is Ops.WMMA and i >= 3): # for any range args of STORE/REDUCE, pass them through new_srcs.append(src) elif root.op is Ops.INDEX and i >= 1 and not isinstance(root.dtype, PtrDType): diff --git a/tinygrad/codegen/opt/kernel.py b/tinygrad/codegen/opt/kernel.py index 80e792db60..2d3af8d3d1 100644 --- a/tinygrad/codegen/opt/kernel.py +++ b/tinygrad/codegen/opt/kernel.py @@ -245,7 +245,7 @@ class Kernel: if axis is None: return -1 if op is OptOps.UNROLL: return self.unrollable_dims[axis] if op in {OptOps.GROUP, OptOps.GROUPTOP}: return self.axes_of(AxisType.REDUCE)[axis] - check(axis < self.shape_len, "invalid axis") + check(axis < self.shape_len, f"invalid axis on {axis=} {op=} {self.shape_len=}") return axis except IndexError as e: raise KernelOptError from e diff --git a/tinygrad/runtime/ops_null.py b/tinygrad/runtime/ops_null.py index 10ae40f4ed..8223e5ddee 100644 --- a/tinygrad/runtime/ops_null.py +++ b/tinygrad/runtime/ops_null.py @@ -7,6 +7,7 @@ class NullRenderer(CStyleLanguage): device = "NULL" has_local = False float4 = "float4" + barrier = "// BARRIER" code_for_op = {**CStyleLanguage.code_for_op, Ops.THREEFRY: lambda a,b,dtype: f"threefry({a},{b})", Ops.MAX: lambda a,b,dtype: f"max({a},{b})"} class NullProgram: diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index 75b95c8dfb..54a808381d 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -329,7 +329,7 @@ pm_cleanups = double_reshape+pm_mops+PatternMatcher([ # BUFFERIZE returns the BUFFER ready for INDEXing (doing this will make splitting a lot easier) # NOTE: this has been fixed up a bit -def bufferize_to_store(x:UOp): +def bufferize_to_store(x:UOp, locals_allowed=False): rngs = x.src[1:] shape = tuple([int(r.vmax+1) for r in rngs]) size = prod(shape) @@ -339,10 +339,18 @@ def bufferize_to_store(x:UOp): assign_target, assign_src = x.src[0].src assert assign_target.op is Ops.INDEX return assign_target.replace(dtype=sdtype).store(assign_src, *rngs, dtype=sdtype) - if sdtype.addrspace == AddrSpace.GLOBAL: buf = UOp.new_buffer(x.arg, size, x.dtype) - else: buf = UOp(Ops.DEFINE_LOCAL, sdtype, arg=x.arg[1]) + # NOTE: the DEFINE_LOCAL needs to be disambiguated here + if sdtype.addrspace == AddrSpace.GLOBAL: + buf = UOp.new_buffer(x.arg, size, x.dtype) + else: + if not locals_allowed: return None + buf = UOp(Ops.DEFINE_LOCAL, sdtype, arg=x.arg[1]) return buf.reshape(shape).index(*rngs, dtype=sdtype).store(x.src[0], *rngs, dtype=sdtype).forced_reshape(shape, dtype=x.dtype) +pm_add_buffers_local = pm_mops+PatternMatcher([ + (UPat(Ops.BUFFERIZE, name="x"), lambda x: bufferize_to_store(x, True)), +]) + pm_add_buffers = pm_mops+PatternMatcher([ (UPat(Ops.BUFFERIZE, name="x"), bufferize_to_store), diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 70cb97f45a..1c5654620a 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -202,17 +202,13 @@ class UOp(MathTrait, metaclass=UOpMetaClass): @functools.cached_property def ranges(self) -> dict[UOp, None]: if self.op is Ops.RANGE: return {self:None} - if self.op in {Ops.BUFFERIZE, Ops.REDUCE}: - ret = self.src[0].ranges.copy() - for s in self.src[1:]: - if s in ret: del ret[s] - elif self.op in {Ops.STORE}: - ret = self.src[0].ranges.copy() - ret.update(self.src[1].ranges) - for s in self.src[2:]: + range_start = {Ops.BUFFERIZE: 1, Ops.REDUCE: 1, Ops.STORE: 2, Ops.WMMA: 3} + ret: dict[UOp, None] = {} + if self.op in range_start.keys(): + for s in self.src[:range_start[self.op]]: ret.update(s.ranges) + for s in self.src[range_start[self.op]:]: if s in ret: del ret[s] else: - ret = {} for s in self.src: ret.update(s.ranges) return ret