rangeify uses symbolic_flat (#12786)

* symbolic_simple -> symbolic_flat

* remove expected failures
This commit is contained in:
Sieds Lykles
2025-10-19 12:27:14 +02:00
committed by GitHub
parent 89e7f2fa00
commit fd6ef4801c
2 changed files with 2 additions and 5 deletions

View File

@@ -24,7 +24,6 @@ class TestUnaryOpsConstFolding(unittest.TestCase):
_check_ast_count(0, Tensor.ones(4).cast(dtypes.int16))
_check_ast_count(0, Tensor.full(4, fill_value=-1).cast(dtypes.uint16))
@unittest.expectedFailure # no two level fold
def test_neg_folding(self):
_check_ast_count(0, Tensor([1, 2, 3]).mul(-1).neg())
_check_ast_count(0, Tensor([1, 2, 3]).neg().mul(-1))
@@ -83,10 +82,8 @@ class TestBinaryOpsConstFolding(unittest.TestCase):
def test_div_tensor_one(self):
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) / Tensor.ones(4))
@unittest.expectedFailure # TODO: fix
def test_idiv_literal_one(self):
_check_ast_count(0, Tensor([1, 2, 3, 4]) // 1)
@unittest.expectedFailure # TODO: fix
def test_idiv_tensor_one(self):
_check_ast_count(0, Tensor([1, 2, 3, 4]) // Tensor.ones(4, dtype=dtypes.int32))

View File

@@ -3,7 +3,7 @@ from dataclasses import dataclass, field
from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, ssimplify, KernelInfo
from tinygrad.uop.ops import track_rewrites, graph_rewrite, identity_element, sint, AxisType
from tinygrad.uop.symbolic import symbolic_simple
from tinygrad.uop.symbolic import symbolic_flat
from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, flatten, dedup, all_int, DEBUG, SPLIT_REDUCEOP, Metadata, REAL_SUBSTITUTE
from tinygrad.codegen.simplify import pm_flatten_range, pm_reduce_unparented
from tinygrad.codegen.opt import Opt
@@ -501,7 +501,7 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
tsink, rctx = run_rangeify(tsink, getenv("DEBUG_RANGEIFY", 0))
# NOTE: sym (vs symbolic_simple) breaks things here because ranges with len 1 aren't handled right
tsink = graph_rewrite(tsink, symbolic_simple+pm_reduce_unparented, name="symbolic") # this supports const folding
tsink = graph_rewrite(tsink, symbolic_flat+pm_reduce_unparented, name="symbolic") # this supports const folding
tsink = graph_rewrite(tsink, pm_cleanups, bottom_up=True, name="remove costly buffers")
# TODO: can you substitute and remove costly buffers at the same time?
tsink = graph_rewrite(tsink, pm_substitute_recurse, bottom_up=True, name="run substitutes")