mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-15 09:33:03 +08:00
add cache on reshape (#13466)
* remove cache on divmod, way less objects * _apply_reshape * reshape * no gc on realize * wow that cache is fast
This commit is contained in:
3
test/external/external_uop_gc.py
vendored
3
test/external/external_uop_gc.py
vendored
@@ -1,7 +1,7 @@
|
||||
import gc
|
||||
from tinygrad import Tensor, UOp, Device, nn
|
||||
from tinygrad.engine.realize import method_cache, get_program
|
||||
from tinygrad.schedule.indexing import apply_movement_op
|
||||
from tinygrad.schedule.indexing import apply_movement_op, _apply_reshape
|
||||
from tinygrad.uop.divandmod import fold_divmod_general
|
||||
from test.test_tiny import TestTiny
|
||||
|
||||
@@ -70,6 +70,7 @@ if __name__ == "__main__":
|
||||
# these caches will keep uops alive
|
||||
method_cache.clear()
|
||||
apply_movement_op.cache_clear()
|
||||
_apply_reshape.cache_clear()
|
||||
fold_divmod_general.cache_clear()
|
||||
Tensor._device_seeds.clear()
|
||||
Tensor._device_rng_counters.clear()
|
||||
|
||||
@@ -3,7 +3,7 @@ import time, pprint, random, itertools, math
|
||||
from dataclasses import dataclass, replace, field
|
||||
from tinygrad.helpers import all_same, colored, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata, TRACEMETA, TracingKey
|
||||
from tinygrad.helpers import DEVECTORIZE, time_to_str, VALIDATE_WITH_CPU, getenv, cpu_profile, PROFILE, ProfilePointEvent, cpu_events, prod, Context
|
||||
from tinygrad.helpers import unwrap, disable_gc
|
||||
from tinygrad.helpers import unwrap
|
||||
from tinygrad.uop.ops import Ops, PatternMatcher, UOp, UPat, sym_infer, graph_rewrite, print_uops, track_rewrites, KernelInfo, pyrender
|
||||
from tinygrad.device import Device, Buffer
|
||||
from tinygrad.renderer import Renderer, ProgramSpec, Estimates
|
||||
@@ -13,7 +13,6 @@ from tinygrad.codegen.opt import Opt
|
||||
|
||||
# **************** Program Creation ****************
|
||||
|
||||
@disable_gc()
|
||||
@track_rewrites(name=lambda *args,ret,**kwargs: TracingKey(ret.name, (ret.function_name, ret.ast), ret=ret), replay=True)
|
||||
def get_program(ast:UOp, renderer:Renderer|None=None, opts:list[Opt]|None=None) -> ProgramSpec:
|
||||
"""
|
||||
|
||||
@@ -2,10 +2,10 @@ import time
|
||||
from typing import cast
|
||||
from dataclasses import dataclass, field, replace
|
||||
from collections import deque
|
||||
from tinygrad.uop.ops import UOp, Ops, buffers
|
||||
from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass
|
||||
from tinygrad.uop.spec import type_verify, tensor_spec
|
||||
from tinygrad.device import Buffer, MultiBuffer
|
||||
from tinygrad.helpers import Metadata, DEBUG, cpu_profile, TracingKey, SPEC, flatten, disable_gc
|
||||
from tinygrad.helpers import Metadata, DEBUG, cpu_profile, TracingKey, SPEC, flatten
|
||||
|
||||
# **** ScheduleItem return type
|
||||
|
||||
@@ -113,7 +113,6 @@ from tinygrad.engine.memory import memory_planner
|
||||
from tinygrad.schedule.rangeify import get_rangeify_map
|
||||
from tinygrad.schedule.multi import get_multi_map
|
||||
|
||||
@disable_gc()
|
||||
def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], list[ScheduleItem], dict[str, int]]:
|
||||
# big_sink srcs are all the Tensors
|
||||
st = time.perf_counter()
|
||||
@@ -139,5 +138,6 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
|
||||
# remove all AFTERs, after scheduling, the tensors are just buffers
|
||||
tensor_map |= {u:u.buf_uop for u in big_sink.toposort() if u.op is Ops.AFTER}
|
||||
|
||||
if (DEBUG >= 1 and len(schedule) > 1) or DEBUG >= 3: print(f"scheduled {len(schedule)} kernels in {(time.perf_counter()-st)*1000:.2f} ms")
|
||||
if (DEBUG >= 1 and len(schedule) > 1) or DEBUG >= 3:
|
||||
print(f"scheduled {len(schedule)} kernels in {(time.perf_counter()-st)*1000:.2f} ms ({len(UOpMetaClass.ucache)} uops in cache)")
|
||||
return tensor_map, schedule, var_vals
|
||||
|
||||
@@ -119,6 +119,21 @@ pm_apply_rangeify = PatternMatcher([
|
||||
(UPat((Ops.CONST, Ops.DEFINE_VAR), name="c"), lambda ctx,c: c.replace(src=()) if c in ctx.range_map else None),
|
||||
])
|
||||
|
||||
@functools.cache
|
||||
def _apply_reshape(in_shape:tuple[sint,...], out_shape:tuple[sint, ...], urngs:UOp) -> UOp:
|
||||
acc = 1
|
||||
axes_in:list[UOp] = []
|
||||
for s,src in list(zip(out_shape, urngs.src))[::-1]:
|
||||
axes_in.append(acc*src)
|
||||
acc *= s
|
||||
combined_axes = sum(axes_in, start=UOp.const(dtypes.index, 0))
|
||||
axes_out:list[UOp] = []
|
||||
for s in in_shape[::-1]:
|
||||
axes_out.append(combined_axes % s)
|
||||
combined_axes //= s
|
||||
# this simplify is doing a lot of heavy lifting. this is the replacement for the reshape view merging code
|
||||
return graph_rewrite(UOp.sink(*axes_out[::-1]), symbolic+pm_simplify_valid+pm_drop_and_clauses, name="reshape")
|
||||
|
||||
# this is the definition of the movement ops
|
||||
@functools.cache
|
||||
def apply_movement_op(op:Ops, in_shape:tuple[sint,...], arg:tuple, rngs:tuple[UOp, ...]) -> tuple[UOp, ...]:
|
||||
@@ -134,18 +149,9 @@ def apply_movement_op(op:Ops, in_shape:tuple[sint,...], arg:tuple, rngs:tuple[UO
|
||||
rngs = tuple(r if (s == 0 and e == 0) else graph_rewrite(((r >= s) & (r < (sh+s))),
|
||||
symbolic+pm_simplify_valid, name="pad").where(r-s, UOp.invalid()) for r,sh,(s,e) in zip(rngs, in_shape, arg))
|
||||
case Ops.RESHAPE:
|
||||
acc = 1
|
||||
axes_in:list[UOp] = []
|
||||
for s,src in list(zip(arg, rngs))[::-1]:
|
||||
axes_in.append(acc*src)
|
||||
acc *= s
|
||||
combined_axes = sum(axes_in, start=UOp.const(dtypes.index, 0))
|
||||
axes_out:list[UOp] = []
|
||||
for s in in_shape[::-1]:
|
||||
axes_out.append(combined_axes % s)
|
||||
combined_axes //= s
|
||||
# this simplify is doing a lot of heavy lifting. this is the replacement for the reshape view merging code
|
||||
rngs = graph_rewrite(UOp.sink(*axes_out[::-1]), symbolic+pm_simplify_valid+pm_drop_and_clauses, name="reshape").src
|
||||
sink = UOp.sink(*rngs)
|
||||
sub_array = {r:UOp.range(r.src[0], i, AxisType.PLACEHOLDER) for i,r in enumerate(sink.ranges)}
|
||||
rngs = _apply_reshape(in_shape, arg, sink.substitute(sub_array)).substitute({v:k for k,v in sub_array.items()}).src
|
||||
case _: raise RuntimeError(f"{op} is not a MovementOp")
|
||||
return rngs
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, leas
|
||||
from tinygrad.dtype import _from_np_dtype, _to_np_dtype
|
||||
from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten
|
||||
from tinygrad.helpers import IMAGE, WINO, Metadata, TRACEMETA, ceildiv, fetch, polyN, is_numpy_ndarray, TracingKey, cpu_profile
|
||||
from tinygrad.helpers import suppress_finalizing
|
||||
from tinygrad.helpers import suppress_finalizing, disable_gc
|
||||
from tinygrad.gradient import compute_gradient
|
||||
from tinygrad.mixin import OpMixin
|
||||
from tinygrad.mixin.movement import _align_left
|
||||
@@ -241,6 +241,7 @@ class Tensor(OpMixin):
|
||||
assert len(var_vals) == 0
|
||||
return schedule
|
||||
|
||||
@disable_gc()
|
||||
def realize(self, *lst:Tensor, do_update_stats=True) -> Tensor:
|
||||
"""Triggers the computation needed to create these Tensor(s)."""
|
||||
if len(to_realize:=[x for x in (self,)+lst if not x.uop.is_contiguous()]):
|
||||
|
||||
@@ -14,7 +14,7 @@ if TYPE_CHECKING:
|
||||
class AxisType(Enum):
|
||||
def __repr__(self): return str(self)
|
||||
GLOBAL = auto(); WARP = auto(); LOCAL = auto(); LOOP = auto(); GROUP_REDUCE = auto(); REDUCE = auto(); UPCAST = auto(); UNROLL = auto() # noqa: E702
|
||||
THREAD = auto(); OUTER = auto() # noqa: E702
|
||||
THREAD = auto(); OUTER = auto(); PLACEHOLDER = auto() # noqa: E702
|
||||
axis_letters = {AxisType.GLOBAL: "g", AxisType.THREAD: "t", AxisType.LOCAL: "l", AxisType.WARP: "w", AxisType.LOOP: "L", AxisType.UPCAST: "u",
|
||||
AxisType.GROUP_REDUCE: "G", AxisType.REDUCE: "R", AxisType.UNROLL: "r", AxisType.OUTER: "O"}
|
||||
axis_colors = {AxisType.GLOBAL: "blue", AxisType.THREAD: "BLUE", AxisType.LOCAL: "cyan", AxisType.WARP: "CYAN", AxisType.LOOP: "WHITE",
|
||||
|
||||
Reference in New Issue
Block a user