mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
vars_from_ast -> LazyOp.vars (#2965)
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
from typing import List
|
||||
from extra.models.resnet import ResNet50
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.ops import LoadOps, vars_from_ast
|
||||
from tinygrad.ops import LoadOps
|
||||
from tinygrad.device import Device, Compiled
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
from tinygrad.features.search import time_linearizer, beam_search, bufs_from_lin
|
||||
@@ -57,7 +57,7 @@ if __name__ == "__main__":
|
||||
choices = []
|
||||
for lin in lins:
|
||||
tm = time_linearizer(lin, rawbufs, allow_test_size=False, cnt=10)
|
||||
gflops = sym_infer(lin.info.flops, {k:k.min for k in vars_from_ast(lin.ast)})*1e-9/tm
|
||||
gflops = sym_infer(lin.info.flops, {k:k.min for k in lin.ast.vars()})*1e-9/tm
|
||||
choices.append((tm, gflops, lin.linearize()))
|
||||
|
||||
# print all kernels
|
||||
|
||||
5
test/external/fuzz_linearizer.py
vendored
5
test/external/fuzz_linearizer.py
vendored
@@ -8,7 +8,6 @@ from tinygrad.features.search import get_linearizer_actions, bufs_from_lin
|
||||
from tinygrad.graph import print_tree
|
||||
from tinygrad.helpers import getenv
|
||||
from tinygrad.device import Device, Compiled, Interpreted
|
||||
from tinygrad.ops import vars_from_ast
|
||||
from tinygrad.codegen.linearizer import UOp
|
||||
|
||||
def tuplize_uops(uops:List[UOp]) -> Tuple: return tuple([(x.uop, x.dtype, tuple(uops.index(x) for x in x.vin), x.arg) for x in uops])
|
||||
@@ -17,7 +16,7 @@ device = Device[Device.DEFAULT]
|
||||
|
||||
def run_linearizer(lin: Linearizer, rawbufs=None, var_vals=None):
|
||||
if rawbufs is None: rawbufs = bufs_from_lin(lin)
|
||||
if var_vals is None: var_vals = {v: v.min for v in vars_from_ast(lin.ast)}
|
||||
if var_vals is None: var_vals = {v: v.min for v in lin.ast.vars()}
|
||||
|
||||
# TODO: images needs required_optimization
|
||||
try:
|
||||
@@ -66,7 +65,7 @@ def fuzz_linearizer(lin: Linearizer):
|
||||
print(lin.colored_shape())
|
||||
# get a new output buffer
|
||||
rawbufs[0] = type(rawbufs[0])(Device.DEFAULT, rawbufs[0].size, rawbufs[0].dtype)
|
||||
var_vals = {v: random.randint(v.min, v.max) for v in vars_from_ast(lin.ast)}
|
||||
var_vals = {v: random.randint(v.min, v.max) for v in lin.ast.vars()}
|
||||
if (msg := run_linearizer(lin, rawbufs, var_vals)) != "PASS":
|
||||
print(f"{lin.applied_opts=}")
|
||||
return msg
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
import os, math, itertools
|
||||
from typing import NamedTuple, Optional, List, Tuple, cast, Dict, Union
|
||||
from tinygrad.ops import LazyOp, FlopCounter, get_lazyop_info, UnaryOps, BinaryOps, ReduceOps, MemBuffer, ConstBuffer, BufferOps, vars_from_ast
|
||||
from tinygrad.ops import LazyOp, FlopCounter, get_lazyop_info, UnaryOps, BinaryOps, ReduceOps, MemBuffer, ConstBuffer, BufferOps
|
||||
from tinygrad.device import Device, Compiled
|
||||
from tinygrad.dtype import dtypes, ImageDType, DType
|
||||
from tinygrad.helpers import dedup, colored, ansilen, getenv, prod, DEBUG, round_up
|
||||
@@ -452,7 +452,7 @@ class Kernel:
|
||||
assert not self.dont_use_locals, "already not using locals"
|
||||
self.dont_use_locals = True
|
||||
elif opt.op == OptOps.PADTO:
|
||||
assert not vars_from_ast(self.ast), "does not work with symbolic shape"
|
||||
assert not self.ast.vars(), "does not work with symbolic shape"
|
||||
assert axis < self.first_reduce, "cannot pad a reduce axis"
|
||||
padded = False
|
||||
for i,st in enumerate(self.sts):
|
||||
|
||||
@@ -7,7 +7,7 @@ from dataclasses import dataclass
|
||||
|
||||
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType
|
||||
from tinygrad.helpers import colored, DEBUG, prod, getenv, all_same, to_function_name, flatten
|
||||
from tinygrad.ops import LazyOp, UnaryOps, BinaryOps, TernaryOps, ReduceOps, ConstBuffer, MemBuffer, BufferOps, vars_from_ast, get_lazyop_info
|
||||
from tinygrad.ops import LazyOp, UnaryOps, BinaryOps, TernaryOps, ReduceOps, ConstBuffer, MemBuffer, BufferOps, get_lazyop_info
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.symbolic import Variable, NumNode, VariableOrNum, Node, SumNode, MulNode, DivNode, ModNode, LtNode, AndNode
|
||||
from tinygrad.codegen.kernel import LocalBuffer, Kernel
|
||||
@@ -201,7 +201,7 @@ class Linearizer(Kernel):
|
||||
if isinstance(buf, MemBuffer):
|
||||
self.buf_uops[i] = self.uop(UOps.DEFINE_GLOBAL, buf.dtype if isinstance(buf.dtype, ImageDType) else PtrDType(buf.dtype), (), f"data{buf.idx}")
|
||||
# add var vals
|
||||
for var in vars_from_ast(self.ast):
|
||||
for var in self.ast.vars():
|
||||
assert var.expr is not None
|
||||
self.loop_uops[var.expr] = self.uop(UOps.DEFINE_GLOBAL, dtypes.int32, (), var.expr)
|
||||
# define local buffers
|
||||
|
||||
@@ -6,7 +6,7 @@ import importlib, inspect, functools, pathlib, time, re, ctypes
|
||||
from tinygrad.dtype import DType, dtypes, ImageDType
|
||||
from tinygrad.helpers import ansilen, DEBUG, getenv, colored, BEAM, NOOPT, all_int, to_function_name, from_mv, flat_mv, diskcache_get, diskcache_put
|
||||
from tinygrad.shape.symbolic import Variable, sym_infer, sint
|
||||
from tinygrad.ops import LazyOp, TernaryOps, get_lazyop_info, ReduceOps, BufferOps, BinaryOps, UnaryOps, Op, vars_from_ast, GlobalCounters
|
||||
from tinygrad.ops import LazyOp, TernaryOps, get_lazyop_info, ReduceOps, BufferOps, BinaryOps, UnaryOps, Op, GlobalCounters
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
@@ -248,7 +248,7 @@ class CompiledASTRunner(JITRunner):
|
||||
if ast:
|
||||
info = get_lazyop_info(ast)
|
||||
self.op_estimate, self.mem_estimate = info.flops, info.mem_estimate
|
||||
self.vars = vars_from_ast(ast)
|
||||
self.vars = ast.vars()
|
||||
assert all(v._val is None for v in self.vars), f"ASTRunner contains bound Variable {self.vars}"
|
||||
|
||||
def build(self, compiler, runtime):
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Dict, List, cast, DefaultDict, Optional, Tuple, Callable
|
||||
import itertools, random, math, time, multiprocessing, traceback, signal
|
||||
from tinygrad.device import Device, Compiled, Buffer
|
||||
from tinygrad.ops import MemBuffer, vars_from_ast
|
||||
from tinygrad.ops import MemBuffer
|
||||
from tinygrad.helpers import prod, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, colored, to_function_name
|
||||
from tinygrad.dtype import ImageDType
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
@@ -116,7 +116,7 @@ def beam_search(lin:Linearizer, rawbufs, amt:int, allow_test_size=True) -> Linea
|
||||
pool = multiprocessing.Pool(multiprocessing.cpu_count(), init_worker) if getenv("PARALLEL", default_parallel) else None
|
||||
|
||||
try:
|
||||
var_vals = {k:(k.max+k.min)//2 for k in vars_from_ast(lin.ast)}
|
||||
var_vals = {k:(k.max+k.min)//2 for k in lin.ast.vars()}
|
||||
exiting, st = False, time.perf_counter()
|
||||
dev = Device[Device.DEFAULT]
|
||||
assert isinstance(dev, Compiled)
|
||||
@@ -165,7 +165,7 @@ def time_linearizer(lin:Linearizer, rawbufs:List[Buffer], allow_test_size=True,
|
||||
key = {"ast": str(lin.ast), "opts": str(lin.applied_opts), "allow_test_size": allow_test_size, "max_global_size": max_global_size, "clear_l2": clear_l2, "device": Device.DEFAULT} # noqa: E501
|
||||
if not disable_cache and CACHELEVEL >= 2 and (val:=diskcache_get("time_linearizer", key)) is not None: return min(val)
|
||||
|
||||
var_vals = {k:(k.max+k.min)//2 for k in vars_from_ast(lin.ast)}
|
||||
var_vals = {k:(k.max+k.min)//2 for k in lin.ast.vars()}
|
||||
lib, global_size, local_size = compile_linearizer(Device.DEFAULT, lin)
|
||||
tms = time_program(Device.DEFAULT, lib, global_size, local_size, var_vals, rawbufs, max_global_size=max_global_size if allow_test_size else None, clear_l2=clear_l2, cnt=cnt, name=to_function_name(lin.name)) # noqa: E501
|
||||
|
||||
|
||||
@@ -4,8 +4,7 @@ import numpy as np
|
||||
from typing import Union, Optional, Any, Tuple, List, Set, Dict
|
||||
from tinygrad.dtype import dtypes, DType, ImageDType
|
||||
from tinygrad.helpers import prod, merge_dicts, flatten, getenv, dedup, DEBUG, all_int, all_same
|
||||
from tinygrad.ops import LoadOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, BufferOps
|
||||
from tinygrad.ops import Op, LazyOp, ConstBuffer, MemBuffer, ScheduleItem, vars_from_ast
|
||||
from tinygrad.ops import LoadOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, BufferOps, Op, LazyOp, ConstBuffer, MemBuffer, ScheduleItem
|
||||
from tinygrad.shape.symbolic import sint, Variable
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.device import Buffer
|
||||
@@ -211,7 +210,7 @@ def _recursive_schedule(out:LazyBuffer, seen:Set[LazyBuffer], realizes:Set[LazyB
|
||||
op = LazyOp(BufferOps.STORE, (op, ), MemBuffer(0, out.dtype, output_st.simplify()))
|
||||
|
||||
return flatten(_recursive_schedule(x.base, seen, realizes, reduce_for_op) for x in inputs) + \
|
||||
[ScheduleItem(op, out, tuple(inputs), {k:var_vals[k] for k in vars_from_ast(op)})]
|
||||
[ScheduleItem(op, out, tuple(inputs), {k:var_vals[k] for k in op.vars()})]
|
||||
|
||||
# recursively search the entire graph for all LazyBuffers, insert realizes after expands
|
||||
def _recurse_lb(buf:LazyBuffer, realizes:Set[LazyBuffer], allbufs:Dict[LazyBuffer, None], simple_pads:Set[LazyBuffer]):
|
||||
|
||||
@@ -65,9 +65,8 @@ class LazyOp:
|
||||
def __hash__(self): return self.hash
|
||||
@functools.cached_property
|
||||
def lazyops(self) -> List[LazyOp]: return dedup([self] + [item for x in self.src for item in x.lazyops])
|
||||
|
||||
def vars_from_ast(ast:LazyOp) -> List[Variable]:
|
||||
return sorted(set.union(*[x.arg.st.vars() for x in ast.lazyops if x.op in BufferOps], set()), key=lambda x: str(x.expr))
|
||||
def vars(self) -> List[Variable]:
|
||||
return sorted(set.union(*[x.arg.st.vars() for x in self.lazyops if x.op in BufferOps], set()), key=lambda x: str(x.expr))
|
||||
|
||||
# **************** independent FlopCounter ****************
|
||||
|
||||
|
||||
Reference in New Issue
Block a user