From cc21351428a7271ee2f663372b0874b0ddfda950 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Thu, 30 Apr 2026 12:42:33 -0700 Subject: [PATCH] move to SPEC=3 --- tinygrad/schedule/indexing.py | 6 +++--- tinygrad/uop/ops.py | 11 +++++++---- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/tinygrad/schedule/indexing.py b/tinygrad/schedule/indexing.py index 4631967497..5ee7789a49 100644 --- a/tinygrad/schedule/indexing.py +++ b/tinygrad/schedule/indexing.py @@ -5,7 +5,7 @@ from tinygrad.dtype import dtypes, AddrSpace from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, graph_rewrite, sint, AxisType, profile_matches from tinygrad.uop.ops import consumer_map_from_toposort, gate_kernel_sink from tinygrad.uop.symbolic import symbolic, pm_simplify_valid, pm_drop_and_clauses -from tinygrad.helpers import argsort, all_same, cpu_profile, PCONTIG, colored, Context +from tinygrad.helpers import argsort, all_same, cpu_profile, PCONTIG, colored, Context, SPEC ALWAYS_CONTIGUOUS: set[Ops] = {Ops.CONTIGUOUS, Ops.AFTER, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW, Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK, Ops.PARAM, @@ -265,8 +265,8 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]: # assign to the range map. rngs are the input ranges, out_rngs are the output ranges, from the x op. rctx.range_map[x] = (rngs, out_rngs) - # NOTE: SPEC=2 is broken here with shape - with Context(SPEC=1): + # NOTE: SPEC=3 is broken here with shape + with Context(SPEC=min(SPEC.value, 2)): tsink = graph_rewrite(tsink, pm_apply_rangeify, ctx=rctx, bottom_up=True, name="apply rangeify") return tsink, rctx diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index c4cf7d6f2e..db84130887 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -100,8 +100,11 @@ class UOpMetaClass(type): buffers[created] = _buffer if SPEC > 1: from tinygrad.uop.spec import full_spec, test_pyrender - if SPEC > 2: test_pyrender(created) - _ = created._shape + if SPEC > 2: + # SPEC=3 checks the shape + _ = created._shape + if SPEC > 3: + test_pyrender(created) with Context(CHECK_OOB=0): fret = cast(bool|None, full_spec.rewrite(created)) if fret is not True: raise RuntimeError(f"SPEC ISSUE {fret}: {created}") return created @@ -213,7 +216,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass): match self.op: # late ops don't have shape case Ops.UNIQUE | Ops.LUNIQUE | Ops.DEVICE | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \ - Ops.UNROLL | Ops.CONTRACT | Ops.SINK | Ops.END | Ops.REWRITE_ERROR | \ + Ops.CONTRACT | Ops.SINK | Ops.END | Ops.REWRITE_ERROR | Ops.PTRCAT | Ops.ENDIF | \ Ops.LINEAR | Ops.PROGRAM | Ops.SOURCE | Ops.BINARY | Ops.INS | Ops.TUPLE | Ops.CALL | Ops.FUNCTION: return None @@ -249,7 +252,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass): # these can have shape if it has a vec dtype if self.dtype.count > 1: return (self.dtype.count,) return () - case Ops.BIND | Ops.RANGE | Ops.SPECIAL: return () + case Ops.BIND | Ops.RANGE | Ops.SPECIAL | Ops.UNROLL: return () case Ops.VCONST: return (len(self.arg),) case Ops.BUFFER: return (self.arg,) case Ops.BUFFER_VIEW: return (self.arg[0],)