move to SPEC=3

This commit is contained in:
George Hotz
2026-04-30 12:42:33 -07:00
parent 7e329c5219
commit cc21351428
2 changed files with 10 additions and 7 deletions

View File

@@ -5,7 +5,7 @@ from tinygrad.dtype import dtypes, AddrSpace
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, graph_rewrite, sint, AxisType, profile_matches
from tinygrad.uop.ops import consumer_map_from_toposort, gate_kernel_sink
from tinygrad.uop.symbolic import symbolic, pm_simplify_valid, pm_drop_and_clauses
from tinygrad.helpers import argsort, all_same, cpu_profile, PCONTIG, colored, Context
from tinygrad.helpers import argsort, all_same, cpu_profile, PCONTIG, colored, Context, SPEC
ALWAYS_CONTIGUOUS: set[Ops] = {Ops.CONTIGUOUS, Ops.AFTER, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW,
Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK, Ops.PARAM,
@@ -265,8 +265,8 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]:
# assign to the range map. rngs are the input ranges, out_rngs are the output ranges, from the x op.
rctx.range_map[x] = (rngs, out_rngs)
# NOTE: SPEC=2 is broken here with shape
with Context(SPEC=1):
# NOTE: SPEC=3 is broken here with shape
with Context(SPEC=min(SPEC.value, 2)):
tsink = graph_rewrite(tsink, pm_apply_rangeify, ctx=rctx, bottom_up=True, name="apply rangeify")
return tsink, rctx

View File

@@ -100,8 +100,11 @@ class UOpMetaClass(type):
buffers[created] = _buffer
if SPEC > 1:
from tinygrad.uop.spec import full_spec, test_pyrender
if SPEC > 2: test_pyrender(created)
_ = created._shape
if SPEC > 2:
# SPEC=3 checks the shape
_ = created._shape
if SPEC > 3:
test_pyrender(created)
with Context(CHECK_OOB=0): fret = cast(bool|None, full_spec.rewrite(created))
if fret is not True: raise RuntimeError(f"SPEC ISSUE {fret}: {created}")
return created
@@ -213,7 +216,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
match self.op:
# late ops don't have shape
case Ops.UNIQUE | Ops.LUNIQUE | Ops.DEVICE | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \
Ops.UNROLL | Ops.CONTRACT | Ops.SINK | Ops.END | Ops.REWRITE_ERROR | \
Ops.CONTRACT | Ops.SINK | Ops.END | Ops.REWRITE_ERROR | Ops.PTRCAT | Ops.ENDIF | \
Ops.LINEAR | Ops.PROGRAM | Ops.SOURCE | Ops.BINARY | Ops.INS | Ops.TUPLE | Ops.CALL | Ops.FUNCTION:
return None
@@ -249,7 +252,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
# these can have shape if it has a vec dtype
if self.dtype.count > 1: return (self.dtype.count,)
return ()
case Ops.BIND | Ops.RANGE | Ops.SPECIAL: return ()
case Ops.BIND | Ops.RANGE | Ops.SPECIAL | Ops.UNROLL: return ()
case Ops.VCONST: return (len(self.arg),)
case Ops.BUFFER: return (self.arg,)
case Ops.BUFFER_VIEW: return (self.arg[0],)