reshape uses valid when simplifying (#12597)

* reshape uses valid when simplifying

* try with IGNORE_OOB=0

* is it this test?

* skipif gpuocelot
This commit is contained in:
Sieds Lykles
2025-10-11 17:02:54 +02:00
committed by GitHub
parent 08e62454b6
commit 772a8dfe31
5 changed files with 18 additions and 13 deletions

View File

@@ -377,7 +377,7 @@ jobs:
llvm: 'true'
- name: Test openpilot model kernel count and gate usage
run: |
ALLOWED_KERNEL_COUNT=190 ALLOWED_READ_IMAGE=2041 ALLOWED_GATED_READ_IMAGE=41 FLOAT16=0 CL=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx
ALLOWED_KERNEL_COUNT=190 ALLOWED_READ_IMAGE=2092 ALLOWED_GATED_READ_IMAGE=55 FLOAT16=0 CL=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx
- name: Test openpilot alt model correctness (float32)
run: FLOAT16=0 DEBUGCL=1 CL=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/3799fe46b3a629e491d4b8498b8ae83e4c88c304/selfdrive/modeld/models/supercombo.onnx
- name: Test openpilot fastvits model correctness (float32)

View File

@@ -8,9 +8,11 @@ from tinygrad.uop.ops import UOp, Ops, GroupOp
from tinygrad.device import Device, Buffer, is_dtype_supported
from tinygrad.tensor import Tensor, _to_np_dtype
from tinygrad.engine.realize import run_schedule, lower_schedule, CompiledRunner, get_program
from tinygrad.helpers import Context, flatten, dedup, TC_SELECT, TC_OPT
from tinygrad.helpers import Context, flatten, dedup, TC_SELECT, TC_OPT, getenv
from tinygrad.dtype import DType, dtypes, PtrDType, AddrSpace
from tinygrad.renderer.ptx import PTXRenderer
from tinygrad.renderer.cstyle import CUDARenderer
MOCKGPU = getenv("MOCKGPU")
class TestLinearizer(unittest.TestCase):
def test_arg_dedup(self):
@@ -314,7 +316,7 @@ class TestLinearizer(unittest.TestCase):
a.realize()
np.testing.assert_equal(a.flatten().numpy(), [1.,1.,1.,1.,2.,2.,2.,2.,1.,1.,1.,1.,1.,1.,1.,1.])
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "PTX indexes differently. might be ok?")
@unittest.skipIf(MOCKGPU and isinstance(Device[Device.DEFAULT].renderer, (PTXRenderer, CUDARenderer)), "PTX indexes differently. might be ok?")
def test_where_fold(self):
a = Tensor.ones(4, 4).contiguous().realize()
b = a.shrink(((1, 2), None)).pad(((1, 2), None))

View File

@@ -42,7 +42,7 @@ class TestWinograd(unittest.TestCase):
out = Tensor.conv2d(x,w, padding=1)
out.mean().backward()
backward_schedule = Tensor.schedule(x.grad, w.grad)
self.assertEqual(len(backward_schedule), 4)
self.assertEqual(len(backward_schedule), 5)
def test_counters(self):
IC, OC, X, Y = 4,4,9,9

View File

@@ -3,7 +3,7 @@ import functools, operator, itertools
from dataclasses import dataclass, field
from tinygrad.dtype import dtypes, AddrSpace
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, graph_rewrite, sint, AxisType
from tinygrad.uop.symbolic import sym, symbolic
from tinygrad.uop.symbolic import symbolic, pm_simplify_valid
from tinygrad.helpers import argsort, all_same, cpu_profile, TracingKey
ALWAYS_CONTIGUOUS: set[Ops] = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW,
@@ -112,8 +112,8 @@ def apply_movement_op(op:Ops, in_shape:tuple[sint,...], arg:tuple, rngs:tuple[UO
case Ops.EXPAND: rngs = tuple(a if in_sh == out_sh else a.const_like(0) for a,in_sh,out_sh in zip(rngs, in_shape, arg))
case Ops.PAD:
# TODO: why is multiple graph_rewrites faster than one here?
rngs = tuple(r if (s == 0 and e == 0) else graph_rewrite(((r >= s) & (r < (sh+s))).where(r-s, UOp.invalid()), sym, name="pad")
for r,sh,(s,e) in zip(rngs, in_shape, arg))
rngs = tuple(r if (s == 0 and e == 0) else graph_rewrite(((r >= s) & (r < (sh+s))).where(r-s, UOp.invalid()),
symbolic+pm_simplify_valid, name="pad") for r,sh,(s,e) in zip(rngs, in_shape, arg))
case Ops.RESHAPE:
acc = 1
axes_in:list[UOp] = []
@@ -126,7 +126,7 @@ def apply_movement_op(op:Ops, in_shape:tuple[sint,...], arg:tuple, rngs:tuple[UO
axes_out.append(combined_axes % s)
combined_axes //= s
# this simplify is doing a lot of heavy lifting. this is the replacement for the reshape view merging code
rngs = graph_rewrite(UOp.sink(*axes_out[::-1]), symbolic, name="reshape").src
rngs = graph_rewrite(UOp.sink(*axes_out[::-1]), symbolic+pm_simplify_valid, name="reshape").src
case _: raise RuntimeError(f"{op} is not a MovementOp")
return rngs

View File

@@ -437,7 +437,7 @@ def uop_given_valid(valid:UOp, uop:UOp, try_simplex=True) -> UOp:
# try all the valids together (but only the whole expressions)
if (s_uop:=uop.substitute(sub_dict:=dict(all_candidates))) is not uop:
uop = s_uop.simplify(tracked=True).substitute({newX:X for X,newX in sub_dict.items()}).simplify(full_symbolic=False)
uop = s_uop.simplify().substitute({newX:X for X,newX in sub_dict.items()}).simplify(full_symbolic=False)
# put the loads back in
uop = uop.substitute({v:k for k,v in load_subs.items()})
return uop
@@ -470,13 +470,16 @@ def reduce_mul_chain(r:UOp):
if len(outside) == 0: return None
return r.replace(src=(prod(inside) if len(inside) else r.src[0].const_like(1),)+r.src[1:])*prod(outside)
# this is symbolic 2.0
REMOVE_FROM_SINK = {Ops.SINK, Ops.UNROLL, Ops.PTRCAT, Ops.CAT, Ops.NOOP}
REMOVE_FROM_BARRIER = {Ops.VECTORIZE, Ops.SINK, Ops.CAT, Ops.PTRCAT, Ops.NOOP}
sym = symbolic_flat+PatternMatcher([
pm_simplify_valid = PatternMatcher([
# simplify valid
(UPat(Ops.AND, name="valid"), simplify_valid),
(UPat.var("c").where(UPat.var("x", dtype=dtypes.index), invalid_pat), lambda c,x,i: c.where(uop_given_valid(c, x, try_simplex=False), i)),
])
# this is symbolic 2.0
REMOVE_FROM_SINK = {Ops.SINK, Ops.UNROLL, Ops.PTRCAT, Ops.CAT, Ops.NOOP}
REMOVE_FROM_BARRIER = {Ops.VECTORIZE, Ops.SINK, Ops.CAT, Ops.PTRCAT, Ops.NOOP}
sym = symbolic_flat+pm_simplify_valid+PatternMatcher([
# LOAD/STORE -> NOOP
(UPat.var('x').store(UPat.var('x').load(), allow_any_len=True), lambda x: None if x.dtype.addrspace != AddrSpace.REG else x.src[0].src[0]),
(UPat(Ops.LOAD, src=(UPat.cvar('c'))), lambda c: c),