add cache on reshape (#13466)

* remove cache on divmod, way less objects

* _apply_reshape

* reshape

* no gc on realize

* wow that cache is fast
This commit is contained in:
George Hotz
2025-11-26 18:57:40 -08:00
committed by GitHub
parent f4123b66df
commit 05cd2279d0
6 changed files with 28 additions and 21 deletions

View File

@@ -1,7 +1,7 @@
import gc
from tinygrad import Tensor, UOp, Device, nn
from tinygrad.engine.realize import method_cache, get_program
from tinygrad.schedule.indexing import apply_movement_op
from tinygrad.schedule.indexing import apply_movement_op, _apply_reshape
from tinygrad.uop.divandmod import fold_divmod_general
from test.test_tiny import TestTiny
@@ -70,6 +70,7 @@ if __name__ == "__main__":
# these caches will keep uops alive
method_cache.clear()
apply_movement_op.cache_clear()
_apply_reshape.cache_clear()
fold_divmod_general.cache_clear()
Tensor._device_seeds.clear()
Tensor._device_rng_counters.clear()

View File

@@ -3,7 +3,7 @@ import time, pprint, random, itertools, math
from dataclasses import dataclass, replace, field
from tinygrad.helpers import all_same, colored, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata, TRACEMETA, TracingKey
from tinygrad.helpers import DEVECTORIZE, time_to_str, VALIDATE_WITH_CPU, getenv, cpu_profile, PROFILE, ProfilePointEvent, cpu_events, prod, Context
from tinygrad.helpers import unwrap, disable_gc
from tinygrad.helpers import unwrap
from tinygrad.uop.ops import Ops, PatternMatcher, UOp, UPat, sym_infer, graph_rewrite, print_uops, track_rewrites, KernelInfo, pyrender
from tinygrad.device import Device, Buffer
from tinygrad.renderer import Renderer, ProgramSpec, Estimates
@@ -13,7 +13,6 @@ from tinygrad.codegen.opt import Opt
# **************** Program Creation ****************
@disable_gc()
@track_rewrites(name=lambda *args,ret,**kwargs: TracingKey(ret.name, (ret.function_name, ret.ast), ret=ret), replay=True)
def get_program(ast:UOp, renderer:Renderer|None=None, opts:list[Opt]|None=None) -> ProgramSpec:
"""

View File

@@ -2,10 +2,10 @@ import time
from typing import cast
from dataclasses import dataclass, field, replace
from collections import deque
from tinygrad.uop.ops import UOp, Ops, buffers
from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass
from tinygrad.uop.spec import type_verify, tensor_spec
from tinygrad.device import Buffer, MultiBuffer
from tinygrad.helpers import Metadata, DEBUG, cpu_profile, TracingKey, SPEC, flatten, disable_gc
from tinygrad.helpers import Metadata, DEBUG, cpu_profile, TracingKey, SPEC, flatten
# **** ScheduleItem return type
@@ -113,7 +113,6 @@ from tinygrad.engine.memory import memory_planner
from tinygrad.schedule.rangeify import get_rangeify_map
from tinygrad.schedule.multi import get_multi_map
@disable_gc()
def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], list[ScheduleItem], dict[str, int]]:
# big_sink srcs are all the Tensors
st = time.perf_counter()
@@ -139,5 +138,6 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
# remove all AFTERs, after scheduling, the tensors are just buffers
tensor_map |= {u:u.buf_uop for u in big_sink.toposort() if u.op is Ops.AFTER}
if (DEBUG >= 1 and len(schedule) > 1) or DEBUG >= 3: print(f"scheduled {len(schedule)} kernels in {(time.perf_counter()-st)*1000:.2f} ms")
if (DEBUG >= 1 and len(schedule) > 1) or DEBUG >= 3:
print(f"scheduled {len(schedule)} kernels in {(time.perf_counter()-st)*1000:.2f} ms ({len(UOpMetaClass.ucache)} uops in cache)")
return tensor_map, schedule, var_vals

View File

@@ -119,6 +119,21 @@ pm_apply_rangeify = PatternMatcher([
(UPat((Ops.CONST, Ops.DEFINE_VAR), name="c"), lambda ctx,c: c.replace(src=()) if c in ctx.range_map else None),
])
@functools.cache
def _apply_reshape(in_shape:tuple[sint,...], out_shape:tuple[sint, ...], urngs:UOp) -> UOp:
acc = 1
axes_in:list[UOp] = []
for s,src in list(zip(out_shape, urngs.src))[::-1]:
axes_in.append(acc*src)
acc *= s
combined_axes = sum(axes_in, start=UOp.const(dtypes.index, 0))
axes_out:list[UOp] = []
for s in in_shape[::-1]:
axes_out.append(combined_axes % s)
combined_axes //= s
# this simplify is doing a lot of heavy lifting. this is the replacement for the reshape view merging code
return graph_rewrite(UOp.sink(*axes_out[::-1]), symbolic+pm_simplify_valid+pm_drop_and_clauses, name="reshape")
# this is the definition of the movement ops
@functools.cache
def apply_movement_op(op:Ops, in_shape:tuple[sint,...], arg:tuple, rngs:tuple[UOp, ...]) -> tuple[UOp, ...]:
@@ -134,18 +149,9 @@ def apply_movement_op(op:Ops, in_shape:tuple[sint,...], arg:tuple, rngs:tuple[UO
rngs = tuple(r if (s == 0 and e == 0) else graph_rewrite(((r >= s) & (r < (sh+s))),
symbolic+pm_simplify_valid, name="pad").where(r-s, UOp.invalid()) for r,sh,(s,e) in zip(rngs, in_shape, arg))
case Ops.RESHAPE:
acc = 1
axes_in:list[UOp] = []
for s,src in list(zip(arg, rngs))[::-1]:
axes_in.append(acc*src)
acc *= s
combined_axes = sum(axes_in, start=UOp.const(dtypes.index, 0))
axes_out:list[UOp] = []
for s in in_shape[::-1]:
axes_out.append(combined_axes % s)
combined_axes //= s
# this simplify is doing a lot of heavy lifting. this is the replacement for the reshape view merging code
rngs = graph_rewrite(UOp.sink(*axes_out[::-1]), symbolic+pm_simplify_valid+pm_drop_and_clauses, name="reshape").src
sink = UOp.sink(*rngs)
sub_array = {r:UOp.range(r.src[0], i, AxisType.PLACEHOLDER) for i,r in enumerate(sink.ranges)}
rngs = _apply_reshape(in_shape, arg, sink.substitute(sub_array)).substitute({v:k for k,v in sub_array.items()}).src
case _: raise RuntimeError(f"{op} is not a MovementOp")
return rngs

View File

@@ -7,7 +7,7 @@ from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, leas
from tinygrad.dtype import _from_np_dtype, _to_np_dtype
from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten
from tinygrad.helpers import IMAGE, WINO, Metadata, TRACEMETA, ceildiv, fetch, polyN, is_numpy_ndarray, TracingKey, cpu_profile
from tinygrad.helpers import suppress_finalizing
from tinygrad.helpers import suppress_finalizing, disable_gc
from tinygrad.gradient import compute_gradient
from tinygrad.mixin import OpMixin
from tinygrad.mixin.movement import _align_left
@@ -241,6 +241,7 @@ class Tensor(OpMixin):
assert len(var_vals) == 0
return schedule
@disable_gc()
def realize(self, *lst:Tensor, do_update_stats=True) -> Tensor:
"""Triggers the computation needed to create these Tensor(s)."""
if len(to_realize:=[x for x in (self,)+lst if not x.uop.is_contiguous()]):

View File

@@ -14,7 +14,7 @@ if TYPE_CHECKING:
class AxisType(Enum):
def __repr__(self): return str(self)
GLOBAL = auto(); WARP = auto(); LOCAL = auto(); LOOP = auto(); GROUP_REDUCE = auto(); REDUCE = auto(); UPCAST = auto(); UNROLL = auto() # noqa: E702
THREAD = auto(); OUTER = auto() # noqa: E702
THREAD = auto(); OUTER = auto(); PLACEHOLDER = auto() # noqa: E702
axis_letters = {AxisType.GLOBAL: "g", AxisType.THREAD: "t", AxisType.LOCAL: "l", AxisType.WARP: "w", AxisType.LOOP: "L", AxisType.UPCAST: "u",
AxisType.GROUP_REDUCE: "G", AxisType.REDUCE: "R", AxisType.UNROLL: "r", AxisType.OUTER: "O"}
axis_colors = {AxisType.GLOBAL: "blue", AxisType.THREAD: "BLUE", AxisType.LOCAL: "cyan", AxisType.WARP: "CYAN", AxisType.LOOP: "WHITE",