Fix up latest openpilot model (#1976)

* fix gemv triggering for gemm

* fixup_openpilot

* external test issues
This commit is contained in:
George Hotz
2023-10-05 05:24:28 -07:00
committed by GitHub
parent 1862e14a4f
commit 2d0c1037b1
10 changed files with 58 additions and 25 deletions

View File

@@ -151,6 +151,9 @@ jobs:
- if: ${{ matrix.task == 'openpilot' }}
name: Test openpilot model correctness (float32)
run: DEBUGCL=1 GPU=1 IMAGE=2 python openpilot/compile.py
- if: ${{ matrix.task == 'openpilot' }}
name: Test openpilot alt model correctness (float32)
run: DEBUGCL=1 GPU=1 IMAGE=2 python openpilot/compile.py https://github.com/commaai/openpilot/raw/3799fe46b3a629e491d4b8498b8ae83e4c88c304/selfdrive/modeld/models/supercombo.onnx
- if: ${{ matrix.task == 'openpilot' }}
name: Test tensor core ops
run: GPU=1 TC=2 python -m pytest -n=auto test/test_ops.py

View File

@@ -145,10 +145,5 @@ def compile(dat, output_fn):
# UNSAFE_FLOAT4=1 DEBUGCL=1 FLOAT16=1 python3 openpilot/compile.py
# 22.59 ms
if __name__ == "__main__":
if len(sys.argv) >= 3:
with open(sys.argv[1], "rb") as f:
dat = f.read()
compile(dat, sys.argv[2])
else:
dat = fetch(OPENPILOT_MODEL)
compile(dat, "/tmp/output.thneed")
dat = fetch(OPENPILOT_MODEL if len(sys.argv) == 1 else sys.argv[1])
compile(dat, sys.argv[2] if len(sys.argv) >= 3 else "/tmp/output.thneed")

View File

@@ -1,2 +1,2 @@
#!/bin/bash
FLOAT16=1 DEBUGCL=1 VALIDHACKS=1 IMAGE=2 GPU=1 ENABLE_METHOD_CACHE=1 python3 openpilot/compile.py
FLOAT16=1 DEBUGCL=1 IMAGE=2 GPU=1 python3 openpilot/compile.py

View File

@@ -90,7 +90,7 @@ def check_gc():
def derandomize(x):
if isinstance(x, LazyOp):
if x.op == LoadOps.RAND: x.op = LoadOps.EMPTY
x.src = [derandomize(s) for s in x.src]
x.src = tuple([derandomize(s) for s in x.src])
else:
x.op = derandomize(x.op)
return x

View File

@@ -14,7 +14,7 @@ from examples.llama import Transformer
def derandomize(x):
if isinstance(x, LazyOp):
if x.op == LoadOps.RAND: x.op = LoadOps.EMPTY
x.src = [derandomize(s) for s in x.src]
x.src = tuple([derandomize(s) for s in x.src])
else:
x.op = derandomize(x.op)
return x

View File

@@ -55,7 +55,7 @@ def helper_test(nm, gen, train, max_memory_allowed, max_kernels_allowed):
def derandomize(x):
if isinstance(x, LazyOp):
if x.op == LoadOps.RAND: x.op = LoadOps.EMPTY
x.src = [derandomize(s) for s in x.src]
x.src = tuple([derandomize(s) for s in x.src])
elif hasattr(x, "op"):
x.op = derandomize(x.op)
return x

View File

@@ -4,7 +4,7 @@ import itertools, math, functools
from collections import defaultdict
from enum import Enum, auto
from tinygrad.helpers import colored, ImageDType, DEBUG, dtypes, DType, prod, PtrDType, all_same
from tinygrad.helpers import colored, ImageDType, DEBUG, dtypes, DType, prod, PtrDType, all_same, getenv
from tinygrad.ops import LazyOp, UnaryOps, ConstBuffer, MemBuffer, BufferOps
from tinygrad.ops import ReduceOps, BinaryOps, TernaryOps
from tinygrad.shape.shapetracker import ShapeTracker
@@ -88,7 +88,7 @@ def to_image_idx(base_shape:Tuple[int, ...], idxy:Node, valid:Node) -> Tuple[Tup
# This is the slow part
# This part is for brute forcing all possible values of idx, idy and valid
# If valid is both 0 and 1 for the same (idx, idy) we can not delete the valid
if valid.min == 0 and not isinstance(idx, ModNode):
if getenv("VALIDHACKS", 1) and valid.min == 0 and not isinstance(idx, ModNode):
variables = tuple(val_vars | idy_vars | idx_vars)
val_infer, idx_infer, idy_infer = valid.expand(variables), idx.expand(variables), idy.expand(variables)
val_dict: Dict[int, Set[Tuple[int,int]]] = {0:set(), 1:set()}

View File

@@ -4,7 +4,7 @@ from typing import Callable, Optional, Tuple, Union, List, Dict, Any, cast, Mapp
from weakref import ref, WeakSet, WeakValueDictionary
import numpy as np
from tinygrad.helpers import prod, getenv, DType, dtypes, flatten, ImageDType, partition, dedup, merge_dicts
from tinygrad.helpers import prod, getenv, DType, dtypes, flatten, partition, dedup, merge_dicts
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, ReduceOps, MovementOps, LoadOps, OpType, LazyOp, MemBuffer, ConstBuffer, BufferOps
from tinygrad.shape.shapetracker import ShapeTracker, get_contraction
from tinygrad.shape.symbolic import Variable, sint
@@ -75,7 +75,7 @@ def _replace_bufferops(op:LazyOp) -> Tuple[LazyOp, List[LazyBuffer]]:
# **** lazy operations ****
def get_single_root(root:LazyBuffer) -> LazyBuffer: return get_single_root(cast(LazyBuffer, root.op.src[0])) if getattr(root, 'op', None) and len(root.op.src) == 1 else root
def get_single_root(root:LazyBuffer) -> LazyBuffer: return get_single_root(cast(LazyBuffer, root.op.src[0])) if getattr(root, 'op', None) and len(root.op.src) == 1 and isinstance(root.op.src[0], LazyBuffer) else root
def get_movementroot(root:LazyBuffer, allow_contiguous=False) -> LazyBuffer: return get_movementroot(cast(LazyBuffer, root.op.src[0]), allow_contiguous) if not root.realized and (root.optype == MovementOps or (root.op.op == LoadOps.CONTIGUOUS and allow_contiguous and root.op.src[0].st.contiguous)) else root
def get_movementroot_contiguous(x:LazyBuffer) -> LazyBuffer: return get_movementroot_contiguous(cast(LazyBuffer, x.op.src[0])) if not x.realized and x.op.op == LoadOps.CONTIGUOUS else (get_movementroot(x, True) if x.optype == MovementOps and x.st.contiguous else x)
@@ -152,7 +152,7 @@ class LazyBuffer:
@property
def buffers(self) -> Tuple[LazyBuffer, ...]: return (self,)
def map_buffers(self, real_srcs: Mapping[LazyBuffer, Union[LazyBuffer, LazyOp]]): return real_srcs.get(self, self)
def map_buffers(self, real_srcs: Mapping[Any, Union[LazyBuffer, LazyOp]]): return real_srcs.get(self, self)
def get_lazyops(self) -> List[LazyOp]: return []
# *** scheduling ***
@@ -168,12 +168,6 @@ class LazyBuffer:
if self.optype is BinaryOps: op = _ast_binaryops(op, self.shape)
elif self.optype is ReduceOps: op = _ast_reduceops(op)
# HACK: image shape can be wrong, hot cast it back to a normal float
if isinstance(self.dtype, ImageDType) and (prod(self.shape) != prod(self.dtype.shape) or not any(self.shape[x]%4 == 0 for x in self.st.unit_stride_axes())):
if op.op == MovementOps.RESHAPE: op = LazyOp(MovementOps.RESHAPE, (LazyOp(UnaryOps.CAST, op.src, (dtypes.float32, False)),), op.arg)
else: op = LazyOp(UnaryOps.CAST, (op,), (dtypes.float32, False))
self.dtype = dtypes.float32
# realize the past and exec the AST
ret = []
for x in op.buffers: ret += x.schedule(seen)

View File

@@ -56,7 +56,7 @@ class LazyOp:
@property
def key(self): return (self.op, tuple(map(lambda x: getattr(x, "key", x), self.src)), getattr(self.arg, "key", self.arg))
def map_buffers(self, real_srcs: Mapping[LazyBuffer, Union[LazyBuffer, LazyOp]]) -> LazyOp: return LazyOp(self.op, tuple([y.map_buffers(real_srcs) for y in self.src]), self.arg)
def map_buffers(self, real_srcs: Mapping[Any, Union[LazyBuffer, LazyOp]]) -> LazyOp: return LazyOp(self.op, tuple([y.map_buffers(real_srcs) if y not in real_srcs else real_srcs[y] for y in self.src]), self.arg)
def get_lazyops(self) -> List[LazyOp]: return [self] + [item for x in self.src for item in x.get_lazyops()]
def replace_with_movement_ops(self:LazyOp, ops:List[Tuple[MovementOps, Tuple[Any, ...]]]) -> 'LazyBuffer':
@@ -241,6 +241,7 @@ class Compiled:
def get_program():
from tinygrad.codegen.linearizer import Linearizer
k = Linearizer(ast, self.linearizer_opts, var_vals)
assert k.info.dtype == output.dtype, f"linearizer must match dtype. linearizer wants {k.info.dtype} but buffer is {output.dtype}"
from tinygrad.codegen.search import kernel_optimize
if getenv("KOPT"): kernel_optimize(k, lambda: Linearizer(ast, self.linearizer_opts, var_vals), self.to_program, rawbuffers, key)
elif not getenv("NOOPT"): k.hand_coded_optimizations()

View File

@@ -1,16 +1,56 @@
from typing import List, Tuple, cast, Dict, Callable
import numpy as np
from tinygrad.ops import LazyOp, LoadOps, BufferOps, Device
from tinygrad.ops import LazyOp, LoadOps, Device, UnaryOps, BufferOps, MemBuffer, get_lazyop_info
from tinygrad.graph import log_schedule_item
from tinygrad.lazy import LazyBuffer
from tinygrad.helpers import DEBUG, prod, all_int, getenv
from tinygrad.helpers import DEBUG, prod, all_int, getenv, IMAGE, ImageDType, dtypes
from tinygrad.runtime.lib import RawBufferMapped, RawBufferTransfer
from tinygrad.runtime.ops_disk import RawDiskBuffer
P2P = getenv("P2P", 0)
def fix_schedule_for_images(schedule:List[Tuple[LazyOp, LazyBuffer, Tuple[LazyBuffer, ...]]]):
# this is the fundamental fix, find unwritable or unreadable images and convert them to normal float32 (TODO: should it be float16?)
for op,out,buffers in schedule:
if isinstance(out.dtype, ImageDType) and (prod(out.shape) != prod(out.dtype.shape) or not any(out.shape[x]%4 == 0 for x in out.st.unit_stride_axes())):
out.dtype = dtypes.float32
bops = [x for x in op.get_lazyops() if x.op == BufferOps.MEM]
for b in bops:
if isinstance(buffers[b.arg.idx-1].dtype, ImageDType) and (b.arg.st.real_offset() % 4 != 0 or not any(b.arg.st.shape[x]%4 == 0 for x in b.arg.st.unit_stride_axes())):
buffers[b.arg.idx-1].dtype = dtypes.float32
# fix the contiguous dtype, no cast required
for op,out,buffers in schedule:
if op.op == LoadOps.CONTIGUOUS and out.dtype != buffers[0].dtype:
out.dtype = buffers[0].dtype = dtypes.float32
# now fix up the schedule to reflect the new dtypes
fixed_schedule = []
for op,out,buffers in schedule:
# fix input dtypes to match what they actually are
bops = [x for x in op.get_lazyops() if x.op == BufferOps.MEM]
replacements = {}
for x in bops:
if x.arg.dtype != buffers[x.arg.idx-1].dtype:
replacements[x] = LazyOp(BufferOps.MEM, (), MemBuffer(x.arg.idx, buffers[x.arg.idx-1].dtype, x.arg.st))
if replacements: op = op.map_buffers(replacements)
# fix the ops to create the output dtype
if op.op not in LoadOps:
info = get_lazyop_info(op)
if info.dtype != out.dtype:
op = LazyOp(UnaryOps.CAST, (op,), (out.dtype, False))
# put this in the fixed schedule
fixed_schedule.append((op, out, buffers))
return fixed_schedule
def run_schedule(schedule:List[Tuple[LazyOp, LazyBuffer, Tuple[LazyBuffer, ...]]]):
# HACK: images can be not usable due to shape
if IMAGE >= 2: schedule = fix_schedule_for_images(schedule)
# NOTE: if you for loop the schedule it's slow because nothing frees
while len(schedule):
op,out,buffers = schedule.pop(0)