fixup ast in kernel to be MetaOps.SINK [run_process_replay] (#5424)

* fixup ast in kernel to be MetaOps.SINK [run_process_replay]

* fix tests

* fix more tests
This commit is contained in:
George Hotz
2024-07-12 14:01:03 -07:00
committed by GitHub
parent b055ece550
commit 94599c0637
8 changed files with 35 additions and 28 deletions

View File

@@ -6,7 +6,7 @@ from tinygrad.helpers import tqdm
tactions = set()
def test_rebuild(lin):
linr = Linearizer(*lin.ast)
linr = Linearizer(lin.ast)
for o in lin.applied_opts:
assert o in actions, f"{o} is not in actions"
tactions.add(o)

View File

@@ -87,10 +87,10 @@ def compare_linearizer(lin: Linearizer, rawbufs=None, var_vals=None, ground_trut
if var_vals is None:
# TODO: handle symbolic max case
var_vals = {v: random.randint(v.min, v.max if isinstance(v.max, int) else v.min) for v in lin.ast[0].vars()}
var_vals = {v: random.randint(v.min, v.max if isinstance(v.max, int) else v.min) for v in lin.ast.vars()}
if ground_truth is None and not has_bf16:
unoptimized = Linearizer(*lin.ast)
unoptimized = Linearizer(lin.ast)
unoptimized.required_optimizations()
if run_linearizer(unoptimized, rawbufs, var_vals) != "PASS":
return ("BASELINE_ERROR", rawbufs, var_vals, ground_truth,)
@@ -121,7 +121,7 @@ def fuzz_linearizer(lin: Linearizer, rtol=1e-2, atol=1e-2):
SEED = getenv("SEED", 42)
random.seed(SEED)
np.random.seed(SEED)
for op in lin.ast: print_tree(op)
print_tree(lin.ast)
print(lin.colored_shape())
seen_uops = {}
last_lins = [lin]
@@ -178,8 +178,8 @@ def fuzz_linearizer(lin: Linearizer, rtol=1e-2, atol=1e-2):
return failures
def _is_simple(lin: Linearizer) -> bool:
if len(lin.ast) > 1: return False
ast:LazyOp = lin.ast[0]
if len(lin.ast.src) > 1: return False
ast:LazyOp = lin.ast.src[0]
if ast.src[0] and ast.src[0].op is UnaryOps.CAST and ast.src[0].src[0] and ast.src[0].src[0].op is BufferOps.LOAD: return True
return False

View File

@@ -17,7 +17,7 @@ for offset in tqdm(range(0, row_count, page_size)):
with Context(**{k:v for k,v in ctx.items() if k in ContextVar._cache}):
# try linearize
try:
k = Linearizer(*ast, opts=opts)
k = Linearizer(ast, opts=opts)
for opt in applied_opts: k.apply_opt(opt)
good_src = k.opts.render(name, k.linearize().uops)
except Exception as e:

View File

@@ -4,18 +4,19 @@ from tinygrad.codegen.linearizer import Linearizer
#from tinygrad.codegen.lowerer import Lowerer
from tinygrad.engine.graph import print_tree
from tinygrad.helpers import DEBUG
from tinygrad.ops import BufferOps, MemBuffer, LazyOp, ReduceOps, verify_lazyop
from tinygrad.ops import BufferOps, MemBuffer, LazyOp, ReduceOps, MetaOps, verify_lazyop
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad import dtypes
from tinygrad.shape.view import View
class InvalidLazyOpException(Exception): pass
def lower(*ast:LazyOp):
sink_ast = LazyOp(MetaOps.SINK, ast)
if DEBUG >= 3:
for op in ast: print_tree(op)
try: verify_lazyop(*ast)
try: verify_lazyop(sink_ast)
except AssertionError: raise InvalidLazyOpException()
k = Linearizer(*ast)
k = Linearizer(sink_ast)
k.linearize()
if DEBUG >= 6: k.uops.print()
if DEBUG >= 4: print(k.to_program().src)

View File

@@ -8,7 +8,7 @@ from tinygrad.ops import LazyOp, UnaryOps, BinaryOps, ReduceOps, MemBuffer, Cons
from tinygrad.device import Device
from tinygrad.renderer import Renderer, TensorCore
from tinygrad.dtype import dtypes, ImageDType
from tinygrad.helpers import all_same, colored, ansilen, dedup, flatten, getenv, prod, DEBUG, TC_OPT, USE_TC, round_up, all_int, get_contraction, to_function_name # noqa: E501
from tinygrad.helpers import all_same, colored, ansilen, dedup, getenv, prod, DEBUG, TC_OPT, USE_TC, round_up, all_int, get_contraction, to_function_name # noqa: E501
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.symbolic import sint
from tinygrad.shape.view import strides_for_shape
@@ -61,22 +61,28 @@ class TensorCoreOptions:
class Kernel:
def __init__(self, *ast:LazyOp, opts:Optional[Renderer]=None):
if len(ast) > 1 or ast[0].op is BufferOps.STORE:
assert all(x.op is BufferOps.STORE for x in ast)
self.ast = LazyOp(MetaOps.SINK, ast)
else:
assert len(ast) == 1 and ast[0].op is MetaOps.SINK
self.ast = ast[0]
self.opts = opts if opts is not None else Device[Device.DEFAULT].renderer
try: lazyop_sts_map = verify_lazyop(*ast)
try: lazyop_sts_map = verify_lazyop(self.ast)
except AssertionError as e:
print("INVALID AST")
for op in ast: print_tree(op)
raise e
self.ast = ast
self.lazyops = flatten([op.lazyops for op in self.ast])
self.lazyops = self.ast.lazyops
cached_ordered_lazyops: Dict[LazyOp, List[LazyOp]] = {}
def ordered_lazyops(op):
if op not in cached_ordered_lazyops: cached_ordered_lazyops[op] = dedup([item for x in op.src for item in ordered_lazyops(x)] + [op])
return cached_ordered_lazyops[op]
self.reduceops = dedup([x for out in self.ast for x in ordered_lazyops(out) if x.op in ReduceOps])
self.reduceops = dedup([x for x in ordered_lazyops(self.ast) if x.op in ReduceOps])
self.vars = flatten([x.vars() for x in self.ast])
self.vars = self.ast.vars()
self.bufs: List[Union[MemBuffer, ConstBuffer]] = dedup([x.arg for x in self.lazyops if x.op in BufferOps])
# get earlybufs, before any reduceops
@@ -645,7 +651,7 @@ class Kernel:
def name(self) -> str:
# kernel name (before late upcast)
name = ("r" if self.reduceop else ("C" if all(x.op in BufferOps for x in self.lazyops) else "E")) + \
(f"{len(self.ast)}_" if len(self.ast) > 1 else "_") + \
(f"{len(self.ast.src)}_" if len(self.ast.src) > 1 else "_") + \
colored('_', 'BLACK').join([colored(str(x), c) for x,c in zip(self.full_shape, self.colors())])
# name the function something unique
@@ -720,5 +726,4 @@ class Kernel:
else:
arg = op.arg
return LazyOp(op.op, tuple(fixup_ast(x) for x in op.src), arg)
return fixup_ast(LazyOp(MetaOps.SINK, src=self.ast)), \
KernelInfo(self.full_shape, self.global_dims, self.first_reduce, self.group_for_reduces, self.upcasted)
return fixup_ast(self.ast), KernelInfo(self.full_shape, self.global_dims, self.first_reduce, self.group_for_reduces, self.upcasted)

View File

@@ -88,7 +88,7 @@ class Lowerer(Kernel):
buf = UOp(UOps.DEFINE_LOCAL, PtrDType(x.arg.dtype.base if isinstance(x.arg.dtype, ImageDType) else x.arg.dtype), (), ("temp", x.arg.st.size))
else:
buf = UOp(UOps.DEFINE_GLOBAL, x.arg.dtype if isinstance(x.arg.dtype, ImageDType) else PtrDType(x.arg.dtype), (),
(x.arg.idx, any(x.arg.idx == y.arg.idx for y in self.ast)))
(x.arg.idx, any(x.arg.idx == y.arg.idx for y in self.ast.src)))
if x.op is BufferOps.LOAD:
barrier = (UOp(UOps.BARRIER, None, (self.to_uop(x.src[0]),)),) if len(x.src) else ()
return UOp(UOps.LOAD, x.arg.dtype.scalar(), (buf, idx) + ((valid, UOp.const(x.arg.dtype.scalar(), 0)) if has_valid else ()) + barrier)
@@ -168,7 +168,7 @@ class Lowerer(Kernel):
if getenv("RUN_PROCESS_REPLAY"):
table_name = f"process_replay_{getenv('GITHUB_SHA', 'HEAD')}"
diskcache_put(table_name, id(self), (self.ast, self.opts, self.applied_opts, name, src, {k:v.value for k,v in ContextVar._cache.items()}))
info = get_lazyop_info(self.ast[0])
info = get_lazyop_info(self.ast.src[0]) # TODO: this should be removed
ops, mem = flops_mem(self.uops.uops)
run_count = prod((self.global_size or []) + (self.local_size or []))
return Program(self.name, src, self.opts.device, self.global_size, self.local_size,

View File

@@ -117,7 +117,7 @@ def get_linearizer_actions(lin:Linearizer, include_0=True) -> Dict[int, Lineariz
beam_pool, BEAM_DEBUG = None, getenv("BEAM_DEBUG")
def beam_search(lin:Linearizer, rawbufs:List[Buffer], amt:int, allow_test_size=True) -> Linearizer:
global beam_pool
key = {"ast": lin.ast[0].key, "amt": amt, "allow_test_size": allow_test_size, "device": lin.opts.device, "suffix": lin.opts.suffix}
key = {"ast": lin.ast.key, "amt": amt, "allow_test_size": allow_test_size, "device": lin.opts.device, "suffix": lin.opts.suffix}
if not getenv("IGNORE_BEAM_CACHE") and CACHELEVEL >= 1 and (val:=diskcache_get("beam_search", key)) is not None:
ret = lin.copy()
for o in val[len(lin.applied_opts):]: ret.apply_opt(o)
@@ -136,7 +136,7 @@ def beam_search(lin:Linearizer, rawbufs:List[Buffer], amt:int, allow_test_size=T
try:
rawbufs = _ensure_buffer_alloc(rawbufs)
var_vals = {k:(k.max+k.min)//2 for k in lin.ast[0].vars()}
var_vals = {k:(k.max+k.min)//2 for k in lin.ast.vars()}
exiting, st = False, time.perf_counter()
dev = Device[lin.opts.device]
while not exiting:
@@ -182,7 +182,7 @@ def optimize_local_size(clprg:Callable, global_size:List[int], rawbufs:List[Buff
return ret[1]
def time_linearizer(lin:Linearizer, rawbufs:List[Buffer], allow_test_size=True, max_global_size=65536, cnt=3, disable_cache=False, clear_l2=False) -> float: # noqa: E501
key = {"ast": lin.ast[0].key, "opts": str(lin.applied_opts), "allow_test_size": allow_test_size,
key = {"ast": lin.ast.key, "opts": str(lin.applied_opts), "allow_test_size": allow_test_size,
"max_global_size": max_global_size, "clear_l2": clear_l2, "device": lin.opts.device, "suffix": lin.opts.suffix}
if not disable_cache and CACHELEVEL >= 2 and (val:=diskcache_get("time_linearizer", key)) is not None: return min(val)
@@ -190,7 +190,7 @@ def time_linearizer(lin:Linearizer, rawbufs:List[Buffer], allow_test_size=True,
assert dev.compiler is not None
rawbufs = _ensure_buffer_alloc(rawbufs)
var_vals = {k:(k.max+k.min)//2 for k in lin.ast[0].vars()}
var_vals = {k:(k.max+k.min)//2 for k in lin.ast.vars()}
p = lin.to_program()
tms = _time_program(p, dev.compiler.compile(p.src), var_vals, rawbufs,
max_global_size=max_global_size if allow_test_size else None, clear_l2=clear_l2, cnt=cnt, name=to_function_name(lin.name))

View File

@@ -155,7 +155,8 @@ truncate: Dict[DType, Callable] = {dtypes.bool: bool,
def exec_alu(op:Op, dtype:DType, operands): return truncate.get(dtype, lambda x: x)(python_alu[op](*operands))
# the living definition of LazyOps
def verify_lazyop(*ast:LazyOp) -> Dict[LazyOp, ShapeTracker]:
def verify_lazyop(ast:LazyOp) -> Dict[LazyOp, ShapeTracker]:
assert ast.op is MetaOps.SINK, "must be SINK"
sts: Dict[LazyOp, ShapeTracker] = {}
def dfs(op:LazyOp, st:ShapeTracker):
if op in sts: return
@@ -170,9 +171,9 @@ def verify_lazyop(*ast:LazyOp) -> Dict[LazyOp, ShapeTracker]:
else: st = sts[op.src[0]]
for x in op.src: assert sts[x].shape == st.shape, f"found implicit movement op {x.op} {sts[x].shape} != {op.op} {st.shape}"
sts[op] = st
for i, out in enumerate(ast):
for i, out in enumerate(ast.src):
assert out.arg.idx == i, f"unexpected output buffer idx {out.arg.idx} != {i}"
assert out.op is BufferOps.STORE, f"kernels must have stores as the output, got {out.op}"
assert out.arg.st.size == ast[-1].arg.st.size, f"outputs must have the same size, got {out.arg.st.size}"
assert out.arg.st.size == ast.src[-1].arg.st.size, f"outputs must have the same size, got {out.arg.st.size}"
dfs(out, out.arg.st)
return sts