mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-14 17:05:35 +08:00
view substitute [pr] (#10360)
This commit is contained in:
@@ -359,6 +359,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
def __int__(self): return self._eval(dtypes.ints, int)
|
||||
def __float__(self): return self._eval(dtypes.floats, float)
|
||||
def substitute(self, dvars:dict[UOp, UOp], name:str|None=None):
|
||||
dvars = {k:v for k,v in dvars.items() if k is not v}
|
||||
if len(dvars) == 0: return self
|
||||
with Context(TRACK_MATCH_STATS=(0 if name is None else TRACK_MATCH_STATS.value)):
|
||||
return graph_rewrite(self, _substitute, dvars, bottom_up=True, name=name)
|
||||
|
||||
@@ -106,6 +106,7 @@ class ShapeTracker:
|
||||
unbound_views, var_vals = zip(*[v.unbind() for v in self.views])
|
||||
if all(len(x) == 0 for x in var_vals): return self, {}
|
||||
return ShapeTracker(tuple(unbound_views)), merge_dicts(var_vals)
|
||||
def substitute(self, dvars:dict[UOp, UOp]): return ShapeTracker(tuple(x.substitute(dvars) for x in self.views))
|
||||
|
||||
def real_strides(self, ignore_valid=False) -> tuple[Optional[sint], ...]:
|
||||
with Context(TRACK_MATCH_STATS=0): return views_to_real_strides(self.views, ignore_valid)
|
||||
|
||||
@@ -141,12 +141,15 @@ class View:
|
||||
def unbind(self) -> tuple[View, dict[Variable, int]]:
|
||||
var_unboundvar_val = [(v, v.unbind()) for v in self.vars() if v.op is Ops.BIND]
|
||||
unbound_vars = {v:uv for v,(uv,_) in var_unboundvar_val}
|
||||
def substitute(x:sint): return x if isinstance(x, int) else x.substitute(unbound_vars)
|
||||
new_shape = tuple(map(substitute, self.shape))
|
||||
new_strides = tuple(map(substitute, self.strides))
|
||||
new_offset = substitute(self.offset)
|
||||
new_mask = tuple((substitute(x[0]), substitute(x[1])) for x in self.mask) if self.mask is not None else None
|
||||
return View.create(new_shape, new_strides, new_offset, new_mask), dict(x[1] for x in var_unboundvar_val)
|
||||
return self.substitute(unbound_vars), dict(x[1] for x in var_unboundvar_val)
|
||||
|
||||
def substitute(self, dvars:dict[UOp, UOp]):
|
||||
def _substitute(x:sint): return x if isinstance(x, int) else x.substitute(dvars)
|
||||
new_shape = tuple(map(_substitute, self.shape))
|
||||
new_strides = tuple(map(_substitute, self.strides))
|
||||
new_offset = _substitute(self.offset)
|
||||
new_mask = tuple((_substitute(x[0]), _substitute(x[1])) for x in self.mask) if self.mask is not None else None
|
||||
return View.create(new_shape, new_strides, new_offset, new_mask)
|
||||
|
||||
@functools.cache # pylint: disable=method-cache-max-size-none
|
||||
def __add__(self, vm1:View) -> Optional[View]:
|
||||
|
||||
Reference in New Issue
Block a user