From 867004fbeb22d213d74fb72747cef7b520ccb3fe Mon Sep 17 00:00:00 2001 From: eliotgolding <177857289+eliotgolding@users.noreply.github.com> Date: Sun, 12 Jan 2025 15:25:55 +0000 Subject: [PATCH] use unravel in views_to_indexed_uops [pr] (#8560) * use unravel in shape * make process replay work * earlier View.minify() * fix * fix tests * mypy * get rid of early minify * fix * linter * clean and add test --------- Co-authored-by: chenyu --- test/unit/test_uop_symbolic.py | 11 +++++++++++ tinygrad/ops.py | 19 +++++++++---------- tinygrad/shape/shapetracker.py | 10 +++------- tinygrad/shape/view.py | 10 +++++----- 4 files changed, 28 insertions(+), 22 deletions(-) diff --git a/test/unit/test_uop_symbolic.py b/test/unit/test_uop_symbolic.py index 578156737a..e03d16e7b9 100644 --- a/test/unit/test_uop_symbolic.py +++ b/test/unit/test_uop_symbolic.py @@ -398,6 +398,17 @@ class TestSymbolic(unittest.TestCase): self.helper_test_variable((idx0*v+idx1)//v, 0, 2, "(idx0)") self.helper_test_variable((idx0*v+idx1)%v, 0, start_pos, "idx1") + def test_divmod_variable_denom_fold_to_const(self): + x = Variable("x", 20, 23) + y = Variable("y", 8, 10) + self.helper_test_variable(x//y, 2, 2, "2") + self.helper_test_variable(x%y, 0, 7, "(x+(y*-2))") + # ensure all 4 corners are checked + x = Variable("x", -10, 10) + y = Variable("y", -8, 9) + self.helper_test_variable(x//y, -2147483648, 2147483647, "(x//y)") + self.helper_test_variable(x%y, -2147483648, 2147483647, "(x%y)") + # TODO: simplify the expression def test_div_neg_all_range(self): gidx = Variable("gidx", 0, 124) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 310c7a564f..e3932e56ad 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -993,14 +993,13 @@ def split_uop(x:UOp, sep:Ops): for s in x.src: yield from split_uop(s, sep) else: yield x -def div_and_mod_folding(x: UOp, c: int, which: Literal[Ops.MOD, Ops.IDIV], split_rem: bool=False) -> UOp|None: - # simplify x // c or x % c, None means no change, c must be > 0 - assert c > 0 - if x.dtype.count > 1: return None +def div_and_mod_folding(x: UOp, y: UOp, which: Literal[Ops.MOD, Ops.IDIV], split_rem: bool=False) -> UOp|None: + # simplify x // y or x % y, None means no change # simple cancel div/mod case - if (q:=x.vmin//c) == (x.vmax//c): - if which is Ops.MOD: return x - q*c - return x.const_like(q) + if y.vmin != 0 != y.vmax and (q:=x.vmin//y.vmin) == x.vmin//y.vmax == x.vmax//y.vmin == x.vmax//y.vmax: + return x - q*y if which is Ops.MOD else x.const_like(q) + + if (y.op is not Ops.CONST) or ((c := y.arg) <= 0) or (x.dtype.count > 1): return None svars, factors, quotients, remainders, gcd, div, const, offset, something_changed = [], [], [], [], c, 1, 0, 0, False for u in split_uop(x, Ops.ADD): @@ -1039,7 +1038,7 @@ def div_and_mod_folding(x: UOp, c: int, which: Literal[Ops.MOD, Ops.IDIV], split if gcd != 1: something_changed = True if not something_changed: - if which is Ops.IDIV and (1 < div < c) and (newx:=div_and_mod_folding(x, div, Ops.IDIV)) is not None: return newx//(c//div) + if which is Ops.IDIV and (1 < div < c) and (newx:=div_and_mod_folding(x, UOp.const(dtypes.int, div), Ops.IDIV)) is not None: return newx//(c//div) return None quo, rem = x.const_like(const//c), x.const_like((const%c)//gcd) for q,r,f,v in zip(quotients, remainders, factors, svars): @@ -1259,10 +1258,10 @@ symbolic = symbolic_simple+PatternMatcher([ # ** div ** # div folding ((UPat.var("x")//UPat.cvar("c") + UPat.cvar("a"))//UPat.cvar("d"), lambda x,c,a,d: (x+a*c)//(c*d)), # (x//c+a)//d -> (x+a*c)//(c*d) - (UPat.var("x", dtypes.sints) // UPat.cvar("c", vec=False), lambda x,c: div_and_mod_folding(x,c.arg,Ops.IDIV) if 0 < c.arg else None), + (UPat.var("x", dtypes.sints) // UPat.var("y"), lambda x,y: div_and_mod_folding(x,y,Ops.IDIV)), # ** mod ** # mod folding - (UPat.var("x") % UPat.cvar("c", vec=False), lambda x,c: div_and_mod_folding(x,c.arg,Ops.MOD) if 0 < c.arg else None), + (UPat.var("x") % UPat.var("y"), lambda x,y: div_and_mod_folding(x,y,Ops.MOD)), ]) diff --git a/tinygrad/shape/shapetracker.py b/tinygrad/shape/shapetracker.py index d86688fa0f..a2ed622816 100644 --- a/tinygrad/shape/shapetracker.py +++ b/tinygrad/shape/shapetracker.py @@ -4,20 +4,16 @@ from dataclasses import dataclass import functools from typing import Optional, Callable from tinygrad.helpers import merge_dicts, getenv -from tinygrad.shape.view import View, strides_for_shape +from tinygrad.shape.view import View, strides_for_shape, unravel from tinygrad.dtype import dtypes -from tinygrad.ops import UOp, Ops, graph_rewrite, split_uop, symbolic_flat, Variable, sint, uop_given_valid, simplify_valid +from tinygrad.ops import UOp, Ops, graph_rewrite, split_uop, symbolic_flat, Variable, sint, uop_given_valid, simplify_valid, sint_to_uop @functools.lru_cache(None) def views_to_indexed_uops(views: tuple[View, ...], _idxs:Optional[tuple[UOp, ...]]=None) -> tuple[UOp, UOp]: idx, valid = views[-1].to_indexed_uops(_idxs) for view in reversed(views[0:-1]): view = view.minify() - acc, idxs = 1, [] - for d in reversed(view.shape): - idxs.append((idx//acc)%d) - acc *= d - idx, valid = view.to_indexed_uops(idxs[::-1], valid) + idx, valid = view.to_indexed_uops([sint_to_uop(i) for i in unravel(view.shape, idx)], valid) return idx, valid @functools.lru_cache(None) diff --git a/tinygrad/shape/view.py b/tinygrad/shape/view.py index eb24e54fc4..0ba1044ea5 100644 --- a/tinygrad/shape/view.py +++ b/tinygrad/shape/view.py @@ -73,11 +73,11 @@ def _reshape_mask(_mask:Optional[tuple[tuple[sint, sint], ...]], old_shape:tuple def unravel(shape:tuple[sint, ...], offset:sint) -> list[sint]: # find the position of offset on each dimension based on shape # similar to unravel_index in numpy/torch - ret = [] - for stride in strides_for_shape(shape): - ret.append(offset // stride if stride != 0 else 0) - offset -= ret[-1] * stride - return ret + acc, idxs = 1, [] + for d in reversed(shape): + idxs.append((offset//acc)%d) + acc *= d + return idxs[::-1] @dataclass(frozen=True) class View: