mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 08:28:55 +08:00
rangeify uses symbolic_flat (#12786)
* symbolic_simple -> symbolic_flat * remove expected failures
This commit is contained in:
@@ -24,7 +24,6 @@ class TestUnaryOpsConstFolding(unittest.TestCase):
|
||||
_check_ast_count(0, Tensor.ones(4).cast(dtypes.int16))
|
||||
_check_ast_count(0, Tensor.full(4, fill_value=-1).cast(dtypes.uint16))
|
||||
|
||||
@unittest.expectedFailure # no two level fold
|
||||
def test_neg_folding(self):
|
||||
_check_ast_count(0, Tensor([1, 2, 3]).mul(-1).neg())
|
||||
_check_ast_count(0, Tensor([1, 2, 3]).neg().mul(-1))
|
||||
@@ -83,10 +82,8 @@ class TestBinaryOpsConstFolding(unittest.TestCase):
|
||||
def test_div_tensor_one(self):
|
||||
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) / Tensor.ones(4))
|
||||
|
||||
@unittest.expectedFailure # TODO: fix
|
||||
def test_idiv_literal_one(self):
|
||||
_check_ast_count(0, Tensor([1, 2, 3, 4]) // 1)
|
||||
@unittest.expectedFailure # TODO: fix
|
||||
def test_idiv_tensor_one(self):
|
||||
_check_ast_count(0, Tensor([1, 2, 3, 4]) // Tensor.ones(4, dtype=dtypes.int32))
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from dataclasses import dataclass, field
|
||||
from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, ssimplify, KernelInfo
|
||||
from tinygrad.uop.ops import track_rewrites, graph_rewrite, identity_element, sint, AxisType
|
||||
from tinygrad.uop.symbolic import symbolic_simple
|
||||
from tinygrad.uop.symbolic import symbolic_flat
|
||||
from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, flatten, dedup, all_int, DEBUG, SPLIT_REDUCEOP, Metadata, REAL_SUBSTITUTE
|
||||
from tinygrad.codegen.simplify import pm_flatten_range, pm_reduce_unparented
|
||||
from tinygrad.codegen.opt import Opt
|
||||
@@ -501,7 +501,7 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
|
||||
tsink, rctx = run_rangeify(tsink, getenv("DEBUG_RANGEIFY", 0))
|
||||
|
||||
# NOTE: sym (vs symbolic_simple) breaks things here because ranges with len 1 aren't handled right
|
||||
tsink = graph_rewrite(tsink, symbolic_simple+pm_reduce_unparented, name="symbolic") # this supports const folding
|
||||
tsink = graph_rewrite(tsink, symbolic_flat+pm_reduce_unparented, name="symbolic") # this supports const folding
|
||||
tsink = graph_rewrite(tsink, pm_cleanups, bottom_up=True, name="remove costly buffers")
|
||||
# TODO: can you substitute and remove costly buffers at the same time?
|
||||
tsink = graph_rewrite(tsink, pm_substitute_recurse, bottom_up=True, name="run substitutes")
|
||||
|
||||
Reference in New Issue
Block a user