From 58d3d5030bf6f45671c76dfa96cc009338d21bb7 Mon Sep 17 00:00:00 2001 From: chenyu Date: Mon, 1 Jan 2024 18:12:38 -0500 Subject: [PATCH] vars_from_ast -> LazyOp.vars (#2965) --- examples/handcode_resnet50_opt.py | 4 ++-- test/external/fuzz_linearizer.py | 5 ++--- tinygrad/codegen/kernel.py | 4 ++-- tinygrad/codegen/linearizer.py | 4 ++-- tinygrad/device.py | 4 ++-- tinygrad/features/search.py | 6 +++--- tinygrad/lazy.py | 5 ++--- tinygrad/ops.py | 5 ++--- 8 files changed, 17 insertions(+), 20 deletions(-) diff --git a/examples/handcode_resnet50_opt.py b/examples/handcode_resnet50_opt.py index 31c58bafd1..1b4c4339d7 100644 --- a/examples/handcode_resnet50_opt.py +++ b/examples/handcode_resnet50_opt.py @@ -1,7 +1,7 @@ from typing import List from extra.models.resnet import ResNet50 from tinygrad.tensor import Tensor -from tinygrad.ops import LoadOps, vars_from_ast +from tinygrad.ops import LoadOps from tinygrad.device import Device, Compiled from tinygrad.codegen.linearizer import Linearizer from tinygrad.features.search import time_linearizer, beam_search, bufs_from_lin @@ -57,7 +57,7 @@ if __name__ == "__main__": choices = [] for lin in lins: tm = time_linearizer(lin, rawbufs, allow_test_size=False, cnt=10) - gflops = sym_infer(lin.info.flops, {k:k.min for k in vars_from_ast(lin.ast)})*1e-9/tm + gflops = sym_infer(lin.info.flops, {k:k.min for k in lin.ast.vars()})*1e-9/tm choices.append((tm, gflops, lin.linearize())) # print all kernels diff --git a/test/external/fuzz_linearizer.py b/test/external/fuzz_linearizer.py index 0efb4b78ff..f7933fbbbe 100644 --- a/test/external/fuzz_linearizer.py +++ b/test/external/fuzz_linearizer.py @@ -8,7 +8,6 @@ from tinygrad.features.search import get_linearizer_actions, bufs_from_lin from tinygrad.graph import print_tree from tinygrad.helpers import getenv from tinygrad.device import Device, Compiled, Interpreted -from tinygrad.ops import vars_from_ast from tinygrad.codegen.linearizer import UOp def tuplize_uops(uops:List[UOp]) -> Tuple: return tuple([(x.uop, x.dtype, tuple(uops.index(x) for x in x.vin), x.arg) for x in uops]) @@ -17,7 +16,7 @@ device = Device[Device.DEFAULT] def run_linearizer(lin: Linearizer, rawbufs=None, var_vals=None): if rawbufs is None: rawbufs = bufs_from_lin(lin) - if var_vals is None: var_vals = {v: v.min for v in vars_from_ast(lin.ast)} + if var_vals is None: var_vals = {v: v.min for v in lin.ast.vars()} # TODO: images needs required_optimization try: @@ -66,7 +65,7 @@ def fuzz_linearizer(lin: Linearizer): print(lin.colored_shape()) # get a new output buffer rawbufs[0] = type(rawbufs[0])(Device.DEFAULT, rawbufs[0].size, rawbufs[0].dtype) - var_vals = {v: random.randint(v.min, v.max) for v in vars_from_ast(lin.ast)} + var_vals = {v: random.randint(v.min, v.max) for v in lin.ast.vars()} if (msg := run_linearizer(lin, rawbufs, var_vals)) != "PASS": print(f"{lin.applied_opts=}") return msg diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index 2b9f2eefb6..5782ce15e6 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -1,7 +1,7 @@ from __future__ import annotations import os, math, itertools from typing import NamedTuple, Optional, List, Tuple, cast, Dict, Union -from tinygrad.ops import LazyOp, FlopCounter, get_lazyop_info, UnaryOps, BinaryOps, ReduceOps, MemBuffer, ConstBuffer, BufferOps, vars_from_ast +from tinygrad.ops import LazyOp, FlopCounter, get_lazyop_info, UnaryOps, BinaryOps, ReduceOps, MemBuffer, ConstBuffer, BufferOps from tinygrad.device import Device, Compiled from tinygrad.dtype import dtypes, ImageDType, DType from tinygrad.helpers import dedup, colored, ansilen, getenv, prod, DEBUG, round_up @@ -452,7 +452,7 @@ class Kernel: assert not self.dont_use_locals, "already not using locals" self.dont_use_locals = True elif opt.op == OptOps.PADTO: - assert not vars_from_ast(self.ast), "does not work with symbolic shape" + assert not self.ast.vars(), "does not work with symbolic shape" assert axis < self.first_reduce, "cannot pad a reduce axis" padded = False for i,st in enumerate(self.sts): diff --git a/tinygrad/codegen/linearizer.py b/tinygrad/codegen/linearizer.py index f250105acc..08d1c8cc01 100644 --- a/tinygrad/codegen/linearizer.py +++ b/tinygrad/codegen/linearizer.py @@ -7,7 +7,7 @@ from dataclasses import dataclass from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType from tinygrad.helpers import colored, DEBUG, prod, getenv, all_same, to_function_name, flatten -from tinygrad.ops import LazyOp, UnaryOps, BinaryOps, TernaryOps, ReduceOps, ConstBuffer, MemBuffer, BufferOps, vars_from_ast, get_lazyop_info +from tinygrad.ops import LazyOp, UnaryOps, BinaryOps, TernaryOps, ReduceOps, ConstBuffer, MemBuffer, BufferOps, get_lazyop_info from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.symbolic import Variable, NumNode, VariableOrNum, Node, SumNode, MulNode, DivNode, ModNode, LtNode, AndNode from tinygrad.codegen.kernel import LocalBuffer, Kernel @@ -201,7 +201,7 @@ class Linearizer(Kernel): if isinstance(buf, MemBuffer): self.buf_uops[i] = self.uop(UOps.DEFINE_GLOBAL, buf.dtype if isinstance(buf.dtype, ImageDType) else PtrDType(buf.dtype), (), f"data{buf.idx}") # add var vals - for var in vars_from_ast(self.ast): + for var in self.ast.vars(): assert var.expr is not None self.loop_uops[var.expr] = self.uop(UOps.DEFINE_GLOBAL, dtypes.int32, (), var.expr) # define local buffers diff --git a/tinygrad/device.py b/tinygrad/device.py index f10b26c7e5..8130d765e0 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -6,7 +6,7 @@ import importlib, inspect, functools, pathlib, time, re, ctypes from tinygrad.dtype import DType, dtypes, ImageDType from tinygrad.helpers import ansilen, DEBUG, getenv, colored, BEAM, NOOPT, all_int, to_function_name, from_mv, flat_mv, diskcache_get, diskcache_put from tinygrad.shape.symbolic import Variable, sym_infer, sint -from tinygrad.ops import LazyOp, TernaryOps, get_lazyop_info, ReduceOps, BufferOps, BinaryOps, UnaryOps, Op, vars_from_ast, GlobalCounters +from tinygrad.ops import LazyOp, TernaryOps, get_lazyop_info, ReduceOps, BufferOps, BinaryOps, UnaryOps, Op, GlobalCounters if TYPE_CHECKING: from tinygrad.codegen.linearizer import Linearizer @@ -248,7 +248,7 @@ class CompiledASTRunner(JITRunner): if ast: info = get_lazyop_info(ast) self.op_estimate, self.mem_estimate = info.flops, info.mem_estimate - self.vars = vars_from_ast(ast) + self.vars = ast.vars() assert all(v._val is None for v in self.vars), f"ASTRunner contains bound Variable {self.vars}" def build(self, compiler, runtime): diff --git a/tinygrad/features/search.py b/tinygrad/features/search.py index 2a27cddf08..5eaf7b064d 100644 --- a/tinygrad/features/search.py +++ b/tinygrad/features/search.py @@ -1,7 +1,7 @@ from typing import Dict, List, cast, DefaultDict, Optional, Tuple, Callable import itertools, random, math, time, multiprocessing, traceback, signal from tinygrad.device import Device, Compiled, Buffer -from tinygrad.ops import MemBuffer, vars_from_ast +from tinygrad.ops import MemBuffer from tinygrad.helpers import prod, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, colored, to_function_name from tinygrad.dtype import ImageDType from tinygrad.codegen.linearizer import Linearizer @@ -116,7 +116,7 @@ def beam_search(lin:Linearizer, rawbufs, amt:int, allow_test_size=True) -> Linea pool = multiprocessing.Pool(multiprocessing.cpu_count(), init_worker) if getenv("PARALLEL", default_parallel) else None try: - var_vals = {k:(k.max+k.min)//2 for k in vars_from_ast(lin.ast)} + var_vals = {k:(k.max+k.min)//2 for k in lin.ast.vars()} exiting, st = False, time.perf_counter() dev = Device[Device.DEFAULT] assert isinstance(dev, Compiled) @@ -165,7 +165,7 @@ def time_linearizer(lin:Linearizer, rawbufs:List[Buffer], allow_test_size=True, key = {"ast": str(lin.ast), "opts": str(lin.applied_opts), "allow_test_size": allow_test_size, "max_global_size": max_global_size, "clear_l2": clear_l2, "device": Device.DEFAULT} # noqa: E501 if not disable_cache and CACHELEVEL >= 2 and (val:=diskcache_get("time_linearizer", key)) is not None: return min(val) - var_vals = {k:(k.max+k.min)//2 for k in vars_from_ast(lin.ast)} + var_vals = {k:(k.max+k.min)//2 for k in lin.ast.vars()} lib, global_size, local_size = compile_linearizer(Device.DEFAULT, lin) tms = time_program(Device.DEFAULT, lib, global_size, local_size, var_vals, rawbufs, max_global_size=max_global_size if allow_test_size else None, clear_l2=clear_l2, cnt=cnt, name=to_function_name(lin.name)) # noqa: E501 diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index b0b0a23cf2..7d52fa8f24 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -4,8 +4,7 @@ import numpy as np from typing import Union, Optional, Any, Tuple, List, Set, Dict from tinygrad.dtype import dtypes, DType, ImageDType from tinygrad.helpers import prod, merge_dicts, flatten, getenv, dedup, DEBUG, all_int, all_same -from tinygrad.ops import LoadOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, BufferOps -from tinygrad.ops import Op, LazyOp, ConstBuffer, MemBuffer, ScheduleItem, vars_from_ast +from tinygrad.ops import LoadOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, BufferOps, Op, LazyOp, ConstBuffer, MemBuffer, ScheduleItem from tinygrad.shape.symbolic import sint, Variable from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.device import Buffer @@ -211,7 +210,7 @@ def _recursive_schedule(out:LazyBuffer, seen:Set[LazyBuffer], realizes:Set[LazyB op = LazyOp(BufferOps.STORE, (op, ), MemBuffer(0, out.dtype, output_st.simplify())) return flatten(_recursive_schedule(x.base, seen, realizes, reduce_for_op) for x in inputs) + \ - [ScheduleItem(op, out, tuple(inputs), {k:var_vals[k] for k in vars_from_ast(op)})] + [ScheduleItem(op, out, tuple(inputs), {k:var_vals[k] for k in op.vars()})] # recursively search the entire graph for all LazyBuffers, insert realizes after expands def _recurse_lb(buf:LazyBuffer, realizes:Set[LazyBuffer], allbufs:Dict[LazyBuffer, None], simple_pads:Set[LazyBuffer]): diff --git a/tinygrad/ops.py b/tinygrad/ops.py index e190b346b7..04c348333b 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -65,9 +65,8 @@ class LazyOp: def __hash__(self): return self.hash @functools.cached_property def lazyops(self) -> List[LazyOp]: return dedup([self] + [item for x in self.src for item in x.lazyops]) - -def vars_from_ast(ast:LazyOp) -> List[Variable]: - return sorted(set.union(*[x.arg.st.vars() for x in ast.lazyops if x.op in BufferOps], set()), key=lambda x: str(x.expr)) + def vars(self) -> List[Variable]: + return sorted(set.union(*[x.arg.st.vars() for x in self.lazyops if x.op in BufferOps], set()), key=lambda x: str(x.expr)) # **************** independent FlopCounter ****************