Files
tinygrad/test/external/external_test_schedule_scaling.py
nimlgen a5e9ea7a60 remove schedule batch 4 (#15927)
* remove schedule batch 4

* fini
2026-04-25 12:36:55 +03:00

135 lines
4.8 KiB
Python

import unittest, time
from tinygrad import Tensor
class TestScheduleScaling(unittest.TestCase):
"""Test that .schedule() scales linearly with graph size (no O(n^2) behavior)."""
def _assert_linear(self, fn, n_small=200, n_large=1000):
"""Assert schedule time scales at most ~linearly: time(n_large)/time(n_small) should be close to n_large/n_small."""
fn(n_small).schedule_linear() # warmup
t_small = min(self._time_schedule(fn, n) for n in [n_small]*3)
t_large = min(self._time_schedule(fn, n) for n in [n_large]*3)
size_ratio = n_large / n_small # 5.0
time_ratio = t_large / t_small
# O(n) -> time_ratio ~ 5, O(n^2) -> time_ratio ~ 25. threshold at 10 catches n^2 with margin.
self.assertLess(time_ratio / size_ratio, 2.0,
f"schedule appears superlinear: n={n_small} {t_small*1e3:.1f}ms, n={n_large} {t_large*1e3:.1f}ms "
f"(time grew {time_ratio:.1f}x for {size_ratio:.0f}x size, per-node ratio {time_ratio/size_ratio:.2f})")
@staticmethod
def _time_schedule(fn, n) -> float:
st = time.perf_counter()
fn(n).schedule_linear()
return time.perf_counter() - st
# *** rangeify: ending_ranges accumulation and consumer merge ***
# ending_ranges accumulation via sum([], []) and nested scan in run_rangeify.
# this creates reduce ops whose ending_ranges lists grow with graph depth, causing O(n^2) list copies.
def test_multi_reduce_scaling(self):
def multi_reduce(n):
x = Tensor.empty(256, 256)
for _ in range(n):
s = x.sum(axis=-1, keepdim=True)
x = x + s + s
return x
self._assert_linear(multi_reduce)
# reduce+elementwise chain stresses ending_ranges propagation and post-rangeify rewrites
def test_wide_reduce_scaling(self):
def wide_reduce(n):
x = Tensor.empty(256, 256)
for _ in range(n):
x = x + x.sum(axis=-1, keepdim=True)
return x
self._assert_linear(wide_reduce)
# expand ops inject into ending_ranges via the EXPAND path in run_rangeify
def test_expand_reduce_scaling(self):
def expand_reduce(n):
x = Tensor.empty(256, 1)
for _ in range(n):
y = x.expand(256, 256)
x = (y + y).sum(axis=-1, keepdim=True)
return x
self._assert_linear(expand_reduce)
# *** graph_rewrite: multi-consumer DAG patterns ***
# multi-consumer diamond pattern (fan-out/fan-in) stresses consumer_rngs merge in run_rangeify
def test_diamond_scaling(self):
def diamond(n):
x = Tensor.empty(256, 256)
for _ in range(n):
a = x + 1
b = x + 2
x = a + b
return x
self._assert_linear(diamond)
# elementwise chain baseline — should be trivially O(n)
def test_chain_scaling(self):
def chain(n):
x = Tensor.empty(256, 256)
for _ in range(n): x = x + 1
return x
self._assert_linear(chain)
# softmax has multi-consumer structure (x used for max, exp, and sum), stresses graph_rewrite on DAGs
def test_softmax_scaling(self):
def softmax_chain(n):
x = Tensor.empty(64, 256)
for _ in range(n): x = x.softmax(axis=-1)
return x
self._assert_linear(softmax_chain)
# *** post-rangeify: symbolic rewrites, kernel splitting ***
# matmul chain stresses symbolic+reduce_collapse and split_store
def test_matmul_scaling(self):
def matmul_chain(n):
xs = [Tensor.empty(32, 32) for _ in range(n + 1)]
result = xs[0]
for i in range(n): result = result @ xs[i + 1]
return result
self._assert_linear(matmul_chain)
# contiguous chain stresses remove_bufferize callbacks (toposort per BUFFERIZE node)
def test_contiguous_scaling(self):
def contiguous_chain(n):
x = Tensor.empty(256, 256)
for _ in range(n): x = (x + 1).contiguous()
return x
self._assert_linear(contiguous_chain)
# *** schedule: AFTER handling, assign ***
# assign chain stresses AFTER cycle detection (toposort inside toposort loop in get_rangeify_map)
def test_assign_scaling(self):
def assign_chain(n):
x = Tensor.empty(256, 256).realize()
for _ in range(n): x.assign(x + 1)
return x
self._assert_linear(assign_chain)
# layernorm has multi-consumer reduces (mean reused in variance), stresses consumer_rngs merge and symbolic rewrites
def test_layernorm_scaling(self):
def layernorm_chain(n):
x = Tensor.empty(64, 256)
for _ in range(n):
mean = x.mean(axis=-1, keepdim=True)
var = ((x - mean) ** 2).mean(axis=-1, keepdim=True)
x = (x - mean) / (var + 1e-5).sqrt()
return x
self._assert_linear(layernorm_chain)
# concat chain stresses MSTACK/MSELECT handling and wide SINK construction
def test_concat_scaling(self):
def concat_chain(n):
parts = [Tensor.empty(4, 256) + i for i in range(n)]
return parts[0].cat(*parts[1:])
self._assert_linear(concat_chain)
if __name__ == '__main__':
unittest.main(verbosity=2)