mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 08:28:55 +08:00
reshape uses valid when simplifying (#12597)
* reshape uses valid when simplifying * try with IGNORE_OOB=0 * is it this test? * skipif gpuocelot
This commit is contained in:
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
@@ -377,7 +377,7 @@ jobs:
|
||||
llvm: 'true'
|
||||
- name: Test openpilot model kernel count and gate usage
|
||||
run: |
|
||||
ALLOWED_KERNEL_COUNT=190 ALLOWED_READ_IMAGE=2041 ALLOWED_GATED_READ_IMAGE=41 FLOAT16=0 CL=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx
|
||||
ALLOWED_KERNEL_COUNT=190 ALLOWED_READ_IMAGE=2092 ALLOWED_GATED_READ_IMAGE=55 FLOAT16=0 CL=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx
|
||||
- name: Test openpilot alt model correctness (float32)
|
||||
run: FLOAT16=0 DEBUGCL=1 CL=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/3799fe46b3a629e491d4b8498b8ae83e4c88c304/selfdrive/modeld/models/supercombo.onnx
|
||||
- name: Test openpilot fastvits model correctness (float32)
|
||||
|
||||
@@ -8,9 +8,11 @@ from tinygrad.uop.ops import UOp, Ops, GroupOp
|
||||
from tinygrad.device import Device, Buffer, is_dtype_supported
|
||||
from tinygrad.tensor import Tensor, _to_np_dtype
|
||||
from tinygrad.engine.realize import run_schedule, lower_schedule, CompiledRunner, get_program
|
||||
from tinygrad.helpers import Context, flatten, dedup, TC_SELECT, TC_OPT
|
||||
from tinygrad.helpers import Context, flatten, dedup, TC_SELECT, TC_OPT, getenv
|
||||
from tinygrad.dtype import DType, dtypes, PtrDType, AddrSpace
|
||||
from tinygrad.renderer.ptx import PTXRenderer
|
||||
from tinygrad.renderer.cstyle import CUDARenderer
|
||||
MOCKGPU = getenv("MOCKGPU")
|
||||
|
||||
class TestLinearizer(unittest.TestCase):
|
||||
def test_arg_dedup(self):
|
||||
@@ -314,7 +316,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
a.realize()
|
||||
np.testing.assert_equal(a.flatten().numpy(), [1.,1.,1.,1.,2.,2.,2.,2.,1.,1.,1.,1.,1.,1.,1.,1.])
|
||||
|
||||
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "PTX indexes differently. might be ok?")
|
||||
@unittest.skipIf(MOCKGPU and isinstance(Device[Device.DEFAULT].renderer, (PTXRenderer, CUDARenderer)), "PTX indexes differently. might be ok?")
|
||||
def test_where_fold(self):
|
||||
a = Tensor.ones(4, 4).contiguous().realize()
|
||||
b = a.shrink(((1, 2), None)).pad(((1, 2), None))
|
||||
|
||||
@@ -42,7 +42,7 @@ class TestWinograd(unittest.TestCase):
|
||||
out = Tensor.conv2d(x,w, padding=1)
|
||||
out.mean().backward()
|
||||
backward_schedule = Tensor.schedule(x.grad, w.grad)
|
||||
self.assertEqual(len(backward_schedule), 4)
|
||||
self.assertEqual(len(backward_schedule), 5)
|
||||
|
||||
def test_counters(self):
|
||||
IC, OC, X, Y = 4,4,9,9
|
||||
|
||||
@@ -3,7 +3,7 @@ import functools, operator, itertools
|
||||
from dataclasses import dataclass, field
|
||||
from tinygrad.dtype import dtypes, AddrSpace
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, graph_rewrite, sint, AxisType
|
||||
from tinygrad.uop.symbolic import sym, symbolic
|
||||
from tinygrad.uop.symbolic import symbolic, pm_simplify_valid
|
||||
from tinygrad.helpers import argsort, all_same, cpu_profile, TracingKey
|
||||
|
||||
ALWAYS_CONTIGUOUS: set[Ops] = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW,
|
||||
@@ -112,8 +112,8 @@ def apply_movement_op(op:Ops, in_shape:tuple[sint,...], arg:tuple, rngs:tuple[UO
|
||||
case Ops.EXPAND: rngs = tuple(a if in_sh == out_sh else a.const_like(0) for a,in_sh,out_sh in zip(rngs, in_shape, arg))
|
||||
case Ops.PAD:
|
||||
# TODO: why is multiple graph_rewrites faster than one here?
|
||||
rngs = tuple(r if (s == 0 and e == 0) else graph_rewrite(((r >= s) & (r < (sh+s))).where(r-s, UOp.invalid()), sym, name="pad")
|
||||
for r,sh,(s,e) in zip(rngs, in_shape, arg))
|
||||
rngs = tuple(r if (s == 0 and e == 0) else graph_rewrite(((r >= s) & (r < (sh+s))).where(r-s, UOp.invalid()),
|
||||
symbolic+pm_simplify_valid, name="pad") for r,sh,(s,e) in zip(rngs, in_shape, arg))
|
||||
case Ops.RESHAPE:
|
||||
acc = 1
|
||||
axes_in:list[UOp] = []
|
||||
@@ -126,7 +126,7 @@ def apply_movement_op(op:Ops, in_shape:tuple[sint,...], arg:tuple, rngs:tuple[UO
|
||||
axes_out.append(combined_axes % s)
|
||||
combined_axes //= s
|
||||
# this simplify is doing a lot of heavy lifting. this is the replacement for the reshape view merging code
|
||||
rngs = graph_rewrite(UOp.sink(*axes_out[::-1]), symbolic, name="reshape").src
|
||||
rngs = graph_rewrite(UOp.sink(*axes_out[::-1]), symbolic+pm_simplify_valid, name="reshape").src
|
||||
case _: raise RuntimeError(f"{op} is not a MovementOp")
|
||||
return rngs
|
||||
|
||||
|
||||
@@ -437,7 +437,7 @@ def uop_given_valid(valid:UOp, uop:UOp, try_simplex=True) -> UOp:
|
||||
|
||||
# try all the valids together (but only the whole expressions)
|
||||
if (s_uop:=uop.substitute(sub_dict:=dict(all_candidates))) is not uop:
|
||||
uop = s_uop.simplify(tracked=True).substitute({newX:X for X,newX in sub_dict.items()}).simplify(full_symbolic=False)
|
||||
uop = s_uop.simplify().substitute({newX:X for X,newX in sub_dict.items()}).simplify(full_symbolic=False)
|
||||
# put the loads back in
|
||||
uop = uop.substitute({v:k for k,v in load_subs.items()})
|
||||
return uop
|
||||
@@ -470,13 +470,16 @@ def reduce_mul_chain(r:UOp):
|
||||
if len(outside) == 0: return None
|
||||
return r.replace(src=(prod(inside) if len(inside) else r.src[0].const_like(1),)+r.src[1:])*prod(outside)
|
||||
|
||||
# this is symbolic 2.0
|
||||
REMOVE_FROM_SINK = {Ops.SINK, Ops.UNROLL, Ops.PTRCAT, Ops.CAT, Ops.NOOP}
|
||||
REMOVE_FROM_BARRIER = {Ops.VECTORIZE, Ops.SINK, Ops.CAT, Ops.PTRCAT, Ops.NOOP}
|
||||
sym = symbolic_flat+PatternMatcher([
|
||||
pm_simplify_valid = PatternMatcher([
|
||||
# simplify valid
|
||||
(UPat(Ops.AND, name="valid"), simplify_valid),
|
||||
(UPat.var("c").where(UPat.var("x", dtype=dtypes.index), invalid_pat), lambda c,x,i: c.where(uop_given_valid(c, x, try_simplex=False), i)),
|
||||
])
|
||||
|
||||
# this is symbolic 2.0
|
||||
REMOVE_FROM_SINK = {Ops.SINK, Ops.UNROLL, Ops.PTRCAT, Ops.CAT, Ops.NOOP}
|
||||
REMOVE_FROM_BARRIER = {Ops.VECTORIZE, Ops.SINK, Ops.CAT, Ops.PTRCAT, Ops.NOOP}
|
||||
sym = symbolic_flat+pm_simplify_valid+PatternMatcher([
|
||||
# LOAD/STORE -> NOOP
|
||||
(UPat.var('x').store(UPat.var('x').load(), allow_any_len=True), lambda x: None if x.dtype.addrspace != AddrSpace.REG else x.src[0].src[0]),
|
||||
(UPat(Ops.LOAD, src=(UPat.cvar('c'))), lambda c: c),
|
||||
|
||||
Reference in New Issue
Block a user