remove schedule batch 4 (#15927)

* remove schedule batch 4

* fini
This commit is contained in:
nimlgen
2026-04-25 12:36:55 +03:00
committed by GitHub
parent d2ab6ea7a6
commit a5e9ea7a60
8 changed files with 50 additions and 52 deletions

View File

@@ -12,7 +12,7 @@ from tinygrad.device import is_dtype_supported
from tinygrad.dtype import DType
from tinygrad.uop.ops import UOp, Ops, UPat
from tinygrad.helpers import CI, DEBUG, OSX, GlobalCounters, Context, getenv, all_same, temp
from tinygrad.engine.realize import CompiledRunner, compile_linear, run_linear
from tinygrad.engine.realize import compile_linear, run_linear
class KernelCountException(Exception): pass
def check_schedule(t:Tensor|list[Tensor]|UOp, allowed:int, to_prerealize:list[Tensor]|None=None, filter_sink=True):
@@ -787,7 +787,7 @@ class TestSchedule(unittest.TestCase):
gc.collect()
base = GlobalCounters.mem_used
x = Tensor.ones(256).contiguous().realize()
(x+Tensor.ones(256).contiguous()).schedule()
(x+Tensor.ones(256).contiguous()).schedule_linear()
gc.collect()
self.assertEqual(GlobalCounters.mem_used-base, 1024)
@@ -797,9 +797,8 @@ class TestSchedule(unittest.TestCase):
def cnt():
x, y, z = Tensor.empty((64, 64), dtype='float'), Tensor.empty((64, 64), dtype='float'), Tensor.empty((64, 64), dtype='float')
a = (x @ y).relu()
sched = ((a @ z).relu() + a).schedule()
for si in sched: si.lower()
return len([si for si in sched if isinstance(si.prg, CompiledRunner)])
linear = compile_linear(((a @ z).relu() + a).schedule_linear())
return len([call for call in linear.src if call.src[0].op is Ops.PROGRAM])
with Context(IMAGE=1):
self.assertEqual(cnt(), 5)
@@ -814,9 +813,8 @@ class TestSchedule(unittest.TestCase):
rb = (((((inp @ b1) + c1).relu() @ b2) + c2).relu() + inp).relu()
b16, c16 = Tensor.empty((512, 16), dtype='float'), Tensor.empty((16,), dtype='float')
b32, c32 = Tensor.empty((512, 32), dtype='float'), Tensor.empty((32,), dtype='float')
sched = Tensor.schedule((rb @ b16 + c16).relu(), (rb @ b32 + c32).relu())
for si in sched: si.lower()
return len([si for si in sched if isinstance(si.prg, CompiledRunner)])
linear = compile_linear(Tensor.schedule_linear((rb @ b16 + c16).relu(), (rb @ b32 + c32).relu()))
return len([call for call in linear.src if call.src[0].op is Ops.PROGRAM])
with Context(IMAGE=1):
self.assertEqual(cnt(), 9)
@@ -828,9 +826,8 @@ class TestSchedule(unittest.TestCase):
x, y, z = Tensor.empty((1, 4, 3, 3)), Tensor.empty((4, 1, 3, 3)), Tensor.empty((4, 1, 7, 7))
a = x.conv2d(y, Tensor.empty(4), groups=4, padding=1)
b = a.conv2d(z, groups=4, padding=3)
sched = (a + b).schedule()
for si in sched: si.lower()
return len([si for si in sched if isinstance(si.prg, CompiledRunner)])
linear = compile_linear((a + b).schedule_linear())
return len([call for call in linear.src if call.src[0].op is Ops.PROGRAM])
with Context(IMAGE=1):
self.assertEqual(cnt(), 5)
@@ -1332,7 +1329,7 @@ class TestCopyFolding(unittest.TestCase):
b = Tensor.empty(4, device="CPU")
add = a+b
assert all_same([x.device for x in add.uop.src]), f"ALU has different devices! {[x.device for x in add.src]}"
add.schedule()
add.schedule_linear()
def test_alu_before_copy(self):
buf = Tensor.ones(1).contiguous().realize()

View File

@@ -23,10 +23,10 @@ if __name__ == "__main__":
if not FORWARD_ONLY:
with Timing("***** model schedule in "):
with Profiling(PROFILE >= 3):
sched = out.schedule()
linear = out.schedule_linear()
if not SCHEDULE_ONLY:
asts = list({x.ast.key:x.ast for x in sched if x.ast.op is Ops.SINK}.values())
asts = list({call.src[0].key:call.src[0] for call in linear.src if call.src[0].op is Ops.SINK}.values())
if (restrict_kernel := getenv("RESTRICT_KERNEL", -1)) != -1: asts = asts[restrict_kernel:restrict_kernel+1]
with Profiling(PROFILE, fn="/tmp/rewrite.prof"):

View File

@@ -20,8 +20,8 @@ class TestHCQ(unittest.TestCase):
#TestHCQ.d1: AMDDevice = Device["AMD:1"]
TestHCQ.a = Tensor([0.,1.], device=Device.DEFAULT).realize()
TestHCQ.b = self.a + 1
si = self.b.schedule()[-1]
TestHCQ.runner = get_runner(TestHCQ.d0.device, si.ast)
linear = self.b.schedule_linear()
TestHCQ.runner = get_runner(TestHCQ.d0.device, linear.src[-1].src[0])
TestHCQ.b.uop.buffer.allocate()
# wow that's a lot of abstraction layers
TestHCQ.addr = struct.pack("QQ", TestHCQ.b.uop.buffer._buf, TestHCQ.a.uop.buffer._buf)

View File

@@ -10,8 +10,8 @@ from hypothesis import given, strategies as st
# copied from test_const_folding.py
def _check_ast_count(desired_count:int, t:Tensor):
# NOTE: this has side effect because everything can be scheduled only once
schedule = t.schedule()
asts = [s for s in schedule if s.ast.op is Ops.SINK]
linear = t.schedule_linear()
asts = [call for call in linear.src if call.src[0].op is Ops.SINK]
assert len(asts) == desired_count, f"{len(asts)} != {desired_count}"
def build_onnx(nodes, from_disk:bool=True, **kwargs):

View File

@@ -6,7 +6,7 @@ class TestScheduleScaling(unittest.TestCase):
def _assert_linear(self, fn, n_small=200, n_large=1000):
"""Assert schedule time scales at most ~linearly: time(n_large)/time(n_small) should be close to n_large/n_small."""
fn(n_small).schedule() # warmup
fn(n_small).schedule_linear() # warmup
t_small = min(self._time_schedule(fn, n) for n in [n_small]*3)
t_large = min(self._time_schedule(fn, n) for n in [n_large]*3)
size_ratio = n_large / n_small # 5.0
@@ -19,7 +19,7 @@ class TestScheduleScaling(unittest.TestCase):
@staticmethod
def _time_schedule(fn, n) -> float:
st = time.perf_counter()
fn(n).schedule()
fn(n).schedule_linear()
return time.perf_counter() - st
# *** rangeify: ending_ranges accumulation and consumer merge ***

View File

@@ -14,13 +14,13 @@ def print_uops():
def start(): pass
def single_tensor(): Tensor([2])
def two_plus_two(): Tensor([2])+Tensor([2])
def two_plus_two_schedule(): (Tensor([2])+Tensor([2])).schedule()
def two_plus_two_schedule(): (Tensor([2])+Tensor([2])).schedule_linear()
def two_plus_two_kernel():
si = (Tensor([2])+Tensor([2])).schedule()[-1]
get_program(si.ast, Device.default.renderer)
linear = (Tensor([2])+Tensor([2])).schedule_linear()
get_program(linear.src[-1].src[0], Device.default.renderer)
def two_plus_two_linearize():
si = (Tensor([2])+Tensor([2])).schedule()[-1]
get_program(si.ast, Device.default.renderer)
linear = (Tensor([2])+Tensor([2])).schedule_linear()
get_program(linear.src[-1].src[0], Device.default.renderer)
def two_plus_two_realize(): (Tensor([2])+Tensor([2])).realize()
def two_plus_two_item(): (Tensor([2])+Tensor([2])).item()
def gradient_test():
@@ -36,8 +36,8 @@ def kernel_matmul():
x = Tensor.eye(3, requires_grad=True)
y = Tensor([[2.0,0,-2.0]], requires_grad=True)
z = y.matmul(x)
si = z.schedule()[-1]
get_program(si.ast, Device.default.renderer)
linear = z.schedule_linear()
get_program(linear.src[-1].src[0], Device.default.renderer)
def realized_matmul():
x = Tensor.eye(3, requires_grad=True)
y = Tensor([[2.0,0,-2.0]], requires_grad=True)

View File

@@ -20,8 +20,8 @@ def gen_prg(device, inputs_cnt):
s = fst[0]
for i in range(1, inputs_cnt): s = s.bitwise_xor(fst[i])
si = s.schedule()[-1]
prg = get_runner(device, si.ast)
linear = s.schedule_linear()
prg = get_runner(device, linear.src[-1].src[0])
cached_prgs[(device, inputs_cnt)] = prg
return prg

View File

@@ -4,6 +4,7 @@ from tinygrad import nn, dtypes, Device, Tensor
from tinygrad.uop.ops import UOp, Ops, GroupOp, UPat, KernelInfo
from tinygrad.helpers import DEBUG, GlobalCounters, Context
from tinygrad.engine.realize import compile_linear, run_linear
from tinygrad.codegen import get_program
class KernelCountException(Exception): pass
def check_schedule(t:Tensor|list[Tensor]|UOp, allowed:int, to_prerealize:list[Tensor]|None=None, filter_sink=True):
@@ -141,7 +142,7 @@ class TestSimpleSchedule(unittest.TestCase):
a = Tensor.empty(16,16).sum(axis=1)
a1 = a.reshape(4,4)
a2 = a.reshape(16,1,1)
self.assertEqual(len(Tensor.schedule(a1, a2)), 1)
self.assertEqual(len(Tensor.schedule_linear(a1, a2).src), 1)
class TestSchedule(unittest.TestCase):
def test_create_schedule_handles_multi_kernel_after_and_after_deps(self):
@@ -166,8 +167,8 @@ class TestSchedule(unittest.TestCase):
kc = Tensor.custom_kernel(out, src_after, fxn=named_copy("kc"))[0]
out_after = Tensor(kc.uop.src[0].after(*kc.uop.src[1:], kd.uop))
schedule = out_after.schedule()
names = [si.ast.arg.name for si in schedule]
linear = out_after.schedule_linear()
names = [call.src[0].arg.name for call in linear.src]
self.assertEqual(set(names), {"ka", "kb", "kc", "kd"})
self.assertEqual(names[-1], "kc")
self.assertLess(names.index("ka"), names.index("kc"))
@@ -667,9 +668,9 @@ class TestSchedule(unittest.TestCase):
check_schedule(c, 2)
def _alu_from_tensor(self, t:Tensor):
s = [s for s in t.schedule() if s.ast.op is Ops.SINK]
s = [s for s in t.schedule_linear().src if s.src[0].op is Ops.SINK]
self.assertEqual(len(s), 1)
return [u.op for u in s[0].ast.toposort() if u.op in GroupOp.ALU]
return [u.op for u in s[0].src[0].toposort() if u.op in GroupOp.ALU]
def test_2_pow_is_exp2(self):
t = 2.0 ** Tensor([1.0, 2.0, 3.0])
@@ -798,12 +799,12 @@ class TestSchedule(unittest.TestCase):
Tensor.manual_seed(0)
x = Tensor.randn(4, 12, 64, 64, dtype=dtypes.half).realize()
out = x.softmax(dtype=dtypes.float)
sched = out.schedule()
self.assertEqual(len(sched), 3)
linear = out.schedule_linear()
self.assertEqual(len(linear.src), 3)
# max reduction stays in input dtype (no numerical loss), upcast happens after subtracting max
self.assertEqual(sched[0].bufs[0].dtype, dtypes.half)
self.assertEqual(sched[1].bufs[0].dtype, dtypes.float)
self.assertEqual(sched[2].bufs[0].dtype, dtypes.float)
self.assertEqual(linear.src[0].src[1].dtype, dtypes.half)
self.assertEqual(linear.src[1].src[1].dtype, dtypes.float)
self.assertEqual(linear.src[2].src[1].dtype, dtypes.float)
def test_softmax_backward(self):
Tensor.manual_seed(0)
@@ -960,7 +961,7 @@ class TestSchedule(unittest.TestCase):
gc.collect()
base = GlobalCounters.mem_used
Tensor.ones(256).contiguous().realize()
Tensor.ones(5, 5).contiguous().schedule()
Tensor.ones(5, 5).contiguous().schedule_linear()
gc.collect()
self.assertEqual(GlobalCounters.mem_used-base, 0)
@@ -1173,24 +1174,24 @@ class TestFusionOp(unittest.TestCase):
st = time.perf_counter()
a = Tensor([1,2,3,4])
for _ in range(24): a = a + a
sched = a.schedule()
sched[-1].lower()
linear = a.schedule_linear()
prg = get_program(linear.src[-1].src[0], renderer=Device[Device.DEFAULT].renderer)
self.assertLess(time.perf_counter()-st, 2.0)
assert len(sched[-1].prg.p.src.splitlines()) < 250
assert len(prg.src.splitlines()) < 250
def test_recursive_add_cmp(self):
st = time.perf_counter()
a = Tensor([1,2,3,4])
for _ in range(24): a = a + a
sched1 = a.schedule()
linear1 = a.schedule_linear()
b = Tensor([1,2,3,4])
for _ in range(24): b = b + b
sched2 = b.schedule()
linear2 = b.schedule_linear()
c = Tensor([1,2,3,4])
for _ in range(23): c = c + c
sched3 = c.schedule()
self.assertEqual(sched1[-1].ast, sched2[-1].ast)
with self.assertRaises(AssertionError): self.assertEqual(sched1[-1].ast, sched3[-1].ast)
linear3 = c.schedule_linear()
self.assertEqual(linear1.src[-1].src[0], linear2.src[-1].src[0])
with self.assertRaises(AssertionError): self.assertEqual(linear1.src[-1].src[0], linear3.src[-1].src[0])
self.assertLess(time.perf_counter()-st, 2.0)
def test_recursive_pad(self):
@@ -1198,8 +1199,8 @@ class TestFusionOp(unittest.TestCase):
val = 1.0
a = Tensor(val)
for _ in range(24): a = Tensor.stack(a, a)[0]
sched = a.schedule()
self.assertLessEqual(len(sched), 1)
linear = a.schedule_linear()
self.assertLessEqual(len(linear.src), 1)
self.assertLess(time.perf_counter()-st, 2.0)
def test_recursive_reshape(self):
@@ -1208,8 +1209,8 @@ class TestFusionOp(unittest.TestCase):
b = Tensor.empty(16, 2).realize()
r = a.sum(1)
for _ in range(24): r = r.reshape(16, 2) + b
sched = r.schedule()
self.assertEqual(len(sched), 1)
linear = r.schedule_linear()
self.assertEqual(len(linear.src), 1)
self.assertLess(time.perf_counter()-st, 2.0)
# NOTE: the NULL backend supports BUFFER_VIEW