add symbolic_simple to the scheduler [pr] (#8419)

This commit is contained in:
qazal
2024-12-26 14:05:08 +02:00
committed by GitHub
parent 6bb54eb532
commit 9defbc7d54
5 changed files with 25 additions and 33 deletions

View File

@@ -295,7 +295,7 @@ jobs:
- if: ${{ matrix.task == 'optimage' }}
name: Test openpilot model kernel count and gate usage
run: |
PYTHONPATH="." ALLOWED_KERNEL_COUNT=208 ALLOWED_READ_IMAGE=2138 ALLOWED_GATED_READ_IMAGE=13 FLOAT16=0 GPU=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx
PYTHONPATH="." ALLOWED_KERNEL_COUNT=208 ALLOWED_READ_IMAGE=2105 ALLOWED_GATED_READ_IMAGE=29 FLOAT16=0 GPU=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx
- if: ${{ matrix.task == 'optimage' }}
name: Test openpilot alt model correctness (float32)
run: PYTHONPATH="." FLOAT16=0 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/3799fe46b3a629e491d4b8498b8ae83e4c88c304/selfdrive/modeld/models/supercombo.onnx

View File

@@ -42,10 +42,18 @@ class TestImageDType(unittest.TestCase):
def test_image_and_back(self):
data = Tensor.randn(9*27*4).realize()
tst = data.numpy()
it = data.cast(dtypes.imagef((9,27,4))).realize()
it = data.cast(dtypes.imagef((9,27,4))).contiguous().realize()
assert isinstance(it.lazydata.base.realized.dtype, ImageDType)
np.testing.assert_equal(tst, it.numpy())
def test_image_cast_and_back_collapses(self):
data = Tensor.randn(9*27*4).realize()
tst = data.numpy()
it = data.cast(dtypes.imagef((9,27,4))).realize()
# the underlying UOp is identical
self.assertIs(it.lazydata.base.realized, data.lazydata.base.realized)
np.testing.assert_equal(tst, it.numpy())
def test_image_and_back_wrong_shape(self):
data = Tensor.randn(9*27*4).realize()
tst = data.numpy()
@@ -59,7 +67,8 @@ class TestImageDType(unittest.TestCase):
np.testing.assert_equal(imgv[0:2], it[0:2].numpy())
def test_mul_stays_image(self):
it = Tensor.randn(4).cast(dtypes.imagef((1,1,4))).realize()
# NOTE: contiguous is needed otherwise this folds
it = Tensor.randn(4).cast(dtypes.imagef((1,1,4))).contiguous().realize()
out = (it*2).realize()
assert isinstance(out.lazydata.base.realized.dtype, ImageDType)
@@ -88,15 +97,15 @@ class TestImageDType(unittest.TestCase):
def test_no_lru_alloc(self):
data = Tensor.randn(9*27*4).realize()
it = data.cast(dtypes.imagef((9,27,4))).realize()
it = data.cast(dtypes.imagef((9,27,4))).contiguous().realize()
b1 = it.lazydata.base.realized._buf
del it
it = data.cast(dtypes.imagef((10,27,4))).realize()
it = data.cast(dtypes.imagef((10,27,4))).contiguous().realize()
assert it.lazydata.base.realized._buf != b1
def test_no_lru_alloc_dtype(self):
data = Tensor.randn(9*27*4).realize()
it = data.cast(dtypes.imagef((9,27,4))).realize()
it = data.cast(dtypes.imagef((9,27,4))).contiguous().realize()
b1 = it.lazydata.base.realized._buf
del it
it = data.cast(dtypes.imageh((9,27,4))).realize()

View File

@@ -10,7 +10,7 @@ from typing import List, Optional, Union, cast
from tinygrad import nn, dtypes, Device, Tensor
from tinygrad.device import is_dtype_supported
from tinygrad.dtype import DType
from tinygrad.dtype import DType, ImageDType
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
from tinygrad.ops import PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_rewrites, view_supported_devices, symbolic
@@ -1405,8 +1405,11 @@ class TestSchedule(unittest.TestCase):
x = Tensor.randn((9, 9)).realize()
y = Tensor.randn((9, 9)).realize()
out = x@y
run_schedule(check_schedule(out, 4))
run_schedule(check_schedule(out, 3))
np.testing.assert_allclose(out.numpy(), x.numpy()@y.numpy(), atol=1e-4, rtol=1e-4)
self.assertIsInstance(out.dtype, ImageDType)
self.assertIsNotNone(out.lazydata.base.realized)
self.assertIsInstance(out.lazydata.base.realized.dtype, ImageDType)
def _test_fusion(self, shapes, f, cnt):
with Context(DEBUG=0, TRACK_MATCH_STATS=0): args = [Tensor.randn(s).realize() for s in shapes]

View File

@@ -2,7 +2,7 @@ import sys, atexit, functools, pickle
from collections import defaultdict, deque
from dataclasses import dataclass, field
from tinygrad.ops import GroupOp, UOp, Ops, PatternMatcher, UPat, Variable, can_pad, graph_rewrite, resolve, track_rewrites, view_left, merge_views
from tinygrad.ops import identity_element, buffers, exec_alu, type_verify
from tinygrad.ops import identity_element, buffers, symbolic_simple, type_verify
from tinygrad.helpers import Context, Metadata, all_int, all_same, colored, diskcache_put, merge_dicts, prod, dedup, getenv, unwrap
from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG, ContextVar
from tinygrad.dtype import DType, ImageDType, dtypes
@@ -34,7 +34,7 @@ tensor_uop_spec = PatternMatcher([
(isinstance(mv.arg, tuple) and mv.dtype == x.dtype) or
# TODO: "make things that can't be images not images" can override the source dtype
# is there a clean way to update its _mop children?
(isinstance(mv.dtype, ImageDType) and x.dtype == mv.dtype.base and x.is_realized)),
((isinstance(mv.dtype, ImageDType) or isinstance(x.dtype, ImageDType)) and x.dtype.base == mv.dtype.base and x.is_realized)),
# Tensor variable bindings
(UPat(Ops.BIND, dtypes.int, (UPat(Ops.DEFINE_VAR), UPat.cvar(dtype=dtypes.int)), arg=None), lambda: True),
@@ -414,21 +414,6 @@ def simplify_reduceop(reduce:UOp, x:UOp) -> UOp|None:
case _: return None
return UOp.const(reduce.dtype, ret)
def simplify_alu(alu:UOp):
if not all(x.is_unrealized_unmasked_const() for x in alu.src): return None
# this needs to have a VIEW next (it has to, right?)
return UOp.const(alu.dtype, exec_alu(alu.op, alu.dtype, [s.const_arg for s in alu.src]))
def simplify_binop(binop:UOp, x:UOp, y:UOp):
if all_int(x.shape) and x.is_unrealized_unmasked_const(): other, const = y, x
elif all_int(y.shape) and y.is_unrealized_unmasked_const():
if binop.op is Ops.IDIV and y.const_arg == 1: return x
other, const = x, y
else: return None
if binop.op is Ops.ADD and const.const_arg == 0: return other
if binop.op is Ops.MUL and const.const_arg == 1: return other
if binop.op is Ops.MUL and const.const_arg == 0: return UOp.const(binop.dtype, 0)
def found_contiguous(ctx:ScheduleContext, contig:UOp, base:UOp, b:UOp):
if contig.src[0].op is Ops.VIEW and len(contig.src[0].src):
old_base = contig.src[0].src[0]
@@ -439,18 +424,13 @@ def replace_contiguous(ctx:ScheduleContext, alu:UOp):
if (replace_src:=ctx.contiguous.get(s, None)) is not None: new_src[i] = replace_src
if tuple(new_src) != alu.src: return alu.replace(src=tuple(new_src))
ops_folding = PatternMatcher([
ops_folding = symbolic_simple+PatternMatcher([
# op with size 0 is zero
(UPatScheduled(), lambda b,to_store,base: base.const_like(0) if base.size == 0 else None),
# if the uop folded to a CONST we can delete the BUFFER
(UPatScheduled(Ops.CONST, name="const"), lambda b,base,const: base.const_like(const.const_arg)),
# DETACH is a NOOP here
(UPat(Ops.DETACH, name="detach"), lambda detach: detach.src[0]),
# elementwise const folding
(UPat(GroupOp.ALU, name="alu"), simplify_alu),
(UPat({Ops.ADD, Ops.MUL, Ops.IDIV}, name="binop", src=(UPat.var("x"), UPat.var("y"))), simplify_binop),
(UPat(Ops.CAST, src=(UPat.var("x"),), name="cast"),
lambda x,cast: UOp.const(cast.dtype, x.const_arg) if all_int(x.shape) and x.is_unrealized_unmasked_const() else None),
# reduce of size 0 is the identity element
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)),
lambda reduce,x:UOp.const(reduce.dtype, identity_element(reduce.arg[0], reduce.dtype)) if x.size == 0 and reduce.size != 0 else None),

View File

@@ -956,11 +956,11 @@ spec = PatternMatcher([
# most ALUs have all matching dtypes, except CMPLT, CMPNE, and WHERE
(UPat(Ops.WHERE, name="w", src=(UPat(dtype=dtypes.bool), UPat(name="x"), UPat(name="y"))), lambda w,x,y: w.dtype == x.dtype == y.dtype),
(UPat((Ops.CMPLT, Ops.CMPNE), dtype=dtypes.bool, src=(UPat(name="x"), UPat(name="y"))), lambda x,y: x.dtype == y.dtype),
(UPat((Ops.CMPLT, Ops.CMPNE), dtype=dtypes.bool, src=(UPat(name="x"), UPat(name="y"))), lambda x,y: x.dtype.base == y.dtype.base),
# and SHL/SHR, the shift distance can be an int
(UPat((Ops.SHL, Ops.SHR), src=(UPat(name="x"), UPat(name="y")), name="a"), lambda a,x,y: a.dtype == x.dtype and y.dtype in (x.dtype, dtypes.uint)),
(UPat(Ops.IDIV, name="x"), lambda x: None if dtypes.is_int(x.dtype) else False),
(UPat(GroupOp.ALU, name="x"), lambda x: all(x.dtype == y.dtype for y in x.src)),
(UPat(GroupOp.ALU, name="x"), lambda x: all(x.dtype.base == y.dtype.base for y in x.src)),
(UPat(Ops.ASSIGN, src=(UPat((Ops.DEFINE_ACC, Ops.DEFINE_GLOBAL)), UPat())), lambda: True),
(UPat(Ops.ENDRANGE, dtype=dtypes.void, src=(UPat(Ops.RANGE),)), lambda: True),