mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
rangeify: fix contiguous multi (#12278)
* rangeify: fix contiguous multi * when it's changing root, it should construct a new UOp
This commit is contained in:
4
.github/workflows/test.yml
vendored
4
.github/workflows/test.yml
vendored
@@ -529,7 +529,9 @@ jobs:
|
||||
- name: Test const folding
|
||||
run: CPU=1 RANGEIFY=1 python3 -m pytest -n auto --durations 20 test/test_const_folding.py -k "not test_cast_padded and not TestReduceOpsConstFolding and not TestMultiConstFolding"
|
||||
- name: Test multitensor
|
||||
run: CPU=1 RANGEIFY=1 python3 test/test_multitensor.py TestMultiTensor.test_matmul_shard_1_1 TestMultiTensor.test_simple_add_W
|
||||
run: |
|
||||
CPU=1 RANGEIFY=1 python3 test/test_multitensor.py TestMultiTensor.test_matmul_shard_1_1 TestMultiTensor.test_simple_add_W
|
||||
CPU=1 RANGEIFY=1 python3 -m pytest test/test_multitensor.py::TestMultiAssign -k 'not (multi_assign_piece_noncontig or multi_assign_var_offset)'
|
||||
- name: Test CPU=1 RANGEIFY=2
|
||||
run: CPU=1 CPU_LLVM=0 RANGEIFY=2 python3 -m pytest -n auto test/test_tiny.py test/test_rangeify.py test/test_ops.py --durations 20
|
||||
# slow (and still wrong on beautiful_mnist)
|
||||
|
||||
@@ -211,7 +211,7 @@ def assign_multi(dest:UOp, src:UOp):
|
||||
return dest.src[0].assign(src.src[0]).multi(src.axis)
|
||||
|
||||
def passthrough_multi(root:UOp, multi:UOp):
|
||||
return root.replace(src=(multi.src[0],)).multi(multi.axis)
|
||||
return UOp(root.op, root.dtype, (multi.src[0],), root.arg).multi(multi.axis)
|
||||
|
||||
# NOTE: this is the same pattern as Ops.UNROLL
|
||||
multi_pm = PatternMatcher([
|
||||
|
||||
@@ -582,7 +582,7 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
|
||||
|
||||
# rebuild the sink with all the BUFFERIZEs with tags, this is what's ending up in the tensor graph
|
||||
# if it's not tagged by here, it's out
|
||||
tsink = UOp.sink(*[x for x in tsink.parents if (x.op is Ops.BUFFERIZE or x.base.op in {Ops.CONST}) and x.tag is not None])
|
||||
tsink = UOp.sink(*[x for x in tsink.parents if x.base.op in {Ops.BUFFERIZE, Ops.CONST} and x.tag is not None])
|
||||
|
||||
if getenv("VIZ"): graph_rewrite(tsink, PatternMatcher([]), name="View Tagged Rangeify")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user