Files
tinygrad/test/null/test_viz.py
qazal 598cc13ad2 more readable null graph profile in VIZ (#16548)
* more readable null graph profile in VIZ

* change

* fix flaky test
2026-06-09 18:35:05 +09:00

1060 lines
44 KiB
Python

import unittest, decimal, sys, json, contextlib, tempfile, pickle, io, math
from pathlib import Path
from dataclasses import dataclass
from typing import Generator
from tinygrad.uop.ops import UOp, UPat, Ops, PatternMatcher, TrackedPatternMatcher, graph_rewrite, track_rewrites, profile_matches
from tinygrad.uop.symbolic import sym
from tinygrad.dtype import dtypes, AddrSpace
from tinygrad.helpers import colored, ansistrip, flatten, TracingKey, ProfileRangeEvent, ProfileEvent, Context, cpu_events, profile_marker
from tinygrad.helpers import cpu_profile, ProfilePointEvent, unwrap
from tinygrad.device import Buffer
from tinygrad.uop.ops import tracked_keys, tracked_ctxs, uop_fields, active_rewrites, active_group, _name_cnt, RewriteTrace
from tinygrad.viz.serve import load_rewrites, get_full_rewrite, uop_to_json, VizData, get_render, addrspace_colors
from tinygrad.codegen import do_to_program
@track_rewrites(name=True)
def exec_rewrite(sink:UOp, pm_lst:list[PatternMatcher], names:None|list[str]=None) -> UOp:
for i,pm in enumerate(pm_lst):
sink = graph_rewrite(sink, TrackedPatternMatcher(pm.patterns), name=names[i] if names else None)
return sink
# small container class for the viz server module
class VizTrace:
# loader init
def __init__(self): self._data:VizData|None = None
@property
def data(self) -> VizData: return unwrap(self._data)
def set_data(self) -> None:
data = VizData(RewriteTrace(tracked_keys.copy(), tracked_ctxs.copy(), uop_fields.copy()))
load_rewrites(data)
self._data = data
# the API
def list_items(self) -> list[dict]:
return self.data.ctxs
def get_details(self, rewrite_idx:int, step:int) -> Generator[dict, None, None]:
assert len(self.data.trace.rewrites) > rewrite_idx, f"only loaded {len(self.data.trace.rewrites)} traces, expecting at least {rewrite_idx}"
return get_full_rewrite(self.data, self.data.trace.rewrites[rewrite_idx][step])
@contextlib.contextmanager
def save_viz():
for lst in [tracked_keys, tracked_ctxs, active_rewrites, active_group, _name_cnt]: lst.clear()
Buffer.profile_events.clear()
cpu_events.clear()
viz = VizTrace()
with Context(VIZ=-1, TRACK_MATCH_STATS=2, PROFILE=1):
yield viz
viz.set_data()
class TestViz(unittest.TestCase):
def test_simple(self):
with save_viz() as viz:
a = UOp.variable("a", 0, 10)
exec_rewrite((a+0)*1, [sym])
lst = viz.list_items()
# VIZ displays rewrites in groups of tracked functions
self.assertEqual(len(lst), 1)
# each group has a list of steps
self.assertEqual(len(lst[0]["steps"]), 1)
# each step has a list of matches
self.assertEqual(lst[0]["steps"][0]["match_count"], 2)
def test_rewrites(self):
with save_viz() as viz:
a = UOp.variable("a", 0, 10)
exec_rewrite(a*1, [sym])
exec_rewrite(a*2, [sym])
lst = viz.list_items()
self.assertEqual(len(lst), 2)
# names dedup using a counter
self.assertEqual(lst[0]["name"], "exec_rewrite n1")
self.assertEqual(lst[1]["name"], "exec_rewrite n2")
def test_steps(self):
with save_viz() as viz:
a = UOp.variable("a", 0, 10)
exec_rewrite(a+1, [PatternMatcher([]), PatternMatcher([])], ["x", "y"])
steps = viz.list_items()[0]["steps"]
# steps can optionally have a name
self.assertEqual(steps[0]["name"], "x")
self.assertEqual(steps[1]["name"], "y")
def test_rewrite_location(self):
def inner(sink): return graph_rewrite(sink, PatternMatcher([]))
def outer(sink): return inner(sink)
with save_viz() as viz:
outer(UOp.variable("a", 1, 10))
lst = viz.list_items()
# step location comes from inner rewrite
fp, lineno = lst[0]["steps"][0]["loc"]
self.assertEqual(fp, inner.__code__.co_filename)
self.assertEqual(lineno, inner.__code__.co_firstlineno)
def test_exceptions(self):
# VIZ tracks rewrites up to and including the error
def count_3(x:UOp):
assert x.arg <= 3
return x.replace(arg=x.arg+1)
err_pm = PatternMatcher([(UPat.cvar("x"), count_3),])
a = UOp.const(dtypes.int, 1)
with save_viz() as viz:
with self.assertRaises(AssertionError): exec_rewrite(a, [err_pm])
lst = viz.list_items()
err_step = lst[0]["steps"][0]
self.assertEqual(err_step["match_count"], 4) # 3 successful rewrites + 1 err
def test_default_name(self):
with save_viz() as viz:
a = UOp.variable("a", 1, 10)
@track_rewrites()
def name_default(): return graph_rewrite(a, PatternMatcher([]))
name_default()
lst = viz.list_items()
self.assertEqual(lst[0]["name"], "name_default n1")
# name can also come from a function that returns a string
def test_dyn_name_fxn(self):
with save_viz() as viz:
@track_rewrites(name=lambda *args,ret,**kwargs: ret.render())
def name_from_fxn(s:UOp, arg:list|None=None): return graph_rewrite(s, PatternMatcher([]))
name_from_fxn(UOp.variable("a", 1, 10)+1, arg=["test"])
lst = viz.list_items()
# name gets deduped by the function call counter
self.assertEqual(lst[0]["name"], "(a+1) n1")
# name can also come from a function that returns a TracingKey
def test_tracing_key(self):
with save_viz() as viz:
@track_rewrites(name=lambda inp,ret: TracingKey("custom_name", (inp,)))
def test(s:UOp): return graph_rewrite(s, PatternMatcher([]))
test(UOp.variable("a", 1, 10)+1)
lst = viz.list_items()
# NOTE: names from TracingKey do not get deduped
self.assertEqual(lst[0]["name"], "custom_name")
def test_nested_track_rewrites(self):
with save_viz() as viz:
@track_rewrites(name=lambda x,ret: TracingKey(f"inner fxn for {x.render()}", (ret,)))
def inner(x:UOp): return graph_rewrite(x, PatternMatcher([]), name="each")
@track_rewrites(name=lambda *args,ret: f"outer rewrite of {len(args)} inputs")
def outer(*xs:tuple[UOp, ...]): return graph_rewrite(UOp.sink(*[inner(x) for x in xs]), PatternMatcher([]), name="all")
items = ["a", "b", "c"]
outer(*[UOp.variable(x, 1, 10) for x in items])
lst = viz.list_items()
# inner calls fall outside the outer call
self.assertEqual(len(lst), len(items)+1)
self.assertEqual(lst[0]["name"], f"outer rewrite of {len(items)} inputs n1")
steps = lst[0]["steps"]
self.assertEqual(len(steps), 1)
self.assertEqual(steps[0]["name"], "all")
for i in range(len(items)):
self.assertEqual(lst[i+1]["name"], f"inner fxn for {items[i]}")
steps = lst[i+1]["steps"]
self.assertEqual(len(steps), 1)
self.assertEqual(steps[0]["name"], "each")
def test_profile_matches(self):
with save_viz() as viz:
@profile_matches
def nested_function(u:UOp):
for i in range(2): graph_rewrite(u, PatternMatcher([]), name=f"step {i+1}")
@track_rewrites()
def main_rewrite(u:UOp):
graph_rewrite(u, PatternMatcher([]), name="init")
nested_function(u)
main_rewrite(UOp.variable("a", 1, 10)+UOp.variable("b", 1, 10))
steps = viz.list_items()[0]["steps"]
self.assertEqual(steps[0]["name"], "init")
self.assertEqual(steps[1]["name"], "nested_function")
self.assertEqual(len(steps), 4)
def test_profile_matches_invalid_arg(self):
with save_viz():
@profile_matches
def invalid_fxn(arg:str): return graph_rewrite(UOp(Ops.SINK), PatternMatcher([]))
with self.assertRaisesRegex(AssertionError, "invalid match tracing input"):
invalid_fxn("test")
def test_colored_label(self):
# NOTE: dataclass repr prints literal escape codes instead of unicode chars
@dataclass(frozen=True)
class TestStruct:
colored_field: str
a = UOp(Ops.CUSTOM, arg=TestStruct(colored("xyz", "magenta")+colored("12345", "blue")))
a2 = uop_to_json(VizData(), a)[id(a)]
self.assertEqual(ansistrip(a2["label"]), f"CUSTOM\n{TestStruct.__qualname__}(colored_field='xyz12345')")
def test_colored_label_multiline(self):
with save_viz() as viz:
arg = colored("x", "green")+"\n"+colored("y", "red")+colored("z", "yellow")+colored("ww\nw", "magenta")
src = [Tensor.empty(1).uop for _ in range(10)]
a = UOp(Ops.CUSTOM, src=tuple(src), arg=arg)
exec_rewrite(a, [PatternMatcher([])])
a2 = next(viz.get_details(0, 0))["graph"][id(a)]
self.assertEqual(ansistrip(a2["label"]), "CUSTOM\nx\nyzww\nw")
def test_inf_loop(self):
a = UOp.const(dtypes.int, 3)
b = UOp.const(dtypes.int, 4)
pm = PatternMatcher([
(UPat(Ops.CONST, arg=3, name="x"), lambda x: x.replace(arg=4)),
(UPat(Ops.CONST, arg=4, name="x"), lambda x: x.replace(arg=3)),
])
with save_viz() as viz:
# use smaller stack limit for faster test (default is 250000)
with Context(REWRITE_STACK_LIMIT=100): self.assertRaises(RuntimeError, exec_rewrite, a, [pm])
graphs = flatten(x["graph"].values() for x in viz.get_details(0, 0))
self.assertEqual(graphs[0], uop_to_json(VizData(), a)[id(a)])
self.assertEqual(graphs[1], uop_to_json(VizData(), b)[id(b)])
# fallback to NOOP with the error message
nop = UOp(Ops.NOOP, arg="infinite loop in fixed_point_rewrite")
self.assertEqual(graphs[2], uop_to_json(VizData(), nop)[id(nop)])
def test_const_node_visibility(self):
with save_viz() as viz:
a = UOp.variable("a", 0, 10, dtype=dtypes.int)
z = UOp.const(a.dtype, 0)
y = UOp.const(dtypes.float, math.pi)
alu = a*z
ret = exec_rewrite(sink:=UOp.sink(alu, y), [sym])
lst = viz.list_items()
self.assertEqual(len(lst), 1)
graphs = [x["graph"] for x in viz.get_details(0, 0)]
# const is always in the graph, client side hides exclude=True nodes by default
self.assertEqual(list(graphs[0]), [id(a), id(z), id(alu), id(y), id(sink)])
self.assertTrue(graphs[0][id(z)]["exclude"])
self.assertTrue(graphs[0][id(y)]["exclude"])
self.assertFalse(graphs[0][id(alu)]["exclude"])
self.assertEqual(graphs[0][id(y)]["label"].split("\n")[:2], ["CONST", "3.14159"])
self.assertEqual(list(graphs[1]), [id(z), id(y), id(ret)])
def test_const_reshape_expand_folded(self):
# CONST->RESHAPE->EXPAND should be folded into the ALU node, not shown as separate RESHAPE/EXPAND nodes
c = UOp.const(dtypes.float, 1.0, shape=(3,4)) # creates CONST->RESHAPE->EXPAND chain
a = UOp(Ops.DEFINE_VAR, dtypes.float, arg=("a", 0.0, 10.0))
alu = a + c
with save_viz() as viz:
graph_rewrite(alu, PatternMatcher([]))
graph = [x["graph"] for x in viz.get_details(0, 0)][0]
excluded_nodes = {v["label"].split("\n")[0] for v in graph.values() if v["exclude"]}
self.assertIn("CONST", excluded_nodes)
self.assertIn("STACK", excluded_nodes)
self.assertIn("RESHAPE", excluded_nodes)
self.assertIn("EXPAND", excluded_nodes)
self.assertIn("CONST1 1", graph[id(alu)]["label"])
def test_stack_movement_not_folded_unless_all_const(self):
a = UOp.variable("a", 0, 10, dtype=dtypes.int)
c = UOp.const(dtypes.int, 1)
stack = a.vectorize(c)
reshaped = stack.reshape((1, 2))
graph = uop_to_json(VizData(), reshaped)
self.assertFalse(graph[id(stack)]["exclude"])
const_stack = c.vectorize(UOp.const(dtypes.int, 2))
const_reshaped = const_stack.reshape((1, 2))
const_graph = uop_to_json(VizData(), const_reshaped)
self.assertTrue(const_graph[id(const_stack)]["exclude"])
# VIZ displays nested graph_rewrites in a tree view
def leaf_rewrite(x:UOp): return x.rtag(1) if x.tag is None else None
leaf = TrackedPatternMatcher([(UPat(Ops.DEFINE_VAR, name="x"), leaf_rewrite)])
def branch_rewrite(x:UOp, y:UOp):
if x.tag is not None: return
x2 = graph_rewrite(x, leaf, name="leaf_left")
y2 = graph_rewrite(y, leaf, name="leaf_right")
return x2 * y2
branch = TrackedPatternMatcher([(UPat.var("x")+UPat.var("y"), branch_rewrite)])
def root_rewrite(root:UOp):
new_src = tuple(graph_rewrite(b, branch, name=f"branch_{i}") for i,b in enumerate(root.src))
return root.replace(src=new_src)
root = TrackedPatternMatcher([(UPat(Ops.SINK, src=UPat(Ops.ADD), name="root"), root_rewrite),])
class TestVizTree(unittest.TestCase):
def assertStepEqual(self, step:dict, want:dict):
for k,v in want.items():
self.assertEqual(step[k], v, f"failed at '{k}': {v} != {step[k]}\n{step=}")
def test_tree_view(self):
with save_viz() as viz:
a = UOp.variable("a",0,10)
b = UOp.variable("b",0,10)
c = UOp.variable("c",0,10)
d = UOp.variable("d",0,10)
sink = UOp.sink(a+b, c+d)
def tree_rewrite(): return graph_rewrite(sink, root, name="root")
tree_rewrite()
lst = viz.list_items()
steps = lst[0]["steps"]
self.assertEqual(len(steps), 1+2+4)
self.assertStepEqual(steps[0], {"name":"root", "depth":0, "match_count":1})
self.assertStepEqual(steps[1], {"name":"branch_0", "depth":1, "match_count":1})
self.assertStepEqual(steps[2], {"name":"leaf_left", "depth":2, "match_count":1})
self.assertStepEqual(steps[3], {"name":"leaf_right", "depth":2, "match_count":1})
self.assertStepEqual(steps[4], {"name":"branch_1", "depth":1, "match_count":1})
self.assertStepEqual(steps[5], {"name":"leaf_left", "depth":2, "match_count":1})
self.assertStepEqual(steps[6], {"name":"leaf_right", "depth":2, "match_count":1})
import gc
def bufs_allocated() -> int:
gc.collect()
return sum([type(x).__name__ == "Buffer" and type(x).__module__ == "tinygrad.device" for x in gc.get_objects()])
class TestVizGC(unittest.TestCase):
def test_gc(self):
with save_viz() as viz:
init = bufs_allocated()
a = UOp.new_buffer("NULL", 10, dtypes.char)
a.buffer.allocate()
exec_rewrite(a, [PatternMatcher([])])
del a
self.assertEqual(bufs_allocated()-init, 0)
lst = viz.list_items()
self.assertEqual(len(lst), 1)
@unittest.skip("it's not generic enough to handle arbitrary UOps in arg")
def test_gc_uop_in_arg(self):
with save_viz() as viz:
init = bufs_allocated()
a = UOp.new_buffer("NULL", 10, dtypes.char)
a.buffer.allocate()
exec_rewrite(UOp(Ops.CUSTOM, src=(a,), arg=a), [PatternMatcher([])])
del a
self.assertEqual(bufs_allocated()-init, 0)
lst = viz.list_items()
self.assertEqual(len(lst), 1)
# VIZ integrates with other parts of tinygrad
from tinygrad import Tensor, Device, TinyJit, Variable, function
class TestVizIntegration(unittest.TestCase):
# codegen supports rendering of code blocks
def test_codegen_tracing(self):
with save_viz() as viz:
ast = (Tensor.empty(4)+Tensor.empty(4)).schedule_linear().src[0].src[0]
prg = do_to_program(ast, Device[Device.DEFAULT].renderer)
lst = viz.list_items()
self.assertEqual(len(lst), 3)
self.assertEqual(lst[0]["name"], "Callify 1 Buffer n1")
self.assertEqual(lst[1]["name"], "Schedule 1 Kernel n1")
self.assertEqual(lst[2]["name"], prg.arg.name)
input_ast = next(viz.get_details(2, 0))["graph"].values()
for u in input_ast:
if u["label"].startswith("PARAM\n"): self.assertEqual(u["addrspace"], addrspace_colors[AddrSpace.GLOBAL])
# schedule graph CALL nodes have a link to jump to codegen
def test_link_sched_codegen(self):
with save_viz() as viz:
c1 = Tensor.empty(4, device="NULL").add(1)
c2 = Tensor.empty(8, device="NULL").add(1)
with Context(SCACHE=0):
sched = c1.schedule_linear(c2)
from tinygrad.engine.realize import compile_linear
sched = compile_linear(sched)
with Context(NO_COLOR=0):
prgs = [do_to_program(si.src[0], Device[c1.device].renderer).arg.name for si in sched.src]
lst = viz.list_items()
sched_idx = next(i for i,l in enumerate(lst) if l["name"].startswith("Schedule"))
viz_kernel = next(i for i,s in enumerate(lst[sched_idx]["steps"]) if s["name"] == "View Kernel Graph")
with Context(NO_COLOR=1):
graph = next(viz.get_details(sched_idx, viz_kernel))["graph"]
call_nodes = [n for n in graph.values() if n["label"].startswith("CALL")]
for i,n in enumerate(call_nodes):
assert n["ref"] is not None
self.assertEqual(lst[n["ref"]]["name"], prgs[i])
assert ansistrip(prgs[i]) in n["label"], f"CALL must contain kernel name, got {n['label']}"
def test_link_sched_codegen_beam(self):
with Context(BEAM=2):
self.test_link_sched_codegen()
@Context(TRACEMETA=2)
def test_metadata_tracing(self):
with save_viz() as viz:
a = Tensor.empty(1)
b = Tensor.empty(1)
metadata = (alu:=a+b).uop.metadata
alu.schedule_linear()
graph = next(viz.get_details(0, 0))["graph"]
self.assertEqual(len([n for n in graph.values() if repr(metadata) in n["label"]]), 1)
# tracing also works without a track_rewrites context
# all graph_rewrites get put into the default group
def test_default_tracing(self):
with save_viz() as viz:
def test(root):
return graph_rewrite(root, sym)
test(c:=UOp.const(dtypes.int, 1))
test(c+1)
ls = viz.list_items()
self.assertEqual(len(ls), 1)
self.assertEqual(ls[0]["name"], "default graph_rewrite")
# using @track_rewrites organizes function calls into groups
# and nicely counts function calls.
def test_group_traces(self):
with save_viz() as viz:
@track_rewrites()
def test(root):
return graph_rewrite(root, sym)
test(c:=UOp.const(dtypes.int, 1))
test(c+1)
ls = viz.list_items()
self.assertEqual(len(ls), 2)
for i in range(2): self.assertEqual(ls[i]["name"], f"test n{i+1}")
# @track_rewrites always starts a new group.
def test_group_combined(self):
with save_viz() as viz:
def default_test(root): return graph_rewrite(root, sym)
tracked_test = track_rewrites()(default_test)
c = UOp.const(dtypes.int, 1)
default_test(c+1) # goes to the default group
tracked_test(c) # all rewrites after this go inside the second group.
default_test(c+2)
ls = viz.list_items()
self.assertEqual(len(ls), 2)
graph = next(viz.get_details(0, 0))["graph"]
self.assertEqual(list(graph), [id(c), id(c+1)])
self.assertTrue(graph[id(c)]["exclude"])
self.assertFalse(graph[id(c+1)]["exclude"])
self.assertEqual(list(next(viz.get_details(1, 0))["graph"]), [id(c)])
graph = next(viz.get_details(1, 1))["graph"]
self.assertEqual(list(graph), [id(c), id(c.const_like(2)), id(c+2)])
self.assertTrue(graph[id(c)]["exclude"])
self.assertTrue(graph[id(c.const_like(2))]["exclude"])
self.assertFalse(graph[id(c+2)]["exclude"])
def test_recurse(self):
with save_viz() as viz:
a = Tensor.empty(10)
for _ in range(10_000): a += a
graph_rewrite(a.uop, PatternMatcher([]))
lst = viz.list_items()
assert len(lst) == 1
def test_jit(self):
with save_viz():
@TinyJit
def f(a, b, c): return (a+b).contiguous().mul(3), c.add(1).contiguous().assign(a.to(c.device)), b.assign(c.to(b.device))
a, b, c = Tensor.empty(16, device="NULL"), Tensor.empty(16, device="NULL"), Tensor.empty(16, device="NULL:1")
for _ in range(3): Tensor.realize(*f(a, b, c))
out = load_profile(cpu_events)
self.assertEqual(["NULL", "NULL Graph", "NULL:SDMA:0", "NULL:1", "NULL:1:SDMA:0"], [k for k in out["layout"] if k.startswith("NULL")])
self.assertEqual(len(out["layout"]["NULL"]["events"]), 2*3)
self.assertEqual(len(out["layout"]["NULL:SDMA:0"]["events"]), 3)
self.assertEqual(len(out["layout"]["NULL Graph"]["events"]), 2)
for graph in out["layout"]["NULL Graph"]["events"]:
graph_st, graph_et = graph["st"], graph["st"]+graph["dur"]
for k in ["NULL", "NULL:1", "NULL:SDMA:0", "NULL:1:SDMA:0"]:
events = [e for e in out["layout"][k]["events"] if graph_st <= e["st"] and e["st"]+e["dur"] <= graph_et]
self.assertGreater(len(events), 0)
self.assertEqual([e["st"] for e in events], [graph_st+i*events[0]["dur"] for i in range(len(events))])
from tinygrad.device import ProfileDeviceEvent, ProfileGraphEvent, ProfileGraphEntry
from tinygrad.viz.serve import get_profile
from tinygrad.viz.cli import decode_profile
def load_profile(lst:list[ProfileEvent]) -> dict: return decode_profile(get_profile(VizData(), lst))
class TestVizProfiler(unittest.TestCase):
def test_transfer_uses_copy_device(self):
with save_viz():
a = Tensor.ones(1, device="NULL").contiguous().realize()
a.to("NULL:1").realize()
range_events = [e for e in cpu_events if isinstance(e, ProfileRangeEvent)]
compute_events = [e for e in range_events if e.device == "NULL"]
copy_events = [e for e in range_events if e.device.endswith(":SDMA:0")]
self.assertGreater(len(compute_events), 0, "expected compute events on base device")
self.assertGreater(len(copy_events), 0, "transfer must produce events with ':SDMA' device suffix")
def test_node(self):
prof = [ProfileRangeEvent(device='NV', name='E_2', st=decimal.Decimal(1000), en=decimal.Decimal(1010)),
ProfileDeviceEvent(device='NV', tdiff=decimal.Decimal(-1000))]
j = load_profile(prof)
dev_events = j['layout']['NV']['events']
self.assertEqual(len(dev_events), 1)
event = dev_events[0]
self.assertEqual(event['name'], 'E_2')
self.assertEqual(event['st'], 0)
self.assertEqual(event['dur'], 10)
assert event['ref'] is None
def test_copy_node(self):
prof = [ProfileRangeEvent(device='NV:SDMA:0', name='COPYxx', st=decimal.Decimal(1000), en=decimal.Decimal(1010)),
ProfileRangeEvent(device='NV:2:SDMA:0', name='COPYxx', st=decimal.Decimal(1000), en=decimal.Decimal(1010)),
ProfileDeviceEvent(device='NV:SDMA:0', tdiff=decimal.Decimal(-100)),
ProfileDeviceEvent(device='NV:2:SDMA:0', tdiff=decimal.Decimal(-80))]
j = load_profile(prof)
event = j['layout']['NV:SDMA:0']['events'][0]
self.assertEqual(event['name'], 'COPYxx')
self.assertEqual(event['st'], 0) # first event
self.assertEqual(event['dur'], 10)
event2 = j['layout']['NV:2:SDMA:0']['events'][0]
self.assertEqual(event2['st'], 20) # second event, diff clock
self.assertEqual(j["dur"], (event2["st"]+event2["dur"])-event["st"])
def test_copy_node_bandwidth(self):
sz = 256*1024*1024
dur = 10_000
prof = [ProfileRangeEvent(device='NV:SDMA:0', name=TracingKey("NV -> NV:1", ret=sz), st=decimal.Decimal(1000), en=decimal.Decimal(1000+dur)),
ProfileDeviceEvent(device='NV:SDMA:0', tdiff=decimal.Decimal(-1000))]
j = load_profile(prof)
event = j['layout']['NV:SDMA:0']['events'][0]
self.assertEqual(event['fmt'], {"B/s": sz/(dur*1e-6), "B": sz})
def test_graph(self):
prof = [ProfileDeviceEvent(device='NV', tdiff=decimal.Decimal(-1000)),
ProfileDeviceEvent(device='NV:1:SDMA:0', tdiff=decimal.Decimal(-50)),
ProfileGraphEvent(ents=[ProfileGraphEntry(device='NV', name='E_25_4n2', st_id=0, en_id=1),
ProfileGraphEntry(device='NV:1:SDMA:0', name='NV -> NV:1', st_id=2, en_id=3)],
deps=[[], [0]],
sigs=[decimal.Decimal(1000), decimal.Decimal(1002), decimal.Decimal(1004), decimal.Decimal(1008)])]
j = load_profile(prof)
tracks = list(j['layout'])
self.assertEqual(tracks[0], 'NV')
self.assertEqual(tracks[1], 'NV Graph')
self.assertEqual(tracks[2], 'NV:1:SDMA:0')
nv_events = j['layout']['NV']['events']
self.assertEqual(nv_events[0]['name'], 'E_25_4n2')
self.assertEqual(nv_events[0]['st'], 0)
self.assertEqual(nv_events[0]['dur'], 2)
sdma_events = j['layout']['NV:1:SDMA:0']['events']
self.assertEqual(sdma_events[0]['name'], 'NV -> NV:1')
self.assertEqual(sdma_events[0]['st'], 954)
graph_events = j['layout']['NV Graph']['events']
self.assertEqual(graph_events[0]['st'], nv_events[0]['st'])
self.assertEqual(graph_events[0]['st']+graph_events[0]['dur'], sdma_events[0]['st']+sdma_events[0]['dur'])
def test_graph_copy_bandwidth(self):
sz = 256*1024*1024
dur = 10_000
prof = [ProfileDeviceEvent(device='NV', tdiff=decimal.Decimal(-1000)),
ProfileDeviceEvent(device='NV:1:SDMA:0', tdiff=decimal.Decimal(-50)),
ProfileGraphEvent(ents=[ProfileGraphEntry(device='NV:1:SDMA:0', name=TracingKey("NV -> NV:1", ret=sz), st_id=0, en_id=1)],
deps=[[]],
sigs=[decimal.Decimal(1004), decimal.Decimal(1004+dur)])]
j = load_profile(prof)
sdma_events = j['layout']['NV:1:SDMA:0']['events']
self.assertEqual(sdma_events[0]["fmt"], {"B/s": sz/(dur*1e-6), "B": sz})
def test_block_ordering(self):
prof = [ProfileDeviceEvent(device='NV', tdiff=decimal.Decimal(-1000)),
ProfileDeviceEvent(device='NV:1', tdiff=decimal.Decimal(-500)),
ProfileDeviceEvent(device='NV:SDMA:0', tdiff=decimal.Decimal(-100)),
ProfileRangeEvent(device='NV', name='E_2', st=decimal.Decimal(1000), en=decimal.Decimal(1010)),
ProfileRangeEvent(device='NV:1', name='E_3', st=decimal.Decimal(1000), en=decimal.Decimal(1010)),
ProfileRangeEvent(device='NV:SDMA:0', name='COPY', st=decimal.Decimal(1000), en=decimal.Decimal(1010)),
ProfileGraphEvent(ents=[ProfileGraphEntry(device='NV', name='E_2', st_id=0, en_id=1)],
deps=[[]], sigs=[decimal.Decimal(1000), decimal.Decimal(1010)])]
j = load_profile(prof)
# graph grouped with its device, memory at the end
self.assertListEqual(list(j['layout']), ['NV', 'NV Graph', 'NV:SDMA:0', 'NV:1'])
@unittest.skipIf(sys.platform == 'win32', "TODO: ops_amd import fails on windows")
def test_multi_sdma_ordering(self):
props = {"gfx_target_version": 0}
D, St, En = decimal.Decimal, decimal.Decimal(1000), decimal.Decimal(1010)
prof = [# 2 AMD GPUs, 2 SDMA engines each
ProfileDeviceEvent(device='AMD', tdiff=D(-1000), props=props),
ProfileDeviceEvent(device='AMD:1', tdiff=D(-900), props=props),
ProfileDeviceEvent(device='AMD:SDMA:0', tdiff=D(-100), props=props),
ProfileDeviceEvent(device='AMD:SDMA:1', tdiff=D(-80), props=props),
ProfileDeviceEvent(device='AMD:1:SDMA:0', tdiff=D(-60), props=props),
ProfileDeviceEvent(device='AMD:1:SDMA:1', tdiff=D(-40), props=props),
# compute + copy events
ProfileRangeEvent(device='AMD', name='E_1', st=St, en=En),
ProfileRangeEvent(device='AMD:1', name='E_2', st=St, en=En),
ProfileRangeEvent(device='AMD:SDMA:0', name='COPY0', st=St, en=En),
ProfileRangeEvent(device='AMD:SDMA:1', name='COPY1', st=St, en=En),
ProfileRangeEvent(device='AMD:1:SDMA:0', name='COPY2', st=St, en=En),
ProfileRangeEvent(device='AMD:1:SDMA:1', name='COPY3', st=St, en=En),
# graph spanning compute + copy on GPU 0
ProfileGraphEvent(ents=[ProfileGraphEntry(device='AMD', name='E_1', st_id=0, en_id=1),
ProfileGraphEntry(device='AMD:SDMA:0', name='COPY0', st_id=2, en_id=3)],
deps=[[], [0]], sigs=[St, En, St, En]),
# memory alloc on both GPUs
ProfilePointEvent(device='AMD', name='alloc', key=0, arg={"sz":1024, "dtype":dtypes.float}, ts=St),
ProfilePointEvent(device='AMD:1', name='alloc', key=1, arg={"sz":512, "dtype":dtypes.float}, ts=St)]
j = load_profile(prof)
# graph grouped with its device, memory at the end
self.assertListEqual(list(j['layout']),
['AMD', 'AMD Graph', 'AMD:SDMA:0', 'AMD:SDMA:1',
'AMD:1', 'AMD:1:SDMA:0', 'AMD:1:SDMA:1',
'AMD Memory', 'AMD:1 Memory'])
def test_bytes_per_kernel(self):
step = 10
n_events = 1_000
prof = [ProfileRangeEvent("CPU", name="k_test", st=decimal.Decimal(ts:=i*step), en=decimal.Decimal(ts)+step) for i in range(n_events)]
sz = len(get_profile(VizData(), prof))
self.assertLessEqual(sz/n_events, 26)
def test_calltrace(self):
with save_viz() as viz:
def fxn(): return Tensor.empty(10).mul(2).realize()
with cpu_profile(TracingKey("test_fxn"), "CUSTOM"):
fxn()
codegen_trace = viz.list_items()[0]["steps"][0]["trace"]
assert any(fxn.__code__.co_filename == f and fxn.__code__.co_firstlineno == l for f,l,*_ in codegen_trace), str(codegen_trace)
profile_ret = load_profile(cpu_events)
e = profile_ret["layout"]["CUSTOM"]["events"][0]
self.assertEqual(e["name"], "test_fxn")
runtime_trace = e["fmt"]["tb"]
assert any(fxn.__code__.co_filename == f and fxn.__code__.co_firstlineno+1 == l for f,l,*_ in runtime_trace), str(runtime_trace)
# can pack up to 1hr 11 min of trace events
def test_trace_duration(self):
dur_mins = 72
n_events = 1_000
step = decimal.Decimal(dur_mins*60*1e6//n_events)
prof = [ProfileRangeEvent("CPU", name="k_test", st=decimal.Decimal(ts:=i*step), en=decimal.Decimal(ts)+step) for i in range(n_events)]
with self.assertRaisesRegex(ValueError, "timestamp out of range"):
get_profile(VizData(), prof)
def test_python_marker(self):
with save_viz():
a = Tensor.empty(1, device="NULL")
b = Tensor.empty(1, device="NULL")
(a+b).realize()
profile_marker("test 1")
(a*b).realize()
profile_marker("test 2")
profile_ret = load_profile(cpu_events)
markers = profile_ret["markers"]
kernels = profile_ret["layout"]["NULL"]["events"]
self.assertEqual(len(markers), 2)
assert kernels[0]["st"] <= markers[0]["ts"] <= kernels[1]["st"]
assert markers[1]["ts"] >= kernels[1]["st"]+kernels[1]["dur"]
def test_layout_order(self):
with save_viz():
def fn(): return
for dname in ["TINY", "USER", "TEST:1 N1", "TEST:2 N1", "TEST:1 N2", "TEST:1:ENGINE:0", "TEST:1:ENGINE:0 N1", "TEST:1"]:
with cpu_profile("fn", dname): fn()
layout = list(load_profile(cpu_events)["layout"])
self.assertListEqual(layout[:2], ["USER","TINY"])
self.assertListEqual(layout[2:], ["TEST:1", "TEST:1 N1", "TEST:1 N2", "TEST:1:ENGINE:0", "TEST:1:ENGINE:0 N1", "TEST:2 N1"])
def _alloc(b:int):
a = Tensor.empty(b, device="NULL", dtype=dtypes.char)
a.uop.buffer.allocate()
return a
class TestVizMemoryLayout(unittest.TestCase):
def test_double_alloc(self):
with save_viz():
a = _alloc(1)
_b = _alloc(1)
profile_ret = load_profile(Buffer.profile_events)
ret = profile_ret["layout"][f"{a.device} Memory"]
self.assertEqual(ret["peak"], 2)
self.assertEqual(len(ret["events"]), 4)
def test_del_once(self):
with save_viz():
a = _alloc(1)
del a
b = _alloc(1)
profile_ret = load_profile(Buffer.profile_events)
ret = profile_ret["layout"][f"{b.device} Memory"]
self.assertEqual(ret["peak"], 1)
self.assertEqual(len(ret["events"]), 4)
def test_alloc_free(self):
with save_viz():
a = _alloc(1)
_b = _alloc(1)
del a
c = _alloc(1)
profile_ret = load_profile(Buffer.profile_events)
ret = profile_ret["layout"][f"{c.device} Memory"]
self.assertEqual(ret["peak"], 2)
self.assertEqual(len(ret["events"]), 6)
def test_free_last(self):
with save_viz():
bufs = []
for _ in range(3):
bufs.append(_alloc(1))
profile_marker("alloc")
device = bufs[0].device
while bufs:
b = bufs.pop()
del b
profile_marker("free")
profile = load_profile(cpu_events+Buffer.profile_events)
ret = profile["layout"][f"{device} Memory"]
self.assertEqual(ret["peak"], 3)
self.assertEqual(len(ret["events"]), 6)
self.assertEqual(len(profile["markers"]), 6)
def test_producer_simple(self):
with save_viz():
a = Tensor.ones(10, device="NULL")
Tensor.realize(a.add(1).contiguous())
b = Tensor.ones(10, device="NULL")
Tensor.realize(b.add(1).contiguous())
profile = load_profile(cpu_events+Buffer.profile_events)
buffers = profile["layout"]["NULL Memory"]["events"]
programs = profile["layout"]["NULL"]["events"]
user_cnt = [len(b["arg"]["users"]) for b in buffers if b["arg"].get("users")]
self.assertEqual(len(user_cnt), len(programs))
@unittest.skip("flaky")
def test_inflight_buf(self):
a = Tensor.empty(1, device="NULL")
n = 4
for i in range(n): (a+i).realize()
profile = load_profile(cpu_events+Buffer.profile_events)
buffers = profile["layout"]["NULL Memory"]["events"]
user_cnt = [len(b["arg"]["users"]) for b in buffers if b["arg"].get("users")]
self.assertEqual(max(user_cnt), n)
input_buf = buffers.pop()
assert all(u[3] == 0 for u in input_buf["arg"]["users"])
def test_annotate_read_write(self):
with save_viz():
a = Tensor.ones(4, device="NULL").contiguous().realize()
b = a.assign(a+2)
c = a+1
Tensor.realize(b, c)
buf_events = load_profile(cpu_events+Buffer.profile_events)["layout"]["NULL Memory"]["events"]
users = next((b["arg"]["users"] for b in buf_events if len(b["arg"].get("users",[])) == 3))
self.assertEqual(users[0][3], 1) # write Tensor.ones
self.assertEqual(users[1][3], 2) # read+write Tensor.assign
self.assertEqual(users[2][3], 0) # readonly
def test_dedup_users(self):
with save_viz():
a = Tensor.empty(1, device="NULL")
for _ in range(n:=4): a.add(1).realize()
profile = load_profile(cpu_events+Buffer.profile_events)
programs = profile["layout"][a.device]["events"]
users = profile["layout"][f"{a.device} Memory"]["events"].pop()["arg"]["users"]
self.assertEqual(len(programs), len(set(users)), n)
from tinygrad.uop.ops import KernelInfo
from tinygrad.renderer.amd.dsl import s
from tinygrad.runtime.autogen.amd.rdna3.ins import (s_add_u32, s_branch, s_cbranch_execz, s_cbranch_scc0, s_cbranch_scc1, s_cmp_eq_i32,
s_cmp_eq_u64, s_code_end, s_endpgm, s_mov_b32, s_nop)
from extra.gemm.amd_asm_matmul import Kernel
class TestCfg(unittest.TestCase):
def setUp(self): self.arch = "gfx1100"
def get_cfg(self, name:str, k:Kernel):
insts = k.finalize()
def fxn(out:UOp) -> UOp:
lidx = UOp.special(1, "lidx0")
gidx = UOp.special(1, "gidx0")
sink = UOp.sink(out.base, lidx, gidx, arg=KernelInfo(name=name))
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg="NULL"), UOp(Ops.LINEAR, src=tuple([UOp(Ops.INS, arg=x) for x in insts]))))
with save_viz() as viz:
with Context(DEV=f"NULL::{self.arch}"):
out = Tensor.custom_kernel(Tensor.empty(1), fxn=fxn)[0]
_ = do_to_program(out.schedule_linear().src[-1].src[0], Device[out.device].renderer)
codegen_rewrites = next(s for s in viz.list_items() if s["name"] == name)
disasm = next(s for s in codegen_rewrites["steps"] if s["name"] == "View Disassembly")
return get_render(viz.data, disasm["query"])
def test_simple(self):
k = Kernel()
k.label("entry")
k.emit(s_branch(), target="bb1")
k.label("bb1")
k.emit(s_endpgm())
k.emit(s_code_end())
cfg = self.get_cfg("simple", k)["data"]
self.assertEqual(len(cfg["blocks"]), 2)
def test_diamond(self):
k = Kernel()
k.label("entry")
k.emit(s_mov_b32(s[0], 0))
k.emit(s_mov_b32(s[1], 0))
k.emit(s_cmp_eq_u64(s[0:1], 0))
k.emit(s_cbranch_scc1(), target="if")
k.emit(s_branch(), target="else")
k.label("if")
k.emit(s_nop(1))
k.emit(s_branch(), target="end")
k.label("else")
k.emit(s_nop(0))
k.label("end")
k.emit(s_endpgm())
k.emit(s_code_end())
ret = self.get_cfg("diamond", k)
cfg = ret["data"]
self.assertEqual(len(cfg["blocks"]), 5)
edge_count = sum(len(v) for v in cfg["paths"].values())
self.assertEqual(edge_count, 5)
references:dict[str, list[str]] = {}
for pc, tokens in cfg["pc_tokens"].items():
for t in tokens:
for key in t["keys"]: references.setdefault(key, []).append(pc)
self.assertEqual(len(references["r0"]), 2)
insts = [cfg["pc_tokens"][pc][0]["st"] for pc in references["r0"]]
self.assertEqual(insts, ['s_mov_b32', 's_cmp_eq_u64'])
end_block = [" ".join(t["st"] for t in cfg["pc_tokens"][pc]) for pc in list(cfg["blocks"].values())[-1]]
code_line = ret["src"].splitlines()[-1]
self.assertEqual(len(end_block), 2)
for st in [end_block[-1], code_line]:
assert st.startswith("s_code_end") and st.endswith("x)"), st
def test_loop(self):
k = Kernel()
k.label("entry")
k.emit(s_mov_b32(s[1], 4))
k.label("loop")
k.emit(s_add_u32(s[1], s[1], -1))
k.emit(s_cmp_eq_i32(s[1], 0))
k.emit(s_cbranch_scc0(), target="loop")
k.emit(s_endpgm())
k.emit(s_code_end())
self.get_cfg("simple_loop", k)
def test_loop_branch(self):
k = Kernel()
k.label("entry")
k.emit(s_mov_b32(s[1], 4))
k.label("loop")
k.emit(s_add_u32(s[1], s[1], -1))
k.emit(s_cmp_eq_i32(s[1], 2))
k.emit(s_cbranch_scc1(), target="cond")
k.emit(s_branch(), target="cont")
k.label("cond")
k.emit(s_add_u32(s[1], s[1], -2))
k.label("cont")
k.emit(s_cmp_eq_i32(s[1], 0))
k.emit(s_cbranch_scc0(), target="loop")
k.emit(s_endpgm())
k.emit(s_code_end())
self.get_cfg("loop_if", k)
def test_loop_break(self):
k = Kernel()
k.label("entry")
k.emit(s_mov_b32(s[1], 8))
k.label("loop")
k.emit(s_add_u32(s[1], s[1], -1))
k.emit(s_cmp_eq_i32(s[1], 5))
k.emit(s_cbranch_scc1(), target="break")
k.emit(s_cmp_eq_i32(s[1], 0))
k.emit(s_cbranch_scc0(), target="loop")
k.label("break")
k.emit(s_endpgm())
k.emit(s_code_end())
self.get_cfg("loop_break", k)
def test_switch(self):
k = Kernel()
k.label("entry")
k.emit(s_cmp_eq_i32(s[0], 0))
k.emit(s_cbranch_scc1(), target="case0")
k.emit(s_cmp_eq_i32(s[0], 1))
k.emit(s_cbranch_scc1(), target="case1")
k.emit(s_branch(), target="case2")
k.label("case0")
k.emit(s_nop(0))
k.emit(s_branch(), target="join")
k.label("case1")
k.emit(s_nop(1))
k.emit(s_branch(), target="join")
k.label("case2")
k.emit(s_nop(2))
k.emit(s_branch(), target="join")
k.label("join")
k.emit(s_endpgm())
k.emit(s_code_end())
self.get_cfg("switch_case", k)
def test_ping_pong(self):
k = Kernel()
k.label("entry")
k.emit(s_cmp_eq_i32(s[0], 0))
k.emit(s_cbranch_scc1(), target="ping")
k.emit(s_branch(), target="pong")
k.label("ping")
k.emit(s_cmp_eq_i32(s[1], 0))
k.emit(s_cbranch_scc1(), target="pong")
k.emit(s_branch(), target="end")
k.label("pong")
k.emit(s_cmp_eq_i32(s[2], 0))
k.emit(s_cbranch_scc1(), target="ping")
k.label("end")
k.emit(s_endpgm())
k.emit(s_code_end())
self.get_cfg("ping_pong", k)
def test_colored_blocks(self):
N = 10
k = Kernel()
k.label("entry")
k.emit(s_branch(), target="init0")
for i in range(N):
loop = f"loop{i}"
k.label(f"init{i}")
k.emit(s_mov_b32(s[1], i + 1))
k.emit(s_branch(), target=loop)
k.label(loop)
k.emit(s_nop(i & 7))
k.emit(s_add_u32(s[1], s[1], -1))
k.emit(s_cmp_eq_i32(s[1], 0))
k.emit(s_cbranch_scc0(), target=loop)
k.emit(s_branch(), target=f"init{i+1}" if i + 1 < N else "end")
k.label("end")
k.emit(s_endpgm())
k.emit(s_code_end())
self.get_cfg("test_colored_blocks", k)
def test_jump_back_to_end(self):
k = Kernel()
k.label("entry")
k.emit(s_mov_b32(s[1], 2))
k.emit(s_cbranch_execz(), target="loop")
k.label("end")
k.emit(s_endpgm())
k.label("loop")
k.emit(s_add_u32(s[1], s[1], -1))
k.emit(s_cmp_eq_i32(s[1], 0))
k.emit(s_branch(), target="end")
k.emit(s_code_end())
self.get_cfg("jump_back_to_end", k)
# launch viz cli without subprocess
def run_cli(*cli_args) -> list[dict]:
from tinygrad.viz.cli import main, get_arg_parser
args = get_arg_parser().parse_args(cli_args+("--json",))
with contextlib.redirect_stdout(buf:=io.StringIO()):
main(args)
return [json.loads(line) for line in buf.getvalue().strip().splitlines()]
@contextlib.contextmanager
def write_files(viz) -> list[str]:
with tempfile.TemporaryDirectory() as tmpdir:
(r:=Path(tmpdir)/"rewrites.pkl").write_bytes(pickle.dumps(viz.data.trace))
(p:=Path(tmpdir)/"profile.pkl").write_bytes(pickle.dumps(cpu_events))
yield ["--rewrites-path", str(r), "--profile-path", str(p)]
class TestCLI(unittest.TestCase):
def test_reconstruct_debug(self):
with save_viz() as viz:
Tensor.empty(1, device="NULL").add(2.0).realize()
profile_marker("marker @ 1")
Tensor.empty(1, device="NULL").add(3.0).realize()
with write_files(viz) as files, Context(DEBUG=4):
out = run_cli(*files, "-s", "NULL")
assert any(s.get("value", "").startswith("void E") for s in out)
assert any(s.get("name", "") == "marker @ 1" for s in out)
def test_aggregate(self):
N, CNT = 1024, 5
with save_viz() as viz:
for _ in range(CNT):
(Tensor.empty(N, N, device="NULL")@Tensor.empty(N, N, device="NULL")).realize()
for _ in range(CNT):
(Tensor.empty(N, N, device="NULL").assign(Tensor.empty(N, N, device="NULL"))).realize()
with write_files(viz) as files, Context(NO_COLOR=1):
kernels = run_cli(*files, "-s", "NULL", "-t")
self.assertEqual(len(kernels), 2)
gemm_summary = [s for s in kernels if s["name"].startswith("r_")][0]
copy_summary = [s for s in kernels if s["name"].startswith("E_")][0]
self.assertEqual(gemm_summary["count"], CNT)
self.assertEqual(copy_summary["count"], CNT)
def test_flops(self):
test_n = [(8, 16), (16, 32), (32, 64)]
with save_viz() as viz:
@TinyJit
def f(a, b): return (a@a.T), (b@b.T)
a = Tensor.empty(64, 64, device="NULL")
b = Tensor.empty(64, 64, device="NULL")
for i_val, j_val in test_n:
i = Variable("i", 1, 64).bind(i_val)
j = Variable("j", 1, 64).bind(j_val)
Tensor.realize(*f(a[:i], b[:j]))
with write_files(viz) as files:
out = run_cli(*files, "-s", "NULL")
aggregate = run_cli(*files, "-s", "NULL", "-t")
self.assertEqual(len(out), 3*2)
# flops increases as N gets larger
gflops = [row["fmt"]["FLOPS"] for row in out]
self.assertGreater(gflops[4], gflops[2])
self.assertGreater(gflops[5], gflops[3])
# aggregate flops
self.assertEqual(len(aggregate), 2)
agg_gflops = [row["fmt"]["FLOPS"] for row in aggregate]
assert all(min(gflops) < v < max(gflops) for v in agg_gflops), f"{agg_gflops}"
def test_dedup(self):
with save_viz() as viz:
for _ in range(CNT:=4):
# use kernel names unique to this test
Tensor.custom_kernel(Tensor.empty(4, device="NULL"), fxn=lambda _: UOp.sink(arg=KernelInfo("k1_test_viz_dedup")))[0].realize()
Tensor.custom_kernel(Tensor.empty(8, device="NULL"), fxn=lambda _: UOp.sink(arg=KernelInfo("k2_test_viz_dedup")))[0].realize()
with write_files(viz) as files, Context(NO_COLOR=1):
name = run_cli(*files, "-s", "NULL")[0]["name"]
with Context(DEBUG=3):
select = run_cli(*files, "-s", "NULL", name)
self.assertEqual(len([s for s in select if s.get("value")]), 1, "debug output was not deduped")
self.assertEqual(len([s for s in select if s.get("device") == "NULL"]), CNT, f"expected 4 runs for {name}")
def test_call_graph(self):
@function(precompile=True)
def f(x):
r = x.sum(axis=1).reshape(32, 1).expand(32, 32).contiguous()
return x + r
# turn off scache because this test requires a complete schedule rewrite
with save_viz() as viz, Context(SCACHE=0):
f(f(Tensor.empty(32, 32, device="NULL"))).realize()
with write_files(viz) as files, Context(NO_COLOR=1):
prgs = [s["name"] for s in run_cli(*files, "-s", "NULL")]
with Context(DEBUG=5):
out = run_cli(*files, "-s", "TINY")
i = next(i for i,s in enumerate(out) if s.get("value", "").lstrip() == "View Kernel Graph")
# next print is the CALL graph, CLI outputs exactly as web in TestVizIntegration.test_link_sched_codegen
call_nodes = [n for n in out[i+1].values() if n["label"].startswith("CALL")]
for i,n in enumerate(call_nodes):
assert prgs[i] in n["label"], f"CALL must contain kernel name, got {n['label']}"
def test_interval(self):
def emit_kernel(name:str): Tensor.custom_kernel(Tensor.empty(1, device="NULL"), fxn=lambda _: UOp.sink(arg=KernelInfo(name=name)))[0].realize()
with save_viz() as viz:
emit_kernel("pre_1")
emit_kernel("pre_2")
profile_marker("interval_start")
emit_kernel("target_1")
emit_kernel("target_2")
profile_marker("interval_end")
emit_kernel("post_1")
emit_kernel("post_2")
with write_files(viz) as files, Context(NO_COLOR=1):
flat = run_cli(*files, "-s", "NULL", "--interval", "interval_start", "interval_end")
aggregate = run_cli(*files, "-s", "NULL", "--interval", "interval_start", "interval_end", "-t")
self.assertEqual([s["name"] for s in flat], ["interval_start", "target_1", "target_2", "interval_end"])
self.assertEqual(sorted(s["name"] for s in aggregate), ["target_1", "target_2"])
if __name__ == "__main__":
unittest.main()