rangeify: fix contiguous multi (#12278)

* rangeify: fix contiguous multi

* when it's changing root, it should construct a new UOp
This commit is contained in:
qazal
2025-09-23 20:05:29 +03:00
committed by GitHub
parent 5f4eeb054c
commit 2f145a98e0
3 changed files with 5 additions and 3 deletions

View File

@@ -529,7 +529,9 @@ jobs:
- name: Test const folding
run: CPU=1 RANGEIFY=1 python3 -m pytest -n auto --durations 20 test/test_const_folding.py -k "not test_cast_padded and not TestReduceOpsConstFolding and not TestMultiConstFolding"
- name: Test multitensor
run: CPU=1 RANGEIFY=1 python3 test/test_multitensor.py TestMultiTensor.test_matmul_shard_1_1 TestMultiTensor.test_simple_add_W
run: |
CPU=1 RANGEIFY=1 python3 test/test_multitensor.py TestMultiTensor.test_matmul_shard_1_1 TestMultiTensor.test_simple_add_W
CPU=1 RANGEIFY=1 python3 -m pytest test/test_multitensor.py::TestMultiAssign -k 'not (multi_assign_piece_noncontig or multi_assign_var_offset)'
- name: Test CPU=1 RANGEIFY=2
run: CPU=1 CPU_LLVM=0 RANGEIFY=2 python3 -m pytest -n auto test/test_tiny.py test/test_rangeify.py test/test_ops.py --durations 20
# slow (and still wrong on beautiful_mnist)

View File

@@ -211,7 +211,7 @@ def assign_multi(dest:UOp, src:UOp):
return dest.src[0].assign(src.src[0]).multi(src.axis)
def passthrough_multi(root:UOp, multi:UOp):
return root.replace(src=(multi.src[0],)).multi(multi.axis)
return UOp(root.op, root.dtype, (multi.src[0],), root.arg).multi(multi.axis)
# NOTE: this is the same pattern as Ops.UNROLL
multi_pm = PatternMatcher([

View File

@@ -582,7 +582,7 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
# rebuild the sink with all the BUFFERIZEs with tags, this is what's ending up in the tensor graph
# if it's not tagged by here, it's out
tsink = UOp.sink(*[x for x in tsink.parents if (x.op is Ops.BUFFERIZE or x.base.op in {Ops.CONST}) and x.tag is not None])
tsink = UOp.sink(*[x for x in tsink.parents if x.base.op in {Ops.BUFFERIZE, Ops.CONST} and x.tag is not None])
if getenv("VIZ"): graph_rewrite(tsink, PatternMatcher([]), name="View Tagged Rangeify")