mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
add symbolic_simple to the scheduler [pr] (#8419)
This commit is contained in:
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
@@ -295,7 +295,7 @@ jobs:
|
||||
- if: ${{ matrix.task == 'optimage' }}
|
||||
name: Test openpilot model kernel count and gate usage
|
||||
run: |
|
||||
PYTHONPATH="." ALLOWED_KERNEL_COUNT=208 ALLOWED_READ_IMAGE=2138 ALLOWED_GATED_READ_IMAGE=13 FLOAT16=0 GPU=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx
|
||||
PYTHONPATH="." ALLOWED_KERNEL_COUNT=208 ALLOWED_READ_IMAGE=2105 ALLOWED_GATED_READ_IMAGE=29 FLOAT16=0 GPU=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx
|
||||
- if: ${{ matrix.task == 'optimage' }}
|
||||
name: Test openpilot alt model correctness (float32)
|
||||
run: PYTHONPATH="." FLOAT16=0 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/3799fe46b3a629e491d4b8498b8ae83e4c88c304/selfdrive/modeld/models/supercombo.onnx
|
||||
|
||||
@@ -42,10 +42,18 @@ class TestImageDType(unittest.TestCase):
|
||||
def test_image_and_back(self):
|
||||
data = Tensor.randn(9*27*4).realize()
|
||||
tst = data.numpy()
|
||||
it = data.cast(dtypes.imagef((9,27,4))).realize()
|
||||
it = data.cast(dtypes.imagef((9,27,4))).contiguous().realize()
|
||||
assert isinstance(it.lazydata.base.realized.dtype, ImageDType)
|
||||
np.testing.assert_equal(tst, it.numpy())
|
||||
|
||||
def test_image_cast_and_back_collapses(self):
|
||||
data = Tensor.randn(9*27*4).realize()
|
||||
tst = data.numpy()
|
||||
it = data.cast(dtypes.imagef((9,27,4))).realize()
|
||||
# the underlying UOp is identical
|
||||
self.assertIs(it.lazydata.base.realized, data.lazydata.base.realized)
|
||||
np.testing.assert_equal(tst, it.numpy())
|
||||
|
||||
def test_image_and_back_wrong_shape(self):
|
||||
data = Tensor.randn(9*27*4).realize()
|
||||
tst = data.numpy()
|
||||
@@ -59,7 +67,8 @@ class TestImageDType(unittest.TestCase):
|
||||
np.testing.assert_equal(imgv[0:2], it[0:2].numpy())
|
||||
|
||||
def test_mul_stays_image(self):
|
||||
it = Tensor.randn(4).cast(dtypes.imagef((1,1,4))).realize()
|
||||
# NOTE: contiguous is needed otherwise this folds
|
||||
it = Tensor.randn(4).cast(dtypes.imagef((1,1,4))).contiguous().realize()
|
||||
out = (it*2).realize()
|
||||
assert isinstance(out.lazydata.base.realized.dtype, ImageDType)
|
||||
|
||||
@@ -88,15 +97,15 @@ class TestImageDType(unittest.TestCase):
|
||||
|
||||
def test_no_lru_alloc(self):
|
||||
data = Tensor.randn(9*27*4).realize()
|
||||
it = data.cast(dtypes.imagef((9,27,4))).realize()
|
||||
it = data.cast(dtypes.imagef((9,27,4))).contiguous().realize()
|
||||
b1 = it.lazydata.base.realized._buf
|
||||
del it
|
||||
it = data.cast(dtypes.imagef((10,27,4))).realize()
|
||||
it = data.cast(dtypes.imagef((10,27,4))).contiguous().realize()
|
||||
assert it.lazydata.base.realized._buf != b1
|
||||
|
||||
def test_no_lru_alloc_dtype(self):
|
||||
data = Tensor.randn(9*27*4).realize()
|
||||
it = data.cast(dtypes.imagef((9,27,4))).realize()
|
||||
it = data.cast(dtypes.imagef((9,27,4))).contiguous().realize()
|
||||
b1 = it.lazydata.base.realized._buf
|
||||
del it
|
||||
it = data.cast(dtypes.imageh((9,27,4))).realize()
|
||||
|
||||
@@ -10,7 +10,7 @@ from typing import List, Optional, Union, cast
|
||||
|
||||
from tinygrad import nn, dtypes, Device, Tensor
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.dtype import DType
|
||||
from tinygrad.dtype import DType, ImageDType
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View
|
||||
from tinygrad.ops import PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_rewrites, view_supported_devices, symbolic
|
||||
@@ -1405,8 +1405,11 @@ class TestSchedule(unittest.TestCase):
|
||||
x = Tensor.randn((9, 9)).realize()
|
||||
y = Tensor.randn((9, 9)).realize()
|
||||
out = x@y
|
||||
run_schedule(check_schedule(out, 4))
|
||||
run_schedule(check_schedule(out, 3))
|
||||
np.testing.assert_allclose(out.numpy(), x.numpy()@y.numpy(), atol=1e-4, rtol=1e-4)
|
||||
self.assertIsInstance(out.dtype, ImageDType)
|
||||
self.assertIsNotNone(out.lazydata.base.realized)
|
||||
self.assertIsInstance(out.lazydata.base.realized.dtype, ImageDType)
|
||||
|
||||
def _test_fusion(self, shapes, f, cnt):
|
||||
with Context(DEBUG=0, TRACK_MATCH_STATS=0): args = [Tensor.randn(s).realize() for s in shapes]
|
||||
|
||||
@@ -2,7 +2,7 @@ import sys, atexit, functools, pickle
|
||||
from collections import defaultdict, deque
|
||||
from dataclasses import dataclass, field
|
||||
from tinygrad.ops import GroupOp, UOp, Ops, PatternMatcher, UPat, Variable, can_pad, graph_rewrite, resolve, track_rewrites, view_left, merge_views
|
||||
from tinygrad.ops import identity_element, buffers, exec_alu, type_verify
|
||||
from tinygrad.ops import identity_element, buffers, symbolic_simple, type_verify
|
||||
from tinygrad.helpers import Context, Metadata, all_int, all_same, colored, diskcache_put, merge_dicts, prod, dedup, getenv, unwrap
|
||||
from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG, ContextVar
|
||||
from tinygrad.dtype import DType, ImageDType, dtypes
|
||||
@@ -34,7 +34,7 @@ tensor_uop_spec = PatternMatcher([
|
||||
(isinstance(mv.arg, tuple) and mv.dtype == x.dtype) or
|
||||
# TODO: "make things that can't be images not images" can override the source dtype
|
||||
# is there a clean way to update its _mop children?
|
||||
(isinstance(mv.dtype, ImageDType) and x.dtype == mv.dtype.base and x.is_realized)),
|
||||
((isinstance(mv.dtype, ImageDType) or isinstance(x.dtype, ImageDType)) and x.dtype.base == mv.dtype.base and x.is_realized)),
|
||||
|
||||
# Tensor variable bindings
|
||||
(UPat(Ops.BIND, dtypes.int, (UPat(Ops.DEFINE_VAR), UPat.cvar(dtype=dtypes.int)), arg=None), lambda: True),
|
||||
@@ -414,21 +414,6 @@ def simplify_reduceop(reduce:UOp, x:UOp) -> UOp|None:
|
||||
case _: return None
|
||||
return UOp.const(reduce.dtype, ret)
|
||||
|
||||
def simplify_alu(alu:UOp):
|
||||
if not all(x.is_unrealized_unmasked_const() for x in alu.src): return None
|
||||
# this needs to have a VIEW next (it has to, right?)
|
||||
return UOp.const(alu.dtype, exec_alu(alu.op, alu.dtype, [s.const_arg for s in alu.src]))
|
||||
|
||||
def simplify_binop(binop:UOp, x:UOp, y:UOp):
|
||||
if all_int(x.shape) and x.is_unrealized_unmasked_const(): other, const = y, x
|
||||
elif all_int(y.shape) and y.is_unrealized_unmasked_const():
|
||||
if binop.op is Ops.IDIV and y.const_arg == 1: return x
|
||||
other, const = x, y
|
||||
else: return None
|
||||
if binop.op is Ops.ADD and const.const_arg == 0: return other
|
||||
if binop.op is Ops.MUL and const.const_arg == 1: return other
|
||||
if binop.op is Ops.MUL and const.const_arg == 0: return UOp.const(binop.dtype, 0)
|
||||
|
||||
def found_contiguous(ctx:ScheduleContext, contig:UOp, base:UOp, b:UOp):
|
||||
if contig.src[0].op is Ops.VIEW and len(contig.src[0].src):
|
||||
old_base = contig.src[0].src[0]
|
||||
@@ -439,18 +424,13 @@ def replace_contiguous(ctx:ScheduleContext, alu:UOp):
|
||||
if (replace_src:=ctx.contiguous.get(s, None)) is not None: new_src[i] = replace_src
|
||||
if tuple(new_src) != alu.src: return alu.replace(src=tuple(new_src))
|
||||
|
||||
ops_folding = PatternMatcher([
|
||||
ops_folding = symbolic_simple+PatternMatcher([
|
||||
# op with size 0 is zero
|
||||
(UPatScheduled(), lambda b,to_store,base: base.const_like(0) if base.size == 0 else None),
|
||||
# if the uop folded to a CONST we can delete the BUFFER
|
||||
(UPatScheduled(Ops.CONST, name="const"), lambda b,base,const: base.const_like(const.const_arg)),
|
||||
# DETACH is a NOOP here
|
||||
(UPat(Ops.DETACH, name="detach"), lambda detach: detach.src[0]),
|
||||
# elementwise const folding
|
||||
(UPat(GroupOp.ALU, name="alu"), simplify_alu),
|
||||
(UPat({Ops.ADD, Ops.MUL, Ops.IDIV}, name="binop", src=(UPat.var("x"), UPat.var("y"))), simplify_binop),
|
||||
(UPat(Ops.CAST, src=(UPat.var("x"),), name="cast"),
|
||||
lambda x,cast: UOp.const(cast.dtype, x.const_arg) if all_int(x.shape) and x.is_unrealized_unmasked_const() else None),
|
||||
# reduce of size 0 is the identity element
|
||||
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)),
|
||||
lambda reduce,x:UOp.const(reduce.dtype, identity_element(reduce.arg[0], reduce.dtype)) if x.size == 0 and reduce.size != 0 else None),
|
||||
|
||||
@@ -956,11 +956,11 @@ spec = PatternMatcher([
|
||||
|
||||
# most ALUs have all matching dtypes, except CMPLT, CMPNE, and WHERE
|
||||
(UPat(Ops.WHERE, name="w", src=(UPat(dtype=dtypes.bool), UPat(name="x"), UPat(name="y"))), lambda w,x,y: w.dtype == x.dtype == y.dtype),
|
||||
(UPat((Ops.CMPLT, Ops.CMPNE), dtype=dtypes.bool, src=(UPat(name="x"), UPat(name="y"))), lambda x,y: x.dtype == y.dtype),
|
||||
(UPat((Ops.CMPLT, Ops.CMPNE), dtype=dtypes.bool, src=(UPat(name="x"), UPat(name="y"))), lambda x,y: x.dtype.base == y.dtype.base),
|
||||
# and SHL/SHR, the shift distance can be an int
|
||||
(UPat((Ops.SHL, Ops.SHR), src=(UPat(name="x"), UPat(name="y")), name="a"), lambda a,x,y: a.dtype == x.dtype and y.dtype in (x.dtype, dtypes.uint)),
|
||||
(UPat(Ops.IDIV, name="x"), lambda x: None if dtypes.is_int(x.dtype) else False),
|
||||
(UPat(GroupOp.ALU, name="x"), lambda x: all(x.dtype == y.dtype for y in x.src)),
|
||||
(UPat(GroupOp.ALU, name="x"), lambda x: all(x.dtype.base == y.dtype.base for y in x.src)),
|
||||
|
||||
(UPat(Ops.ASSIGN, src=(UPat((Ops.DEFINE_ACC, Ops.DEFINE_GLOBAL)), UPat())), lambda: True),
|
||||
(UPat(Ops.ENDRANGE, dtype=dtypes.void, src=(UPat(Ops.RANGE),)), lambda: True),
|
||||
|
||||
Reference in New Issue
Block a user