update some tests for less Kernel (#11543)

* update some tests for less Kernel

* get_program update
This commit is contained in:
George Hotz
2025-08-06 14:19:59 -07:00
committed by GitHub
parent 09dc7af8e9
commit 6fd1332763
5 changed files with 28 additions and 151 deletions

View File

@@ -234,10 +234,7 @@ class TestLinearizer(unittest.TestCase):
x, w = Tensor.randn((1,1,3)).realize(), Tensor.randn((1,1,2)).realize()
r = Tensor.conv2d(x,w,padding=1).relu()
k = Kernel(r.schedule()[-1].ast)
k.apply_opt(Opt(op=OptOps.UPCAST, axis=0, arg=0))
k.apply_opt(Opt(op=OptOps.UNROLL, axis=0, arg=0))
uops = get_program(k.get_optimized_ast(), k.opts).uops
uops = get_program(r.schedule()[-1].ast, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=0), Opt(op=OptOps.UNROLL, axis=0, arg=0)]).uops
accs = [u for u in uops if u.op is Ops.DEFINE_REG]
stores = [u for u in uops if u.op is Ops.STORE]
assert len(accs) == 0 # it's removed now
@@ -249,9 +246,7 @@ class TestLinearizer(unittest.TestCase):
@unittest.skipUnless(Device.DEFAULT == "CPU", "test only for CPU")
def test_upcast_with_locals_cpu(self):
out = Tensor.ones(64,64).contiguous() @ Tensor.ones(64,64).contiguous()
k = Kernel(out.schedule()[-1].ast)
k.apply_opt(Opt(OptOps.LOCAL, axis=0, arg=4))
prg = get_program(k.get_optimized_ast(), k.opts)
prg = get_program(out.schedule()[-1].ast, opts=[Opt(OptOps.LOCAL, axis=0, arg=4)]).uops
self.assertEqual(len(prg.src.split("for")), 5)
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
@@ -261,10 +256,8 @@ class TestLinearizer(unittest.TestCase):
def test_upcast_with_locals(self):
x, y = Tensor.rand(1,128), Tensor.rand(128, 128)
r = (x@y).relu()
realized_ast = r.schedule()[-1].ast
opts_to_apply = [Opt(op=OptOps.GROUP, axis=0, arg=8), Opt(op=OptOps.LOCAL, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=4)]
realized_ast = realized_ast.replace(arg=KernelInfo(opts_to_apply=tuple(opts_to_apply)))
program = get_program(realized_ast, Device[Device.DEFAULT].renderer)
program = get_program(r.schedule()[-1].ast, opts=opts_to_apply)
stores = [u for u in program.uops if u.op is Ops.STORE and u.src[0].dtype.addrspace != AddrSpace.REG]
@@ -278,10 +271,7 @@ class TestLinearizer(unittest.TestCase):
def test_zero_fold(self):
a, b = Tensor.randn(1).realize(), Tensor.randn(1).realize()
r = Tensor.stack(a, b)
k = Kernel(r.schedule()[-1].ast)
k.apply_opt(Opt(op=OptOps.UPCAST, axis=0, arg=0))
uops = get_program(k.get_optimized_ast(), k.opts).uops
uops = get_program(r.schedule()[-1].ast, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=0)]).uops
num_ops = len([uop for uop in uops if uop.op in GroupOp.ALU])
assert num_ops == 0, "more alu uops than needed"
@@ -291,16 +281,14 @@ class TestLinearizer(unittest.TestCase):
if is_dtype_supported(tensor_dtype) and is_dtype_supported(acc_dtype):
a = Tensor([1, 2, 3], dtype=tensor_dtype).sum()
realized_ast = a.schedule()[-1].ast
realized_ast = realized_ast.replace(arg=KernelInfo(opts_to_apply=tuple()))
program = get_program(realized_ast, Device[Device.DEFAULT].renderer)
program = get_program(realized_ast, opts=[])
local = [uop for uop in program.uops if uop.op is Ops.DEFINE_REG]
assert local[0].dtype.base == acc_dtype
def test_arg_acc_dtype(self):
def helper_arg_acc_dtype(c: Tensor, expected_dtype:DType):
realized_ast = c.schedule()[-1].ast
realized_ast = realized_ast.replace(arg=KernelInfo(opts_to_apply=tuple()))
program = get_program(realized_ast, Device[Device.DEFAULT].renderer)
program = get_program(realized_ast, opts=[])
local = [uop for uop in program.uops if uop.op is Ops.DEFINE_REG]
self.assertEqual(local[0].dtype.base, expected_dtype)
@@ -758,11 +746,7 @@ class TestFloat4(unittest.TestCase):
c = a + b
s = c.schedule()[0]
k = Kernel(s.ast)
k.apply_opt(Opt(op=OptOps.UPCAST, axis=0, arg=4))
k.apply_opt(Opt(op=OptOps.UPCAST, axis=0, arg=2))
uops = get_program(k.get_optimized_ast(), k.opts).uops
uops = get_program(s.ast, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=2)]).uops
assert TestFloat4.count_float4(uops) == (4, 2)
@unittest.skipUnless(Device.DEFAULT in {"CPU", "LLVM"} and AMX, "Only CPU with AMX upcasts float up to size 16")
@@ -773,10 +757,7 @@ class TestFloat4(unittest.TestCase):
c = a + b
s = c.schedule()[0]
k = Kernel(s.ast)
k.apply_opt(Opt(op=OptOps.UPCAST, axis=0, arg=4))
k.apply_opt(Opt(op=OptOps.UPCAST, axis=0, arg=shift))
return get_program(k.get_optimized_ast(), k.opts).uops
return get_program(s.ast, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=shift)]).uops
sizes = [12, 8, 16]
shifts = [3, 2, 4]
@@ -806,10 +787,7 @@ class TestFloat4(unittest.TestCase):
c = a + b
s = c.schedule()[0]
k = Kernel(s.ast)
k.apply_opt(Opt(op=OptOps.UPCAST, axis=1, arg=4))
k.apply_opt(Opt(op=OptOps.UPCAST, axis=1, arg=2))
uops = get_program(k.get_optimized_ast(), k.opts).uops
uops = get_program(s.ast, opts=[Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=1, arg=2)]).uops
assert TestFloat4.count_float4(uops) == (0, 2)
@@ -842,9 +820,7 @@ class TestFloat4(unittest.TestCase):
# float4 should be emitted (the reduce axis of size 4 is the float4 axis here)
s = c.schedule()[0]
k = Kernel(s.ast)
k.apply_opt(Opt(op=OptOps.UNROLL, axis=0, arg=4))
uops = get_program(k.get_optimized_ast(), k.opts).uops
uops = get_program(s.ast, opts=[Opt(op=OptOps.UNROLL, axis=0, arg=4)]).uops
assert TestFloat4.count_float4(uops) == (0, 0)
@@ -858,10 +834,7 @@ class TestFloat4(unittest.TestCase):
# UPDATE: now we do this fusion
s = c.schedule()[0]
k = Kernel(s.ast)
k.apply_opt(Opt(op=OptOps.UPCAST, axis=0, arg=0))
k.apply_opt(Opt(op=OptOps.UNROLL, axis=0, arg=0))
uops = get_program(k.get_optimized_ast(), k.opts).uops
uops = get_program(s.ast, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=0), Opt(op=OptOps.UNROLL, axis=0, arg=0)]).uops
assert TestFloat4.count_float4(uops) in {(0,1), (1,1)}
@@ -874,9 +847,7 @@ class TestFloat4(unittest.TestCase):
# since the top axis is not contiguous.
s = c.schedule()[0]
k = Kernel(s.ast)
k.apply_opt(Opt(op=OptOps.UPCAST, axis=0, arg=4))
uops = get_program(k.get_optimized_ast(), k.opts).uops
uops = get_program(s.ast, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=4)]).uops
assert TestFloat4.count_float4(uops) == (0, 1)
@@ -888,9 +859,7 @@ class TestFloat4(unittest.TestCase):
# should float4 b but not a
s = c.schedule()[0]
k = Kernel(s.ast)
k.apply_opt(Opt(op=OptOps.UPCAST, axis=0, arg=4))
uops = get_program(k.get_optimized_ast(), k.opts).uops
uops = get_program(s.ast, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=4)]).uops
assert TestFloat4.count_float4(uops) == (1, 1)

View File

@@ -6,7 +6,7 @@ import unittest
import numpy as np
import functools
from typing import cast
from hypothesis import assume, given, strategies as strat
from hypothesis import assume, given, settings, strategies as strat
from tinygrad import nn, dtypes, Device, Tensor
from tinygrad.device import is_dtype_supported
@@ -16,7 +16,7 @@ from tinygrad.uop.ops import PatternMatcher, UOp, Ops, GroupOp, UPat, graph_rewr
from tinygrad.uop.symbolic import symbolic_simple
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, all_same, temp
from tinygrad.schedule.kernelize import merge_views, get_kernelize_map, Kernel
from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars
from tinygrad.engine.schedule import create_schedule_with_vars
from tinygrad.engine.realize import CompiledRunner, run_schedule, lower_schedule
class KernelCountException(Exception): pass
@@ -151,6 +151,7 @@ class TestSchedule(unittest.TestCase):
self.assertEqual(root.item(), sum(range(N)))
@given(strat.sampled_from(range(2,4)), strat.sampled_from(range(2,4)), strat.sampled_from(range(0,4)), strat.sampled_from(range(0,4)))
@settings(deadline=None)
def test_indexing_scalars(self, x, y, a, b):
assume(a<x and b<y)
X = Tensor.randn(x, y).realize()
@@ -1353,8 +1354,7 @@ class TestSchedule(unittest.TestCase):
r = a.sum(0) + 6
b = r.sum(0) * 4
c = r.sum(1) * 2
schedule = check_schedule([b, c], 3)
self.assertIs(store_val(schedule[0]).op, Ops.ADD)
check_schedule([b, c], 3)
def test_multireduce_simple_chase(self):
Tensor.manual_seed(0)
@@ -1376,8 +1376,7 @@ class TestSchedule(unittest.TestCase):
r = a.sum(2) + b
d = r.T * 4
e = r * d
schedule = check_schedule([d, e], 3)
self.assertIs(store_val(schedule[0]).op, Ops.ADD)
check_schedule([d, e], 3)
def test_multireduce_push_permute_chase(self):
Tensor.manual_seed(0)
@@ -1387,7 +1386,6 @@ class TestSchedule(unittest.TestCase):
d = r.T * 4
e = r * (d + a).sum(2)
schedule = check_schedule([d, e], 3) # make sure it doesn't fuse
self.assertIs(store_val(schedule[0]).op, Ops.ADD)
run_schedule(schedule)
np.testing.assert_allclose(d.numpy(), (a.numpy().sum(2) + b.numpy()).T * 4, atol=1e-4, rtol=1e-4)
np.testing.assert_allclose(e.numpy(), (a.numpy().sum(2) + b.numpy()) * (d.numpy() + a.numpy()).sum(2), atol=1e-4, rtol=1e-4)
@@ -1398,8 +1396,7 @@ class TestSchedule(unittest.TestCase):
c = Tensor.empty(16, )
r = a.sum(1) + c
d = r[:4] * b
schedule = check_schedule(d, 2)
self.assertIs(store_val(schedule[0]).op, Ops.ADD)
check_schedule(d, 2)
def test_multireduce_push_shrink_chase(self):
Tensor.manual_seed(0)
@@ -1411,15 +1408,13 @@ class TestSchedule(unittest.TestCase):
out = r[:4] * b + d.sum(1)[:4]
# schedule = check_schedule(out, 2)
schedule = check_schedule(out, 3)
self.assertIs(store_val(schedule[0]).op, Ops.ADD)
run_schedule(schedule)
np.testing.assert_allclose(out.numpy(), (a.numpy().sum(1) + c.numpy())[:4] * b.numpy() + d.numpy().sum(1)[:4], atol=1e-4, rtol=1e-4)
def test_midreduce_nochase(self):
a = Tensor.empty(16, 16)
b = (a.sum(0) + a.max(1)) + 2
schedule = check_schedule(b, 2)
self.assertIs(store_val(schedule[0]).op, Ops.REDUCE_AXIS)
check_schedule(b, 2)
def test_multireduce_midreduce_nochase(self):
Tensor.manual_seed(0)
@@ -1427,7 +1422,6 @@ class TestSchedule(unittest.TestCase):
b = (a.sum(0)+a.max(0) + a.max(1)+a.sum(1)) + 2
# schedule = check_schedule(b, 2)
schedule = check_schedule(b, 4)
self.assertIs(store_val(schedule[0]).op, Ops.REDUCE_AXIS)
run_schedule(schedule)
np.testing.assert_allclose(b.numpy(), a.numpy().sum(0)+a.numpy().max(0) + a.numpy().max(1)+a.numpy().sum(1)+2, atol=1e-4, rtol=1e-4)
@@ -1902,8 +1896,6 @@ class TestIndexing(unittest.TestCase):
a = Tensor.arange(4).reshape(2, 2, 1).expand(2, 2, 2).contiguous().to("CPU")
sched = self.check_schedule(a, 2) # NOTE: there is a contiguous between REDUCE_AXIS and COPY
self.assertIs(sched[2].ast.op, Ops.COPY)
self.assertIs(store_val(sched[1]).op, Ops.LOAD)
self.assertIs(store_val(sched[0]).op, Ops.ADD)
np.testing.assert_equal(a.numpy(), [[[0, 0], [1, 1]], [[2, 2], [3, 3]]])
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
@@ -1987,24 +1979,6 @@ class TestIndexing(unittest.TestCase):
self.assertEqual(new_uop.st, ShapeTracker.from_shape((4,)).reshape((4, 1)))
self.assertEqual(swizzle_cnt(new_uop), 0)
def test_no_rewrite_elementwise(self):
a = Tensor.empty(32, 32)
b = Tensor.empty(32, 32)
sink = (a+b).schedule()[0].ast
self.assertEqual(swizzle_cnt(sink), 0)
def test_simple_store_reshape(self):
a = Tensor.empty(32, 32).sum(axis=1)+Tensor.empty(1,32)
ast = a.schedule()[0].ast
self.assertEqual(ast.shape, (32, 1))
self.assertEqual(a.uop.shape, (1, 32))
def test_no_reshape_reduceop(self):
a = Tensor.empty(32, 32).sum(axis=(1,)).contiguous()
ast = a.schedule()[0].ast
self.assertEqual(ast.shape, (32, 1))
self.assertEqual(a.uop.shape, (32,))
def swizzle_cnt(u:UOp) -> int:
return len([x for x in u.toposort() if x.op is Ops.VIEW and len(x.src) != 0 and x.src[0].op not in {Ops.BUFFER, Ops.DEFINE_GLOBAL, Ops.ASSIGN}])
@@ -2108,7 +2082,6 @@ class TestSwizzle(unittest.TestCase):
run_schedule(check_schedule(t, 3))
np.testing.assert_equal(t.numpy(), [[0.5, 0.5], [0.5, 0.5], [0., 0.]])
def store_val(si:ScheduleItem): return si.ast.src[0].src[1]
zero_pm = UPat(Ops.CONST, arg=0)
class TestView(unittest.TestCase):
def test_all_masked_out(self):
@@ -2117,7 +2090,6 @@ class TestView(unittest.TestCase):
# all masked out, degrades to const 0
b = a.pad(((0, 10), None))[10:]
sched = check_schedule(b.contiguous(), 1)
assert zero_pm.match(store_val(sched[-1]), {})
run_schedule(sched)
np.testing.assert_equal(b.numpy(), 0)
@@ -2128,7 +2100,6 @@ class TestView(unittest.TestCase):
assert b.shape == (10, 10)
sched = check_schedule(b.contiguous(), 1)
self.assertEqual(sched[-1].ast.full_shape, (10, 10))
assert zero_pm.match(store_val(sched[-1]), {})
run_schedule(sched)
np.testing.assert_equal(b.numpy(), 0)
@@ -2143,8 +2114,6 @@ class TestView(unittest.TestCase):
b = a.pad(((0, 5), None))[5:]
assert b.shape == (10, 10)
sched = check_schedule(b.contiguous(), 1)
self.assertEqual(store_val(sched[-1]).op, Ops.LOAD)
self.assertEqual(store_val(sched[-1]).st_arg, b.uop.st)
run_schedule(sched)
np.testing.assert_allclose(b.numpy(), np.pad(a.numpy(), ((0, 5), (0, 0)))[5:])
@@ -2260,24 +2229,6 @@ class TestConst(unittest.TestCase):
sched = a.schedule()
self.assertEqual(len(sched), 1)
def test_const_ast(self):
a = Tensor.ones((4,)).pad((1, 1)).contiguous()
sched = a.schedule()
print(sched[0].ast)
const_ast_pattern = UPat(Ops.SINK, src=(UPat.store(UPat(), UPat.where(UPat(Ops.VALID), UPat.cvar("x"), UPat(Ops.CONST, arg=0))),))
self.assertEqual(len(const_ast_pattern.match(sched[0].ast, {})), 1)
run_schedule(sched)
self.assertListEqual(a.tolist(), [0, 1, 1, 1, 1, 0])
def test_unmasked_const_ast(self):
a = Tensor.ones((4,)).contiguous()
sched = a.schedule()
print(sched[0].ast)
const_ast_pattern = UPat(Ops.SINK, src=(UPat.store(UPat(), UPat(Ops.CONST)),))
self.assertEqual(len(const_ast_pattern.match(sched[0].ast, {})), 1)
run_schedule(sched)
self.assertListEqual(a.tolist(), [1, 1, 1, 1])
# ** part 2: scheduler behavior when const folding happens later
def test_const_folding_no_realize(self):

View File

@@ -508,19 +508,6 @@ class TestShapeSpec(unittest.TestCase):
a = Tensor.ones((4, 4)).uop
self.assertEqual(a.st, ShapeTracker.from_shape(()).reshape((1,1)).expand((4,4)))
def test_padded_const(self):
a = Tensor.ones((1, 1)).pad(((1, 1), (1, 1)))
ast = a.contiguous().schedule()[0].ast
valid_pattern = UPat(Ops.WHERE, src=(UPat(Ops.VALID), UPat.cvar(), UPat.cvar()))
valid_ternary = [x for x in ast.toposort() if valid_pattern.match(x, {})][0]
# the WHERE outputs a contiguous (3, 3)
self.assertEqual(valid_ternary.st, ShapeTracker.from_shape((3, 3)))
valid, x, y = valid_ternary.src
# very notably, only the first source is padded
self.assertIsNotNone(valid.st.views[-1].mask)
assert x.st.views[-1].mask is y.st.views[-1].mask is None
assert all(s.shape == (3, 3) for s in valid_ternary.src)
# NOTE: CONST ShapeTracker comes from its source
def test_scalar_const(self):
a = Tensor(0).uop

View File

@@ -6,7 +6,7 @@ from tinygrad.renderer import Estimates
from tinygrad.codegen import full_rewrite
from tinygrad.uop.ops import Ops, UOp
from tinygrad.dtype import dtypes
from tinygrad.opt.kernel import Kernel, Opt, OptOps, KernelOptError
from tinygrad.opt.kernel import Opt, OptOps, KernelOptError
from tinygrad.device import Device
def flops_mem(uops, ignore_indexing=False):
@@ -178,10 +178,10 @@ class TestStatsOptimized(unittest.TestCase):
self.assertEqual(p.estimates.lds, 2*N*N*N*4 + 4*N*N)
def test_gemm_tc_unroll(self):
k = Kernel(self.ast_gemm)
if not k.apply_tensor_cores(): self.skipTest("no tensor cores")
k.apply_opt(Opt(OptOps.UNROLL, 0, 2))
p = get_program(k.get_optimized_ast(), k.opts)
try:
p = get_program(self.ast_gemm, opts=[Opt(OptOps.TC, 0, (-1, 0, 1)), Opt(OptOps.UNROLL, 0, 2)])
except KernelOptError:
raise unittest.SkipTest("no tensor cores")
print(p.src)
self.check_gemm(p)
@@ -217,19 +217,16 @@ class TestStatsOptimized(unittest.TestCase):
self.assertEqual(p.estimates.lds, 2*N*N*N*4 + SZ*4 + (SZ*4 + 4*N*N)*4)
def test_reduce(self):
k = Kernel(self.ast_reduce)
p = get_program(k.get_optimized_ast(), k.opts)
p = get_program(self.ast_reduce, opts=[])
print(p.name, p.estimates.ops, p.estimates.mem, p.estimates.lds)
self.assertEqual(p.estimates.ops, N*N)
self.assertEqual(p.estimates.mem, N*N*4 + 4)
def test_reduce_group(self):
k = Kernel(self.ast_reduce)
try:
k.apply_opt(Opt(OptOps.GROUP, 0, 50))
p = get_program(self.ast_reduce, opts=[Opt(OptOps.GROUP, 0, 50)])
except KernelOptError:
raise unittest.SkipTest("no locals")
p = get_program(k.get_optimized_ast(), k.opts)
# NOTE: these are wrong, they don't respect the if statement
print(p.name, p.estimates.ops, p.estimates.mem, p.estimates.lds)

View File

@@ -1,10 +1,7 @@
import unittest
import numpy as np
from tinygrad import Tensor, GlobalCounters, dtypes, Context, nn
from tinygrad.uop.ops import Ops
from tinygrad.helpers import Timing, CI, Profiling, WINO, DEBUG, getenv
from tinygrad.opt.kernel import Kernel
from tinygrad.opt.heuristic import hand_coded_optimizations
from tinygrad.helpers import CI, Profiling, WINO, getenv
class TestWinogradClose(unittest.TestCase):
def test_close(self):
@@ -28,30 +25,6 @@ class TestWinograd(unittest.TestCase):
def tearDown(self):
WINO.value = self.old
def test_speed(self):
x = Tensor.empty(1,4,9,9)
w = Tensor.empty(4,4,3,3)
with Timing("running conv: "):
out = Tensor.conv2d(x, w)
with Timing("scheduling: "):
sched = out.schedule()
for i,s in enumerate(sched):
if s.ast.op is not Ops.SINK: continue
ops = s.ast.toposort()
with Timing(f"linearize {i} with {len(ops):4d} ops: "):
l = Kernel(s.ast)
l.apply_opts(hand_coded_optimizations(l))
assert len(l.sts) <= 256 # just the current value to prevent regression
if DEBUG >= 2: print(f"{len(l.sts):4d} shapetrackers with max {max(len(x.views) for x in l.sts)} views")
for st in l.sts:
assert len(st.views) <= 2, "too many views in winograd"
if DEBUG >= 3:
print(f"{len(st.views):3d} views")
for v in st.views: print(v)
def test_profile(self):
x,w = Tensor.rand(1,4,9,9).realize(), Tensor.rand(4,4,3,3).realize()
with Profiling(enabled=not CI, sort='time'):