mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 16:37:04 +08:00
100 lines
3.7 KiB
Python
100 lines
3.7 KiB
Python
import unittest, contextlib
|
|
from tinygrad import Device, Tensor, Context, TinyJit
|
|
from tinygrad.device import Compiled, ProfileProgramEvent, ProfileDeviceEvent
|
|
from tinygrad.engine.realize import run_linear
|
|
from tinygrad.codegen import to_program
|
|
from tinygrad.viz.serve import load_amd_counters, VizData
|
|
|
|
@contextlib.contextmanager
|
|
def save_sqtt():
|
|
data = VizData()
|
|
yield data.ctxs
|
|
Device[Device.DEFAULT].synchronize()
|
|
Device[Device.DEFAULT]._at_profile_finalize()
|
|
load_amd_counters(data, Compiled.profile_events)
|
|
data.ctxs[:] = [r for r in data.ctxs if r["name"].startswith("SQTT")]
|
|
|
|
@unittest.skipUnless(Device.DEFAULT == "AMD", "only runs on AMD")
|
|
class TestSQTTProfiler(unittest.TestCase):
|
|
# TODO: can we enable SQTT profiling in context?
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
if not Device[Device.DEFAULT].sqtt_enabled: raise unittest.SkipTest("device must be in SQTT profiling mode")
|
|
|
|
def setUp(self):
|
|
Device[Device.DEFAULT].synchronize()
|
|
Compiled.profile_events[:] = [e for e in Compiled.profile_events if isinstance(e, (ProfileProgramEvent, ProfileDeviceEvent))]
|
|
|
|
def test_simple(self):
|
|
t = Tensor.empty(1) + 1
|
|
with save_sqtt() as sqtt:
|
|
linear = t.schedule_linear()
|
|
run_linear(linear)
|
|
fn_name = to_program(linear.src[0].src[0], renderer=Device[Device.DEFAULT].renderer).arg.function_name
|
|
self.assertEqual(len(sqtt), 1)
|
|
self.assertEqual(sqtt[0]["name"], f"SQTT {fn_name}")
|
|
|
|
def test_multiple_runs(self):
|
|
t = Tensor.empty(1) + 1
|
|
with save_sqtt() as sqtt:
|
|
linear = t.schedule_linear()
|
|
for _ in range(N:=3): run_linear(linear)
|
|
fn_name = to_program(linear.src[0].src[0], renderer=Device[Device.DEFAULT].renderer).arg.function_name
|
|
self.assertEqual(len(sqtt), N)
|
|
for i in range(1, N):
|
|
self.assertEqual(sqtt[i]["name"], f"SQTT {fn_name} n{i+1}")
|
|
|
|
def test_multiple_kernels(self):
|
|
t = ((Tensor.empty(1) + 1).contiguous() + 2)
|
|
linear = t.schedule_linear()
|
|
with save_sqtt() as sqtt:
|
|
run_linear(linear)
|
|
self.assertEqual(len(sqtt), len(linear.src))
|
|
for i,call in enumerate(linear.src):
|
|
fn_name = to_program(call.src[0], renderer=Device[Device.DEFAULT].renderer).arg.function_name
|
|
self.assertEqual(sqtt[i]["name"], f"SQTT {fn_name}")
|
|
|
|
def test_multiple_kernels_lower(self):
|
|
t = ((Tensor.empty(1) + 1).contiguous() + 2)
|
|
linear = t.schedule_linear()
|
|
with save_sqtt() as sqtt:
|
|
run_linear(linear)
|
|
self.assertEqual(len(sqtt), len(linear.src))
|
|
for i,call in enumerate(linear.src):
|
|
fn_name = to_program(call.src[0], renderer=Device[Device.DEFAULT].renderer).arg.function_name
|
|
self.assertEqual(sqtt[i]["name"], f"SQTT {fn_name}")
|
|
|
|
def test_jit(self):
|
|
@TinyJit
|
|
def f(a): return a + 1
|
|
t = Tensor.empty(1)
|
|
with save_sqtt() as sqtt:
|
|
for _ in range(N:=5):
|
|
f(t).realize()
|
|
self.assertEqual(len(sqtt), N)
|
|
kernel_name = sqtt[0]["name"]
|
|
for i,s in enumerate(sqtt[1:], start=1): self.assertEqual(s["name"], f"{kernel_name} n{i+1}")
|
|
|
|
# TODO: can we trace SQTT for graphed kernels?
|
|
def test_jit_graph(self, kernel_count=3*1):
|
|
@TinyJit
|
|
def f(a): return ((a + 1).contiguous() + 2).contiguous().sum()
|
|
t = Tensor.empty(32)
|
|
with save_sqtt() as sqtt:
|
|
for _ in range(5):
|
|
f(t).realize()
|
|
names = [s["name"] for s in sqtt]
|
|
k0, k1, k2 = names[:3]
|
|
for i in range(3, len(sqtt), 3):
|
|
n = (i // 3)+1
|
|
self.assertEqual(names[i], f"{k0} n{n}")
|
|
self.assertEqual(names[i+1], f"{k1} n{n}")
|
|
self.assertEqual(names[i+2], f"{k2} n{n}")
|
|
self.assertEqual(len(sqtt), kernel_count)
|
|
|
|
@Context(JIT=2)
|
|
def test_jit_multiple_kernels(self): self.test_jit_graph(kernel_count=3*5)
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|