Files
tinygrad/test/amd/test_sqtt_profiler.py
nimlgen 4164666c72 programinfo (#15942)
* programinfo

* fix

* m

* x

* x

* changes

* x

* fix

* rm
2026-04-27 23:12:03 +03:00

100 lines
3.7 KiB
Python

import unittest, contextlib
from tinygrad import Device, Tensor, Context, TinyJit
from tinygrad.device import Compiled, ProfileProgramEvent, ProfileDeviceEvent
from tinygrad.engine.realize import run_linear
from tinygrad.codegen import to_program
from tinygrad.viz.serve import load_amd_counters, VizData
@contextlib.contextmanager
def save_sqtt():
data = VizData()
yield data.ctxs
Device[Device.DEFAULT].synchronize()
Device[Device.DEFAULT]._at_profile_finalize()
load_amd_counters(data, Compiled.profile_events)
data.ctxs[:] = [r for r in data.ctxs if r["name"].startswith("SQTT")]
@unittest.skipUnless(Device.DEFAULT == "AMD", "only runs on AMD")
class TestSQTTProfiler(unittest.TestCase):
# TODO: can we enable SQTT profiling in context?
@classmethod
def setUpClass(cls):
if not Device[Device.DEFAULT].sqtt_enabled: raise unittest.SkipTest("device must be in SQTT profiling mode")
def setUp(self):
Device[Device.DEFAULT].synchronize()
Compiled.profile_events[:] = [e for e in Compiled.profile_events if isinstance(e, (ProfileProgramEvent, ProfileDeviceEvent))]
def test_simple(self):
t = Tensor.empty(1) + 1
with save_sqtt() as sqtt:
linear = t.schedule_linear()
run_linear(linear)
fn_name = to_program(linear.src[0].src[0], renderer=Device[Device.DEFAULT].renderer).arg.function_name
self.assertEqual(len(sqtt), 1)
self.assertEqual(sqtt[0]["name"], f"SQTT {fn_name}")
def test_multiple_runs(self):
t = Tensor.empty(1) + 1
with save_sqtt() as sqtt:
linear = t.schedule_linear()
for _ in range(N:=3): run_linear(linear)
fn_name = to_program(linear.src[0].src[0], renderer=Device[Device.DEFAULT].renderer).arg.function_name
self.assertEqual(len(sqtt), N)
for i in range(1, N):
self.assertEqual(sqtt[i]["name"], f"SQTT {fn_name} n{i+1}")
def test_multiple_kernels(self):
t = ((Tensor.empty(1) + 1).contiguous() + 2)
linear = t.schedule_linear()
with save_sqtt() as sqtt:
run_linear(linear)
self.assertEqual(len(sqtt), len(linear.src))
for i,call in enumerate(linear.src):
fn_name = to_program(call.src[0], renderer=Device[Device.DEFAULT].renderer).arg.function_name
self.assertEqual(sqtt[i]["name"], f"SQTT {fn_name}")
def test_multiple_kernels_lower(self):
t = ((Tensor.empty(1) + 1).contiguous() + 2)
linear = t.schedule_linear()
with save_sqtt() as sqtt:
run_linear(linear)
self.assertEqual(len(sqtt), len(linear.src))
for i,call in enumerate(linear.src):
fn_name = to_program(call.src[0], renderer=Device[Device.DEFAULT].renderer).arg.function_name
self.assertEqual(sqtt[i]["name"], f"SQTT {fn_name}")
def test_jit(self):
@TinyJit
def f(a): return a + 1
t = Tensor.empty(1)
with save_sqtt() as sqtt:
for _ in range(N:=5):
f(t).realize()
self.assertEqual(len(sqtt), N)
kernel_name = sqtt[0]["name"]
for i,s in enumerate(sqtt[1:], start=1): self.assertEqual(s["name"], f"{kernel_name} n{i+1}")
# TODO: can we trace SQTT for graphed kernels?
def test_jit_graph(self, kernel_count=3*1):
@TinyJit
def f(a): return ((a + 1).contiguous() + 2).contiguous().sum()
t = Tensor.empty(32)
with save_sqtt() as sqtt:
for _ in range(5):
f(t).realize()
names = [s["name"] for s in sqtt]
k0, k1, k2 = names[:3]
for i in range(3, len(sqtt), 3):
n = (i // 3)+1
self.assertEqual(names[i], f"{k0} n{n}")
self.assertEqual(names[i+1], f"{k1} n{n}")
self.assertEqual(names[i+2], f"{k2} n{n}")
self.assertEqual(len(sqtt), kernel_count)
@Context(JIT=2)
def test_jit_multiple_kernels(self): self.test_jit_graph(kernel_count=3*5)
if __name__ == "__main__":
unittest.main()