vars_from_ast -> LazyOp.vars (#2965)

This commit is contained in:
chenyu
2024-01-01 18:12:38 -05:00
committed by GitHub
parent 980f421442
commit 58d3d5030b
8 changed files with 17 additions and 20 deletions

View File

@@ -1,7 +1,7 @@
from typing import List
from extra.models.resnet import ResNet50
from tinygrad.tensor import Tensor
from tinygrad.ops import LoadOps, vars_from_ast
from tinygrad.ops import LoadOps
from tinygrad.device import Device, Compiled
from tinygrad.codegen.linearizer import Linearizer
from tinygrad.features.search import time_linearizer, beam_search, bufs_from_lin
@@ -57,7 +57,7 @@ if __name__ == "__main__":
choices = []
for lin in lins:
tm = time_linearizer(lin, rawbufs, allow_test_size=False, cnt=10)
gflops = sym_infer(lin.info.flops, {k:k.min for k in vars_from_ast(lin.ast)})*1e-9/tm
gflops = sym_infer(lin.info.flops, {k:k.min for k in lin.ast.vars()})*1e-9/tm
choices.append((tm, gflops, lin.linearize()))
# print all kernels

View File

@@ -8,7 +8,6 @@ from tinygrad.features.search import get_linearizer_actions, bufs_from_lin
from tinygrad.graph import print_tree
from tinygrad.helpers import getenv
from tinygrad.device import Device, Compiled, Interpreted
from tinygrad.ops import vars_from_ast
from tinygrad.codegen.linearizer import UOp
def tuplize_uops(uops:List[UOp]) -> Tuple: return tuple([(x.uop, x.dtype, tuple(uops.index(x) for x in x.vin), x.arg) for x in uops])
@@ -17,7 +16,7 @@ device = Device[Device.DEFAULT]
def run_linearizer(lin: Linearizer, rawbufs=None, var_vals=None):
if rawbufs is None: rawbufs = bufs_from_lin(lin)
if var_vals is None: var_vals = {v: v.min for v in vars_from_ast(lin.ast)}
if var_vals is None: var_vals = {v: v.min for v in lin.ast.vars()}
# TODO: images needs required_optimization
try:
@@ -66,7 +65,7 @@ def fuzz_linearizer(lin: Linearizer):
print(lin.colored_shape())
# get a new output buffer
rawbufs[0] = type(rawbufs[0])(Device.DEFAULT, rawbufs[0].size, rawbufs[0].dtype)
var_vals = {v: random.randint(v.min, v.max) for v in vars_from_ast(lin.ast)}
var_vals = {v: random.randint(v.min, v.max) for v in lin.ast.vars()}
if (msg := run_linearizer(lin, rawbufs, var_vals)) != "PASS":
print(f"{lin.applied_opts=}")
return msg

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
import os, math, itertools
from typing import NamedTuple, Optional, List, Tuple, cast, Dict, Union
from tinygrad.ops import LazyOp, FlopCounter, get_lazyop_info, UnaryOps, BinaryOps, ReduceOps, MemBuffer, ConstBuffer, BufferOps, vars_from_ast
from tinygrad.ops import LazyOp, FlopCounter, get_lazyop_info, UnaryOps, BinaryOps, ReduceOps, MemBuffer, ConstBuffer, BufferOps
from tinygrad.device import Device, Compiled
from tinygrad.dtype import dtypes, ImageDType, DType
from tinygrad.helpers import dedup, colored, ansilen, getenv, prod, DEBUG, round_up
@@ -452,7 +452,7 @@ class Kernel:
assert not self.dont_use_locals, "already not using locals"
self.dont_use_locals = True
elif opt.op == OptOps.PADTO:
assert not vars_from_ast(self.ast), "does not work with symbolic shape"
assert not self.ast.vars(), "does not work with symbolic shape"
assert axis < self.first_reduce, "cannot pad a reduce axis"
padded = False
for i,st in enumerate(self.sts):

View File

@@ -7,7 +7,7 @@ from dataclasses import dataclass
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType
from tinygrad.helpers import colored, DEBUG, prod, getenv, all_same, to_function_name, flatten
from tinygrad.ops import LazyOp, UnaryOps, BinaryOps, TernaryOps, ReduceOps, ConstBuffer, MemBuffer, BufferOps, vars_from_ast, get_lazyop_info
from tinygrad.ops import LazyOp, UnaryOps, BinaryOps, TernaryOps, ReduceOps, ConstBuffer, MemBuffer, BufferOps, get_lazyop_info
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.symbolic import Variable, NumNode, VariableOrNum, Node, SumNode, MulNode, DivNode, ModNode, LtNode, AndNode
from tinygrad.codegen.kernel import LocalBuffer, Kernel
@@ -201,7 +201,7 @@ class Linearizer(Kernel):
if isinstance(buf, MemBuffer):
self.buf_uops[i] = self.uop(UOps.DEFINE_GLOBAL, buf.dtype if isinstance(buf.dtype, ImageDType) else PtrDType(buf.dtype), (), f"data{buf.idx}")
# add var vals
for var in vars_from_ast(self.ast):
for var in self.ast.vars():
assert var.expr is not None
self.loop_uops[var.expr] = self.uop(UOps.DEFINE_GLOBAL, dtypes.int32, (), var.expr)
# define local buffers

View File

@@ -6,7 +6,7 @@ import importlib, inspect, functools, pathlib, time, re, ctypes
from tinygrad.dtype import DType, dtypes, ImageDType
from tinygrad.helpers import ansilen, DEBUG, getenv, colored, BEAM, NOOPT, all_int, to_function_name, from_mv, flat_mv, diskcache_get, diskcache_put
from tinygrad.shape.symbolic import Variable, sym_infer, sint
from tinygrad.ops import LazyOp, TernaryOps, get_lazyop_info, ReduceOps, BufferOps, BinaryOps, UnaryOps, Op, vars_from_ast, GlobalCounters
from tinygrad.ops import LazyOp, TernaryOps, get_lazyop_info, ReduceOps, BufferOps, BinaryOps, UnaryOps, Op, GlobalCounters
if TYPE_CHECKING:
from tinygrad.codegen.linearizer import Linearizer
@@ -248,7 +248,7 @@ class CompiledASTRunner(JITRunner):
if ast:
info = get_lazyop_info(ast)
self.op_estimate, self.mem_estimate = info.flops, info.mem_estimate
self.vars = vars_from_ast(ast)
self.vars = ast.vars()
assert all(v._val is None for v in self.vars), f"ASTRunner contains bound Variable {self.vars}"
def build(self, compiler, runtime):

View File

@@ -1,7 +1,7 @@
from typing import Dict, List, cast, DefaultDict, Optional, Tuple, Callable
import itertools, random, math, time, multiprocessing, traceback, signal
from tinygrad.device import Device, Compiled, Buffer
from tinygrad.ops import MemBuffer, vars_from_ast
from tinygrad.ops import MemBuffer
from tinygrad.helpers import prod, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, colored, to_function_name
from tinygrad.dtype import ImageDType
from tinygrad.codegen.linearizer import Linearizer
@@ -116,7 +116,7 @@ def beam_search(lin:Linearizer, rawbufs, amt:int, allow_test_size=True) -> Linea
pool = multiprocessing.Pool(multiprocessing.cpu_count(), init_worker) if getenv("PARALLEL", default_parallel) else None
try:
var_vals = {k:(k.max+k.min)//2 for k in vars_from_ast(lin.ast)}
var_vals = {k:(k.max+k.min)//2 for k in lin.ast.vars()}
exiting, st = False, time.perf_counter()
dev = Device[Device.DEFAULT]
assert isinstance(dev, Compiled)
@@ -165,7 +165,7 @@ def time_linearizer(lin:Linearizer, rawbufs:List[Buffer], allow_test_size=True,
key = {"ast": str(lin.ast), "opts": str(lin.applied_opts), "allow_test_size": allow_test_size, "max_global_size": max_global_size, "clear_l2": clear_l2, "device": Device.DEFAULT} # noqa: E501
if not disable_cache and CACHELEVEL >= 2 and (val:=diskcache_get("time_linearizer", key)) is not None: return min(val)
var_vals = {k:(k.max+k.min)//2 for k in vars_from_ast(lin.ast)}
var_vals = {k:(k.max+k.min)//2 for k in lin.ast.vars()}
lib, global_size, local_size = compile_linearizer(Device.DEFAULT, lin)
tms = time_program(Device.DEFAULT, lib, global_size, local_size, var_vals, rawbufs, max_global_size=max_global_size if allow_test_size else None, clear_l2=clear_l2, cnt=cnt, name=to_function_name(lin.name)) # noqa: E501

View File

@@ -4,8 +4,7 @@ import numpy as np
from typing import Union, Optional, Any, Tuple, List, Set, Dict
from tinygrad.dtype import dtypes, DType, ImageDType
from tinygrad.helpers import prod, merge_dicts, flatten, getenv, dedup, DEBUG, all_int, all_same
from tinygrad.ops import LoadOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, BufferOps
from tinygrad.ops import Op, LazyOp, ConstBuffer, MemBuffer, ScheduleItem, vars_from_ast
from tinygrad.ops import LoadOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, BufferOps, Op, LazyOp, ConstBuffer, MemBuffer, ScheduleItem
from tinygrad.shape.symbolic import sint, Variable
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.device import Buffer
@@ -211,7 +210,7 @@ def _recursive_schedule(out:LazyBuffer, seen:Set[LazyBuffer], realizes:Set[LazyB
op = LazyOp(BufferOps.STORE, (op, ), MemBuffer(0, out.dtype, output_st.simplify()))
return flatten(_recursive_schedule(x.base, seen, realizes, reduce_for_op) for x in inputs) + \
[ScheduleItem(op, out, tuple(inputs), {k:var_vals[k] for k in vars_from_ast(op)})]
[ScheduleItem(op, out, tuple(inputs), {k:var_vals[k] for k in op.vars()})]
# recursively search the entire graph for all LazyBuffers, insert realizes after expands
def _recurse_lb(buf:LazyBuffer, realizes:Set[LazyBuffer], allbufs:Dict[LazyBuffer, None], simple_pads:Set[LazyBuffer]):

View File

@@ -65,9 +65,8 @@ class LazyOp:
def __hash__(self): return self.hash
@functools.cached_property
def lazyops(self) -> List[LazyOp]: return dedup([self] + [item for x in self.src for item in x.lazyops])
def vars_from_ast(ast:LazyOp) -> List[Variable]:
return sorted(set.union(*[x.arg.st.vars() for x in ast.lazyops if x.op in BufferOps], set()), key=lambda x: str(x.expr))
def vars(self) -> List[Variable]:
return sorted(set.union(*[x.arg.st.vars() for x in self.lazyops if x.op in BufferOps], set()), key=lambda x: str(x.expr))
# **************** independent FlopCounter ****************