From 7703dbef99876ce9906f7ad50626d3221f8a91c5 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Fri, 16 May 2025 15:08:24 -0700 Subject: [PATCH] view substitute [pr] (#10360) --- tinygrad/ops.py | 1 + tinygrad/shape/shapetracker.py | 1 + tinygrad/shape/view.py | 15 +++++++++------ 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index cb6dfd95de..d7fdb84a3f 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -359,6 +359,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): def __int__(self): return self._eval(dtypes.ints, int) def __float__(self): return self._eval(dtypes.floats, float) def substitute(self, dvars:dict[UOp, UOp], name:str|None=None): + dvars = {k:v for k,v in dvars.items() if k is not v} if len(dvars) == 0: return self with Context(TRACK_MATCH_STATS=(0 if name is None else TRACK_MATCH_STATS.value)): return graph_rewrite(self, _substitute, dvars, bottom_up=True, name=name) diff --git a/tinygrad/shape/shapetracker.py b/tinygrad/shape/shapetracker.py index 97538ddf15..4dd8b8ce3c 100644 --- a/tinygrad/shape/shapetracker.py +++ b/tinygrad/shape/shapetracker.py @@ -106,6 +106,7 @@ class ShapeTracker: unbound_views, var_vals = zip(*[v.unbind() for v in self.views]) if all(len(x) == 0 for x in var_vals): return self, {} return ShapeTracker(tuple(unbound_views)), merge_dicts(var_vals) + def substitute(self, dvars:dict[UOp, UOp]): return ShapeTracker(tuple(x.substitute(dvars) for x in self.views)) def real_strides(self, ignore_valid=False) -> tuple[Optional[sint], ...]: with Context(TRACK_MATCH_STATS=0): return views_to_real_strides(self.views, ignore_valid) diff --git a/tinygrad/shape/view.py b/tinygrad/shape/view.py index c93977410b..a7bd6710ec 100644 --- a/tinygrad/shape/view.py +++ b/tinygrad/shape/view.py @@ -141,12 +141,15 @@ class View: def unbind(self) -> tuple[View, dict[Variable, int]]: var_unboundvar_val = [(v, v.unbind()) for v in self.vars() if v.op is Ops.BIND] unbound_vars = {v:uv for v,(uv,_) in var_unboundvar_val} - def substitute(x:sint): return x if isinstance(x, int) else x.substitute(unbound_vars) - new_shape = tuple(map(substitute, self.shape)) - new_strides = tuple(map(substitute, self.strides)) - new_offset = substitute(self.offset) - new_mask = tuple((substitute(x[0]), substitute(x[1])) for x in self.mask) if self.mask is not None else None - return View.create(new_shape, new_strides, new_offset, new_mask), dict(x[1] for x in var_unboundvar_val) + return self.substitute(unbound_vars), dict(x[1] for x in var_unboundvar_val) + + def substitute(self, dvars:dict[UOp, UOp]): + def _substitute(x:sint): return x if isinstance(x, int) else x.substitute(dvars) + new_shape = tuple(map(_substitute, self.shape)) + new_strides = tuple(map(_substitute, self.strides)) + new_offset = _substitute(self.offset) + new_mask = tuple((_substitute(x[0]), _substitute(x[1])) for x in self.mask) if self.mask is not None else None + return View.create(new_shape, new_strides, new_offset, new_mask) @functools.cache # pylint: disable=method-cache-max-size-none def __add__(self, vm1:View) -> Optional[View]: