mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
update some tests for less Kernel (#11543)
* update some tests for less Kernel * get_program update
This commit is contained in:
@@ -234,10 +234,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
x, w = Tensor.randn((1,1,3)).realize(), Tensor.randn((1,1,2)).realize()
|
||||
r = Tensor.conv2d(x,w,padding=1).relu()
|
||||
|
||||
k = Kernel(r.schedule()[-1].ast)
|
||||
k.apply_opt(Opt(op=OptOps.UPCAST, axis=0, arg=0))
|
||||
k.apply_opt(Opt(op=OptOps.UNROLL, axis=0, arg=0))
|
||||
uops = get_program(k.get_optimized_ast(), k.opts).uops
|
||||
uops = get_program(r.schedule()[-1].ast, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=0), Opt(op=OptOps.UNROLL, axis=0, arg=0)]).uops
|
||||
accs = [u for u in uops if u.op is Ops.DEFINE_REG]
|
||||
stores = [u for u in uops if u.op is Ops.STORE]
|
||||
assert len(accs) == 0 # it's removed now
|
||||
@@ -249,9 +246,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
@unittest.skipUnless(Device.DEFAULT == "CPU", "test only for CPU")
|
||||
def test_upcast_with_locals_cpu(self):
|
||||
out = Tensor.ones(64,64).contiguous() @ Tensor.ones(64,64).contiguous()
|
||||
k = Kernel(out.schedule()[-1].ast)
|
||||
k.apply_opt(Opt(OptOps.LOCAL, axis=0, arg=4))
|
||||
prg = get_program(k.get_optimized_ast(), k.opts)
|
||||
prg = get_program(out.schedule()[-1].ast, opts=[Opt(OptOps.LOCAL, axis=0, arg=4)]).uops
|
||||
self.assertEqual(len(prg.src.split("for")), 5)
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
|
||||
@@ -261,10 +256,8 @@ class TestLinearizer(unittest.TestCase):
|
||||
def test_upcast_with_locals(self):
|
||||
x, y = Tensor.rand(1,128), Tensor.rand(128, 128)
|
||||
r = (x@y).relu()
|
||||
realized_ast = r.schedule()[-1].ast
|
||||
opts_to_apply = [Opt(op=OptOps.GROUP, axis=0, arg=8), Opt(op=OptOps.LOCAL, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=4)]
|
||||
realized_ast = realized_ast.replace(arg=KernelInfo(opts_to_apply=tuple(opts_to_apply)))
|
||||
program = get_program(realized_ast, Device[Device.DEFAULT].renderer)
|
||||
program = get_program(r.schedule()[-1].ast, opts=opts_to_apply)
|
||||
|
||||
stores = [u for u in program.uops if u.op is Ops.STORE and u.src[0].dtype.addrspace != AddrSpace.REG]
|
||||
|
||||
@@ -278,10 +271,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
def test_zero_fold(self):
|
||||
a, b = Tensor.randn(1).realize(), Tensor.randn(1).realize()
|
||||
r = Tensor.stack(a, b)
|
||||
|
||||
k = Kernel(r.schedule()[-1].ast)
|
||||
k.apply_opt(Opt(op=OptOps.UPCAST, axis=0, arg=0))
|
||||
uops = get_program(k.get_optimized_ast(), k.opts).uops
|
||||
uops = get_program(r.schedule()[-1].ast, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=0)]).uops
|
||||
num_ops = len([uop for uop in uops if uop.op in GroupOp.ALU])
|
||||
assert num_ops == 0, "more alu uops than needed"
|
||||
|
||||
@@ -291,16 +281,14 @@ class TestLinearizer(unittest.TestCase):
|
||||
if is_dtype_supported(tensor_dtype) and is_dtype_supported(acc_dtype):
|
||||
a = Tensor([1, 2, 3], dtype=tensor_dtype).sum()
|
||||
realized_ast = a.schedule()[-1].ast
|
||||
realized_ast = realized_ast.replace(arg=KernelInfo(opts_to_apply=tuple()))
|
||||
program = get_program(realized_ast, Device[Device.DEFAULT].renderer)
|
||||
program = get_program(realized_ast, opts=[])
|
||||
local = [uop for uop in program.uops if uop.op is Ops.DEFINE_REG]
|
||||
assert local[0].dtype.base == acc_dtype
|
||||
|
||||
def test_arg_acc_dtype(self):
|
||||
def helper_arg_acc_dtype(c: Tensor, expected_dtype:DType):
|
||||
realized_ast = c.schedule()[-1].ast
|
||||
realized_ast = realized_ast.replace(arg=KernelInfo(opts_to_apply=tuple()))
|
||||
program = get_program(realized_ast, Device[Device.DEFAULT].renderer)
|
||||
program = get_program(realized_ast, opts=[])
|
||||
local = [uop for uop in program.uops if uop.op is Ops.DEFINE_REG]
|
||||
self.assertEqual(local[0].dtype.base, expected_dtype)
|
||||
|
||||
@@ -758,11 +746,7 @@ class TestFloat4(unittest.TestCase):
|
||||
c = a + b
|
||||
|
||||
s = c.schedule()[0]
|
||||
k = Kernel(s.ast)
|
||||
k.apply_opt(Opt(op=OptOps.UPCAST, axis=0, arg=4))
|
||||
k.apply_opt(Opt(op=OptOps.UPCAST, axis=0, arg=2))
|
||||
uops = get_program(k.get_optimized_ast(), k.opts).uops
|
||||
|
||||
uops = get_program(s.ast, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=2)]).uops
|
||||
assert TestFloat4.count_float4(uops) == (4, 2)
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT in {"CPU", "LLVM"} and AMX, "Only CPU with AMX upcasts float up to size 16")
|
||||
@@ -773,10 +757,7 @@ class TestFloat4(unittest.TestCase):
|
||||
c = a + b
|
||||
|
||||
s = c.schedule()[0]
|
||||
k = Kernel(s.ast)
|
||||
k.apply_opt(Opt(op=OptOps.UPCAST, axis=0, arg=4))
|
||||
k.apply_opt(Opt(op=OptOps.UPCAST, axis=0, arg=shift))
|
||||
return get_program(k.get_optimized_ast(), k.opts).uops
|
||||
return get_program(s.ast, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=shift)]).uops
|
||||
|
||||
sizes = [12, 8, 16]
|
||||
shifts = [3, 2, 4]
|
||||
@@ -806,10 +787,7 @@ class TestFloat4(unittest.TestCase):
|
||||
c = a + b
|
||||
|
||||
s = c.schedule()[0]
|
||||
k = Kernel(s.ast)
|
||||
k.apply_opt(Opt(op=OptOps.UPCAST, axis=1, arg=4))
|
||||
k.apply_opt(Opt(op=OptOps.UPCAST, axis=1, arg=2))
|
||||
uops = get_program(k.get_optimized_ast(), k.opts).uops
|
||||
uops = get_program(s.ast, opts=[Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=1, arg=2)]).uops
|
||||
|
||||
assert TestFloat4.count_float4(uops) == (0, 2)
|
||||
|
||||
@@ -842,9 +820,7 @@ class TestFloat4(unittest.TestCase):
|
||||
# float4 should be emitted (the reduce axis of size 4 is the float4 axis here)
|
||||
|
||||
s = c.schedule()[0]
|
||||
k = Kernel(s.ast)
|
||||
k.apply_opt(Opt(op=OptOps.UNROLL, axis=0, arg=4))
|
||||
uops = get_program(k.get_optimized_ast(), k.opts).uops
|
||||
uops = get_program(s.ast, opts=[Opt(op=OptOps.UNROLL, axis=0, arg=4)]).uops
|
||||
|
||||
assert TestFloat4.count_float4(uops) == (0, 0)
|
||||
|
||||
@@ -858,10 +834,7 @@ class TestFloat4(unittest.TestCase):
|
||||
# UPDATE: now we do this fusion
|
||||
|
||||
s = c.schedule()[0]
|
||||
k = Kernel(s.ast)
|
||||
k.apply_opt(Opt(op=OptOps.UPCAST, axis=0, arg=0))
|
||||
k.apply_opt(Opt(op=OptOps.UNROLL, axis=0, arg=0))
|
||||
uops = get_program(k.get_optimized_ast(), k.opts).uops
|
||||
uops = get_program(s.ast, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=0), Opt(op=OptOps.UNROLL, axis=0, arg=0)]).uops
|
||||
|
||||
assert TestFloat4.count_float4(uops) in {(0,1), (1,1)}
|
||||
|
||||
@@ -874,9 +847,7 @@ class TestFloat4(unittest.TestCase):
|
||||
# since the top axis is not contiguous.
|
||||
|
||||
s = c.schedule()[0]
|
||||
k = Kernel(s.ast)
|
||||
k.apply_opt(Opt(op=OptOps.UPCAST, axis=0, arg=4))
|
||||
uops = get_program(k.get_optimized_ast(), k.opts).uops
|
||||
uops = get_program(s.ast, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=4)]).uops
|
||||
|
||||
assert TestFloat4.count_float4(uops) == (0, 1)
|
||||
|
||||
@@ -888,9 +859,7 @@ class TestFloat4(unittest.TestCase):
|
||||
# should float4 b but not a
|
||||
|
||||
s = c.schedule()[0]
|
||||
k = Kernel(s.ast)
|
||||
k.apply_opt(Opt(op=OptOps.UPCAST, axis=0, arg=4))
|
||||
uops = get_program(k.get_optimized_ast(), k.opts).uops
|
||||
uops = get_program(s.ast, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=4)]).uops
|
||||
|
||||
assert TestFloat4.count_float4(uops) == (1, 1)
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ import unittest
|
||||
import numpy as np
|
||||
import functools
|
||||
from typing import cast
|
||||
from hypothesis import assume, given, strategies as strat
|
||||
from hypothesis import assume, given, settings, strategies as strat
|
||||
|
||||
from tinygrad import nn, dtypes, Device, Tensor
|
||||
from tinygrad.device import is_dtype_supported
|
||||
@@ -16,7 +16,7 @@ from tinygrad.uop.ops import PatternMatcher, UOp, Ops, GroupOp, UPat, graph_rewr
|
||||
from tinygrad.uop.symbolic import symbolic_simple
|
||||
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, all_same, temp
|
||||
from tinygrad.schedule.kernelize import merge_views, get_kernelize_map, Kernel
|
||||
from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars
|
||||
from tinygrad.engine.schedule import create_schedule_with_vars
|
||||
from tinygrad.engine.realize import CompiledRunner, run_schedule, lower_schedule
|
||||
|
||||
class KernelCountException(Exception): pass
|
||||
@@ -151,6 +151,7 @@ class TestSchedule(unittest.TestCase):
|
||||
self.assertEqual(root.item(), sum(range(N)))
|
||||
|
||||
@given(strat.sampled_from(range(2,4)), strat.sampled_from(range(2,4)), strat.sampled_from(range(0,4)), strat.sampled_from(range(0,4)))
|
||||
@settings(deadline=None)
|
||||
def test_indexing_scalars(self, x, y, a, b):
|
||||
assume(a<x and b<y)
|
||||
X = Tensor.randn(x, y).realize()
|
||||
@@ -1353,8 +1354,7 @@ class TestSchedule(unittest.TestCase):
|
||||
r = a.sum(0) + 6
|
||||
b = r.sum(0) * 4
|
||||
c = r.sum(1) * 2
|
||||
schedule = check_schedule([b, c], 3)
|
||||
self.assertIs(store_val(schedule[0]).op, Ops.ADD)
|
||||
check_schedule([b, c], 3)
|
||||
|
||||
def test_multireduce_simple_chase(self):
|
||||
Tensor.manual_seed(0)
|
||||
@@ -1376,8 +1376,7 @@ class TestSchedule(unittest.TestCase):
|
||||
r = a.sum(2) + b
|
||||
d = r.T * 4
|
||||
e = r * d
|
||||
schedule = check_schedule([d, e], 3)
|
||||
self.assertIs(store_val(schedule[0]).op, Ops.ADD)
|
||||
check_schedule([d, e], 3)
|
||||
|
||||
def test_multireduce_push_permute_chase(self):
|
||||
Tensor.manual_seed(0)
|
||||
@@ -1387,7 +1386,6 @@ class TestSchedule(unittest.TestCase):
|
||||
d = r.T * 4
|
||||
e = r * (d + a).sum(2)
|
||||
schedule = check_schedule([d, e], 3) # make sure it doesn't fuse
|
||||
self.assertIs(store_val(schedule[0]).op, Ops.ADD)
|
||||
run_schedule(schedule)
|
||||
np.testing.assert_allclose(d.numpy(), (a.numpy().sum(2) + b.numpy()).T * 4, atol=1e-4, rtol=1e-4)
|
||||
np.testing.assert_allclose(e.numpy(), (a.numpy().sum(2) + b.numpy()) * (d.numpy() + a.numpy()).sum(2), atol=1e-4, rtol=1e-4)
|
||||
@@ -1398,8 +1396,7 @@ class TestSchedule(unittest.TestCase):
|
||||
c = Tensor.empty(16, )
|
||||
r = a.sum(1) + c
|
||||
d = r[:4] * b
|
||||
schedule = check_schedule(d, 2)
|
||||
self.assertIs(store_val(schedule[0]).op, Ops.ADD)
|
||||
check_schedule(d, 2)
|
||||
|
||||
def test_multireduce_push_shrink_chase(self):
|
||||
Tensor.manual_seed(0)
|
||||
@@ -1411,15 +1408,13 @@ class TestSchedule(unittest.TestCase):
|
||||
out = r[:4] * b + d.sum(1)[:4]
|
||||
# schedule = check_schedule(out, 2)
|
||||
schedule = check_schedule(out, 3)
|
||||
self.assertIs(store_val(schedule[0]).op, Ops.ADD)
|
||||
run_schedule(schedule)
|
||||
np.testing.assert_allclose(out.numpy(), (a.numpy().sum(1) + c.numpy())[:4] * b.numpy() + d.numpy().sum(1)[:4], atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_midreduce_nochase(self):
|
||||
a = Tensor.empty(16, 16)
|
||||
b = (a.sum(0) + a.max(1)) + 2
|
||||
schedule = check_schedule(b, 2)
|
||||
self.assertIs(store_val(schedule[0]).op, Ops.REDUCE_AXIS)
|
||||
check_schedule(b, 2)
|
||||
|
||||
def test_multireduce_midreduce_nochase(self):
|
||||
Tensor.manual_seed(0)
|
||||
@@ -1427,7 +1422,6 @@ class TestSchedule(unittest.TestCase):
|
||||
b = (a.sum(0)+a.max(0) + a.max(1)+a.sum(1)) + 2
|
||||
# schedule = check_schedule(b, 2)
|
||||
schedule = check_schedule(b, 4)
|
||||
self.assertIs(store_val(schedule[0]).op, Ops.REDUCE_AXIS)
|
||||
run_schedule(schedule)
|
||||
np.testing.assert_allclose(b.numpy(), a.numpy().sum(0)+a.numpy().max(0) + a.numpy().max(1)+a.numpy().sum(1)+2, atol=1e-4, rtol=1e-4)
|
||||
|
||||
@@ -1902,8 +1896,6 @@ class TestIndexing(unittest.TestCase):
|
||||
a = Tensor.arange(4).reshape(2, 2, 1).expand(2, 2, 2).contiguous().to("CPU")
|
||||
sched = self.check_schedule(a, 2) # NOTE: there is a contiguous between REDUCE_AXIS and COPY
|
||||
self.assertIs(sched[2].ast.op, Ops.COPY)
|
||||
self.assertIs(store_val(sched[1]).op, Ops.LOAD)
|
||||
self.assertIs(store_val(sched[0]).op, Ops.ADD)
|
||||
np.testing.assert_equal(a.numpy(), [[[0, 0], [1, 1]], [[2, 2], [3, 3]]])
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
|
||||
@@ -1987,24 +1979,6 @@ class TestIndexing(unittest.TestCase):
|
||||
self.assertEqual(new_uop.st, ShapeTracker.from_shape((4,)).reshape((4, 1)))
|
||||
self.assertEqual(swizzle_cnt(new_uop), 0)
|
||||
|
||||
def test_no_rewrite_elementwise(self):
|
||||
a = Tensor.empty(32, 32)
|
||||
b = Tensor.empty(32, 32)
|
||||
sink = (a+b).schedule()[0].ast
|
||||
self.assertEqual(swizzle_cnt(sink), 0)
|
||||
|
||||
def test_simple_store_reshape(self):
|
||||
a = Tensor.empty(32, 32).sum(axis=1)+Tensor.empty(1,32)
|
||||
ast = a.schedule()[0].ast
|
||||
self.assertEqual(ast.shape, (32, 1))
|
||||
self.assertEqual(a.uop.shape, (1, 32))
|
||||
|
||||
def test_no_reshape_reduceop(self):
|
||||
a = Tensor.empty(32, 32).sum(axis=(1,)).contiguous()
|
||||
ast = a.schedule()[0].ast
|
||||
self.assertEqual(ast.shape, (32, 1))
|
||||
self.assertEqual(a.uop.shape, (32,))
|
||||
|
||||
def swizzle_cnt(u:UOp) -> int:
|
||||
return len([x for x in u.toposort() if x.op is Ops.VIEW and len(x.src) != 0 and x.src[0].op not in {Ops.BUFFER, Ops.DEFINE_GLOBAL, Ops.ASSIGN}])
|
||||
|
||||
@@ -2108,7 +2082,6 @@ class TestSwizzle(unittest.TestCase):
|
||||
run_schedule(check_schedule(t, 3))
|
||||
np.testing.assert_equal(t.numpy(), [[0.5, 0.5], [0.5, 0.5], [0., 0.]])
|
||||
|
||||
def store_val(si:ScheduleItem): return si.ast.src[0].src[1]
|
||||
zero_pm = UPat(Ops.CONST, arg=0)
|
||||
class TestView(unittest.TestCase):
|
||||
def test_all_masked_out(self):
|
||||
@@ -2117,7 +2090,6 @@ class TestView(unittest.TestCase):
|
||||
# all masked out, degrades to const 0
|
||||
b = a.pad(((0, 10), None))[10:]
|
||||
sched = check_schedule(b.contiguous(), 1)
|
||||
assert zero_pm.match(store_val(sched[-1]), {})
|
||||
run_schedule(sched)
|
||||
np.testing.assert_equal(b.numpy(), 0)
|
||||
|
||||
@@ -2128,7 +2100,6 @@ class TestView(unittest.TestCase):
|
||||
assert b.shape == (10, 10)
|
||||
sched = check_schedule(b.contiguous(), 1)
|
||||
self.assertEqual(sched[-1].ast.full_shape, (10, 10))
|
||||
assert zero_pm.match(store_val(sched[-1]), {})
|
||||
run_schedule(sched)
|
||||
np.testing.assert_equal(b.numpy(), 0)
|
||||
|
||||
@@ -2143,8 +2114,6 @@ class TestView(unittest.TestCase):
|
||||
b = a.pad(((0, 5), None))[5:]
|
||||
assert b.shape == (10, 10)
|
||||
sched = check_schedule(b.contiguous(), 1)
|
||||
self.assertEqual(store_val(sched[-1]).op, Ops.LOAD)
|
||||
self.assertEqual(store_val(sched[-1]).st_arg, b.uop.st)
|
||||
run_schedule(sched)
|
||||
np.testing.assert_allclose(b.numpy(), np.pad(a.numpy(), ((0, 5), (0, 0)))[5:])
|
||||
|
||||
@@ -2260,24 +2229,6 @@ class TestConst(unittest.TestCase):
|
||||
sched = a.schedule()
|
||||
self.assertEqual(len(sched), 1)
|
||||
|
||||
def test_const_ast(self):
|
||||
a = Tensor.ones((4,)).pad((1, 1)).contiguous()
|
||||
sched = a.schedule()
|
||||
print(sched[0].ast)
|
||||
const_ast_pattern = UPat(Ops.SINK, src=(UPat.store(UPat(), UPat.where(UPat(Ops.VALID), UPat.cvar("x"), UPat(Ops.CONST, arg=0))),))
|
||||
self.assertEqual(len(const_ast_pattern.match(sched[0].ast, {})), 1)
|
||||
run_schedule(sched)
|
||||
self.assertListEqual(a.tolist(), [0, 1, 1, 1, 1, 0])
|
||||
|
||||
def test_unmasked_const_ast(self):
|
||||
a = Tensor.ones((4,)).contiguous()
|
||||
sched = a.schedule()
|
||||
print(sched[0].ast)
|
||||
const_ast_pattern = UPat(Ops.SINK, src=(UPat.store(UPat(), UPat(Ops.CONST)),))
|
||||
self.assertEqual(len(const_ast_pattern.match(sched[0].ast, {})), 1)
|
||||
run_schedule(sched)
|
||||
self.assertListEqual(a.tolist(), [1, 1, 1, 1])
|
||||
|
||||
# ** part 2: scheduler behavior when const folding happens later
|
||||
|
||||
def test_const_folding_no_realize(self):
|
||||
|
||||
@@ -508,19 +508,6 @@ class TestShapeSpec(unittest.TestCase):
|
||||
a = Tensor.ones((4, 4)).uop
|
||||
self.assertEqual(a.st, ShapeTracker.from_shape(()).reshape((1,1)).expand((4,4)))
|
||||
|
||||
def test_padded_const(self):
|
||||
a = Tensor.ones((1, 1)).pad(((1, 1), (1, 1)))
|
||||
ast = a.contiguous().schedule()[0].ast
|
||||
valid_pattern = UPat(Ops.WHERE, src=(UPat(Ops.VALID), UPat.cvar(), UPat.cvar()))
|
||||
valid_ternary = [x for x in ast.toposort() if valid_pattern.match(x, {})][0]
|
||||
# the WHERE outputs a contiguous (3, 3)
|
||||
self.assertEqual(valid_ternary.st, ShapeTracker.from_shape((3, 3)))
|
||||
valid, x, y = valid_ternary.src
|
||||
# very notably, only the first source is padded
|
||||
self.assertIsNotNone(valid.st.views[-1].mask)
|
||||
assert x.st.views[-1].mask is y.st.views[-1].mask is None
|
||||
assert all(s.shape == (3, 3) for s in valid_ternary.src)
|
||||
|
||||
# NOTE: CONST ShapeTracker comes from its source
|
||||
def test_scalar_const(self):
|
||||
a = Tensor(0).uop
|
||||
|
||||
@@ -6,7 +6,7 @@ from tinygrad.renderer import Estimates
|
||||
from tinygrad.codegen import full_rewrite
|
||||
from tinygrad.uop.ops import Ops, UOp
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.opt.kernel import Kernel, Opt, OptOps, KernelOptError
|
||||
from tinygrad.opt.kernel import Opt, OptOps, KernelOptError
|
||||
from tinygrad.device import Device
|
||||
|
||||
def flops_mem(uops, ignore_indexing=False):
|
||||
@@ -178,10 +178,10 @@ class TestStatsOptimized(unittest.TestCase):
|
||||
self.assertEqual(p.estimates.lds, 2*N*N*N*4 + 4*N*N)
|
||||
|
||||
def test_gemm_tc_unroll(self):
|
||||
k = Kernel(self.ast_gemm)
|
||||
if not k.apply_tensor_cores(): self.skipTest("no tensor cores")
|
||||
k.apply_opt(Opt(OptOps.UNROLL, 0, 2))
|
||||
p = get_program(k.get_optimized_ast(), k.opts)
|
||||
try:
|
||||
p = get_program(self.ast_gemm, opts=[Opt(OptOps.TC, 0, (-1, 0, 1)), Opt(OptOps.UNROLL, 0, 2)])
|
||||
except KernelOptError:
|
||||
raise unittest.SkipTest("no tensor cores")
|
||||
print(p.src)
|
||||
self.check_gemm(p)
|
||||
|
||||
@@ -217,19 +217,16 @@ class TestStatsOptimized(unittest.TestCase):
|
||||
self.assertEqual(p.estimates.lds, 2*N*N*N*4 + SZ*4 + (SZ*4 + 4*N*N)*4)
|
||||
|
||||
def test_reduce(self):
|
||||
k = Kernel(self.ast_reduce)
|
||||
p = get_program(k.get_optimized_ast(), k.opts)
|
||||
p = get_program(self.ast_reduce, opts=[])
|
||||
print(p.name, p.estimates.ops, p.estimates.mem, p.estimates.lds)
|
||||
self.assertEqual(p.estimates.ops, N*N)
|
||||
self.assertEqual(p.estimates.mem, N*N*4 + 4)
|
||||
|
||||
def test_reduce_group(self):
|
||||
k = Kernel(self.ast_reduce)
|
||||
try:
|
||||
k.apply_opt(Opt(OptOps.GROUP, 0, 50))
|
||||
p = get_program(self.ast_reduce, opts=[Opt(OptOps.GROUP, 0, 50)])
|
||||
except KernelOptError:
|
||||
raise unittest.SkipTest("no locals")
|
||||
p = get_program(k.get_optimized_ast(), k.opts)
|
||||
# NOTE: these are wrong, they don't respect the if statement
|
||||
print(p.name, p.estimates.ops, p.estimates.mem, p.estimates.lds)
|
||||
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad import Tensor, GlobalCounters, dtypes, Context, nn
|
||||
from tinygrad.uop.ops import Ops
|
||||
from tinygrad.helpers import Timing, CI, Profiling, WINO, DEBUG, getenv
|
||||
from tinygrad.opt.kernel import Kernel
|
||||
from tinygrad.opt.heuristic import hand_coded_optimizations
|
||||
from tinygrad.helpers import CI, Profiling, WINO, getenv
|
||||
|
||||
class TestWinogradClose(unittest.TestCase):
|
||||
def test_close(self):
|
||||
@@ -28,30 +25,6 @@ class TestWinograd(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
WINO.value = self.old
|
||||
|
||||
def test_speed(self):
|
||||
x = Tensor.empty(1,4,9,9)
|
||||
w = Tensor.empty(4,4,3,3)
|
||||
|
||||
with Timing("running conv: "):
|
||||
out = Tensor.conv2d(x, w)
|
||||
|
||||
with Timing("scheduling: "):
|
||||
sched = out.schedule()
|
||||
|
||||
for i,s in enumerate(sched):
|
||||
if s.ast.op is not Ops.SINK: continue
|
||||
ops = s.ast.toposort()
|
||||
with Timing(f"linearize {i} with {len(ops):4d} ops: "):
|
||||
l = Kernel(s.ast)
|
||||
l.apply_opts(hand_coded_optimizations(l))
|
||||
assert len(l.sts) <= 256 # just the current value to prevent regression
|
||||
if DEBUG >= 2: print(f"{len(l.sts):4d} shapetrackers with max {max(len(x.views) for x in l.sts)} views")
|
||||
for st in l.sts:
|
||||
assert len(st.views) <= 2, "too many views in winograd"
|
||||
if DEBUG >= 3:
|
||||
print(f"{len(st.views):3d} views")
|
||||
for v in st.views: print(v)
|
||||
|
||||
def test_profile(self):
|
||||
x,w = Tensor.rand(1,4,9,9).realize(), Tensor.rand(4,4,3,3).realize()
|
||||
with Profiling(enabled=not CI, sort='time'):
|
||||
|
||||
Reference in New Issue
Block a user