Files
tinygrad/test/null/test_viz.py
qazal b1d88ebf02 viz/cli: aggregate flops in -t (#16031)
* 38

* plumbing

* more flops

* flop/s and bytes/s

* arithmetic mean

* tests

* harmonic mean

* range

* better

* simplify

* fix prints

* no string parsing needed
2026-05-04 17:35:02 +03:00

973 lines
40 KiB
Python

import unittest, decimal, sys, json, contextlib, tempfile, pickle, io
from pathlib import Path
from dataclasses import dataclass
from typing import Generator
from tinygrad.uop.ops import UOp, UPat, Ops, PatternMatcher, TrackedPatternMatcher, graph_rewrite, track_rewrites, profile_matches
from tinygrad.uop.symbolic import sym
from tinygrad.dtype import dtypes
from tinygrad.helpers import colored, ansistrip, flatten, TracingKey, ProfileRangeEvent, ProfileEvent, Context, cpu_events, profile_marker
from tinygrad.helpers import cpu_profile, ProfilePointEvent, unwrap
from tinygrad.device import Buffer
from tinygrad.uop.ops import tracked_keys, tracked_ctxs, uop_fields, active_rewrites, active_group, _name_cnt, RewriteTrace
from tinygrad.viz.serve import load_rewrites, get_full_rewrite, uop_to_json, VizData, get_render
from tinygrad.codegen import to_program_cache
from tinygrad.codegen import to_program
@track_rewrites(name=True)
def exec_rewrite(sink:UOp, pm_lst:list[PatternMatcher], names:None|list[str]=None) -> UOp:
for i,pm in enumerate(pm_lst):
sink = graph_rewrite(sink, TrackedPatternMatcher(pm.patterns), name=names[i] if names else None)
return sink
# small container class for the viz server module
class VizTrace:
# loader init
def __init__(self): self._data:VizData|None = None
@property
def data(self) -> VizData: return unwrap(self._data)
def set_data(self) -> None:
data = VizData(RewriteTrace(tracked_keys.copy(), tracked_ctxs.copy(), uop_fields.copy()))
load_rewrites(data)
self._data = data
# the API
def list_items(self) -> list[dict]:
return self.data.ctxs
def get_details(self, rewrite_idx:int, step:int) -> Generator[dict, None, None]:
assert len(self.data.trace.rewrites) > rewrite_idx, f"only loaded {len(self.data.trace.rewrites)} traces, expecting at least {rewrite_idx}"
return get_full_rewrite(self.data, self.data.trace.rewrites[rewrite_idx][step])
@contextlib.contextmanager
def save_viz():
for lst in [tracked_keys, tracked_ctxs, active_rewrites, active_group, _name_cnt]: lst.clear()
to_program_cache.clear()
Buffer.profile_events.clear()
cpu_events.clear()
viz = VizTrace()
with Context(VIZ=-1, TRACK_MATCH_STATS=2, PROFILE=1):
yield viz
viz.set_data()
class TestViz(unittest.TestCase):
def test_simple(self):
with save_viz() as viz:
a = UOp.variable("a", 0, 10)
exec_rewrite((a+0)*1, [sym])
lst = viz.list_items()
# VIZ displays rewrites in groups of tracked functions
self.assertEqual(len(lst), 1)
# each group has a list of steps
self.assertEqual(len(lst[0]["steps"]), 1)
# each step has a list of matches
self.assertEqual(lst[0]["steps"][0]["match_count"], 2)
def test_rewrites(self):
with save_viz() as viz:
a = UOp.variable("a", 0, 10)
exec_rewrite(a*1, [sym])
exec_rewrite(a*2, [sym])
lst = viz.list_items()
self.assertEqual(len(lst), 2)
# names dedup using a counter
self.assertEqual(lst[0]["name"], "exec_rewrite n1")
self.assertEqual(lst[1]["name"], "exec_rewrite n2")
def test_steps(self):
with save_viz() as viz:
a = UOp.variable("a", 0, 10)
exec_rewrite(a+1, [PatternMatcher([]), PatternMatcher([])], ["x", "y"])
steps = viz.list_items()[0]["steps"]
# steps can optionally have a name
self.assertEqual(steps[0]["name"], "x")
self.assertEqual(steps[1]["name"], "y")
def test_rewrite_location(self):
def inner(sink): return graph_rewrite(sink, PatternMatcher([]))
def outer(sink): return inner(sink)
with save_viz() as viz:
outer(UOp.variable("a", 1, 10))
lst = viz.list_items()
# step location comes from inner rewrite
fp, lineno = lst[0]["steps"][0]["loc"]
self.assertEqual(fp, inner.__code__.co_filename)
self.assertEqual(lineno, inner.__code__.co_firstlineno)
def test_exceptions(self):
# VIZ tracks rewrites up to and including the error
def count_3(x:UOp):
assert x.arg <= 3
return x.replace(arg=x.arg+1)
err_pm = PatternMatcher([(UPat.cvar("x"), count_3),])
a = UOp.const(dtypes.int, 1)
with save_viz() as viz:
with self.assertRaises(AssertionError): exec_rewrite(a, [err_pm])
lst = viz.list_items()
err_step = lst[0]["steps"][0]
self.assertEqual(err_step["match_count"], 4) # 3 successful rewrites + 1 err
def test_default_name(self):
with save_viz() as viz:
a = UOp.variable("a", 1, 10)
@track_rewrites()
def name_default(): return graph_rewrite(a, PatternMatcher([]))
name_default()
lst = viz.list_items()
self.assertEqual(lst[0]["name"], "name_default n1")
# name can also come from a function that returns a string
def test_dyn_name_fxn(self):
with save_viz() as viz:
@track_rewrites(name=lambda *args,ret,**kwargs: ret.render())
def name_from_fxn(s:UOp, arg:list|None=None): return graph_rewrite(s, PatternMatcher([]))
name_from_fxn(UOp.variable("a", 1, 10)+1, arg=["test"])
lst = viz.list_items()
# name gets deduped by the function call counter
self.assertEqual(lst[0]["name"], "(a+1) n1")
# name can also come from a function that returns a TracingKey
def test_tracing_key(self):
with save_viz() as viz:
@track_rewrites(name=lambda inp,ret: TracingKey("custom_name", (inp,)))
def test(s:UOp): return graph_rewrite(s, PatternMatcher([]))
test(UOp.variable("a", 1, 10)+1)
lst = viz.list_items()
# NOTE: names from TracingKey do not get deduped
self.assertEqual(lst[0]["name"], "custom_name")
def test_nested_track_rewrites(self):
with save_viz() as viz:
@track_rewrites(name=lambda x,ret: TracingKey(f"inner fxn for {x.render()}", (ret,)))
def inner(x:UOp): return graph_rewrite(x, PatternMatcher([]), name="each")
@track_rewrites(name=lambda *args,ret: f"outer rewrite of {len(args)} inputs")
def outer(*xs:tuple[UOp, ...]): return graph_rewrite(UOp.sink(*[inner(x) for x in xs]), PatternMatcher([]), name="all")
items = ["a", "b", "c"]
outer(*[UOp.variable(x, 1, 10) for x in items])
lst = viz.list_items()
# inner calls fall outside the outer call
self.assertEqual(len(lst), len(items)+1)
self.assertEqual(lst[0]["name"], f"outer rewrite of {len(items)} inputs n1")
steps = lst[0]["steps"]
self.assertEqual(len(steps), 1)
self.assertEqual(steps[0]["name"], "all")
for i in range(len(items)):
self.assertEqual(lst[i+1]["name"], f"inner fxn for {items[i]}")
steps = lst[i+1]["steps"]
self.assertEqual(len(steps), 1)
self.assertEqual(steps[0]["name"], "each")
def test_profile_matches(self):
with save_viz() as viz:
@profile_matches
def nested_function(u:UOp):
for i in range(2): graph_rewrite(u, PatternMatcher([]), name=f"step {i+1}")
@track_rewrites()
def main_rewrite(u:UOp):
graph_rewrite(u, PatternMatcher([]), name="init")
nested_function(u)
main_rewrite(UOp.variable("a", 1, 10)+UOp.variable("b", 1, 10))
steps = viz.list_items()[0]["steps"]
self.assertEqual(steps[0]["name"], "init")
self.assertEqual(steps[1]["name"], "nested_function")
self.assertEqual(len(steps), 4)
def test_profile_matches_invalid_arg(self):
with save_viz():
@profile_matches
def invalid_fxn(arg:str): return graph_rewrite(UOp(Ops.SINK), PatternMatcher([]))
with self.assertRaisesRegex(AssertionError, "invalid match tracing input"):
invalid_fxn("test")
def test_colored_label(self):
# NOTE: dataclass repr prints literal escape codes instead of unicode chars
@dataclass(frozen=True)
class TestStruct:
colored_field: str
a = UOp(Ops.CUSTOM, arg=TestStruct(colored("xyz", "magenta")+colored("12345", "blue")))
a2 = uop_to_json(VizData(), a)[id(a)]
self.assertEqual(ansistrip(a2["label"]), f"CUSTOM\n{TestStruct.__qualname__}(colored_field='xyz12345')")
def test_colored_label_multiline(self):
with save_viz() as viz:
arg = colored("x", "green")+"\n"+colored("y", "red")+colored("z", "yellow")+colored("ww\nw", "magenta")
src = [Tensor.empty(1).uop for _ in range(10)]
a = UOp(Ops.CUSTOM, src=tuple(src), arg=arg)
exec_rewrite(a, [PatternMatcher([])])
a2 = next(viz.get_details(0, 0))["graph"][id(a)]
self.assertEqual(ansistrip(a2["label"]), "CUSTOM\nx\nyzww\nw")
def test_inf_loop(self):
a = UOp.const(dtypes.int, 3)
b = UOp.const(dtypes.int, 4)
pm = PatternMatcher([
(UPat(Ops.CONST, arg=3, name="x"), lambda x: x.replace(arg=4)),
(UPat(Ops.CONST, arg=4, name="x"), lambda x: x.replace(arg=3)),
])
with save_viz() as viz:
# use smaller stack limit for faster test (default is 250000)
with Context(REWRITE_STACK_LIMIT=100): self.assertRaises(RuntimeError, exec_rewrite, a, [pm])
graphs = flatten(x["graph"].values() for x in viz.get_details(0, 0))
self.assertEqual(graphs[0], uop_to_json(VizData(), a)[id(a)])
self.assertEqual(graphs[1], uop_to_json(VizData(), b)[id(b)])
# fallback to NOOP with the error message
nop = UOp(Ops.NOOP, arg="infinite loop in fixed_point_rewrite")
self.assertEqual(graphs[2], uop_to_json(VizData(), nop)[id(nop)])
def test_const_node_visibility(self):
with save_viz() as viz:
a = UOp.variable("a", 0, 10, dtype=dtypes.int)
z = UOp.const(a.dtype, 0)
alu = a*z
exec_rewrite(alu, [sym])
lst = viz.list_items()
self.assertEqual(len(lst), 1)
graphs = [x["graph"] for x in viz.get_details(0, 0)]
# embed const in the parent node when possible
self.assertEqual(list(graphs[0]), [id(a), id(alu)])
self.assertEqual(list(graphs[1]), [id(z)])
# TODO: DEFINE_VAR (shape ()) now gets wrapped in RESHAPE+EXPAND when broadcast against a shaped operand
# (due to shared OpMixin._binop using _broadcasted). Either extend viz to fold RESHAPE/EXPAND around
# DEFINE_VAR/RANGE/SPECIAL the way it does for CONST, or redesign scalar-compiler-op broadcasting.
@unittest.expectedFailure
def test_const_reshape_expand_folded(self):
# CONST->RESHAPE->EXPAND should be folded into the ALU node, not shown as separate RESHAPE/EXPAND nodes
c = UOp.const(dtypes.float, 1.0, device="CPU", shape=(3,4)) # creates CONST->RESHAPE->EXPAND chain
a = UOp(Ops.DEFINE_VAR, dtypes.float, arg=("a", 0.0, 10.0))
alu = a + c
graph = uop_to_json(VizData(), alu)
# the RESHAPE and EXPAND nodes from the const should not appear in the graph
labels = {v["label"].split("\n")[0] for v in graph.values()}
self.assertNotIn("RESHAPE", labels)
self.assertNotIn("EXPAND", labels)
# the CONST should be inlined into the ALU node's label
alu_label = graph[id(alu)]["label"]
self.assertIn("CONST", alu_label)
# VIZ displays nested graph_rewrites in a tree view
def leaf_rewrite(x:UOp): return x.rtag(1) if x.tag is None else None
leaf = TrackedPatternMatcher([(UPat(Ops.DEFINE_VAR, name="x"), leaf_rewrite)])
def branch_rewrite(x:UOp, y:UOp):
if x.tag is not None: return
x2 = graph_rewrite(x, leaf, name="leaf_left")
y2 = graph_rewrite(y, leaf, name="leaf_right")
return x2 * y2
branch = TrackedPatternMatcher([(UPat.var("x")+UPat.var("y"), branch_rewrite)])
def root_rewrite(root:UOp):
new_src = tuple(graph_rewrite(b, branch, name=f"branch_{i}") for i,b in enumerate(root.src))
return root.replace(src=new_src)
root = TrackedPatternMatcher([(UPat(Ops.SINK, src=UPat(Ops.ADD), name="root"), root_rewrite),])
class TestVizTree(unittest.TestCase):
def assertStepEqual(self, step:dict, want:dict):
for k,v in want.items():
self.assertEqual(step[k], v, f"failed at '{k}': {v} != {step[k]}\n{step=}")
def test_tree_view(self):
with save_viz() as viz:
a = UOp.variable("a",0,10)
b = UOp.variable("b",0,10)
c = UOp.variable("c",0,10)
d = UOp.variable("d",0,10)
sink = UOp.sink(a+b, c+d)
def tree_rewrite(): return graph_rewrite(sink, root, name="root")
tree_rewrite()
lst = viz.list_items()
steps = lst[0]["steps"]
self.assertEqual(len(steps), 1+2+4)
self.assertStepEqual(steps[0], {"name":"root", "depth":0, "match_count":1})
self.assertStepEqual(steps[1], {"name":"branch_0", "depth":1, "match_count":1})
self.assertStepEqual(steps[2], {"name":"leaf_left", "depth":2, "match_count":1})
self.assertStepEqual(steps[3], {"name":"leaf_right", "depth":2, "match_count":1})
self.assertStepEqual(steps[4], {"name":"branch_1", "depth":1, "match_count":1})
self.assertStepEqual(steps[5], {"name":"leaf_left", "depth":2, "match_count":1})
self.assertStepEqual(steps[6], {"name":"leaf_right", "depth":2, "match_count":1})
import gc
def bufs_allocated() -> int:
gc.collect()
return sum([type(x).__name__ == "Buffer" and type(x).__module__ == "tinygrad.device" for x in gc.get_objects()])
class TestVizGC(unittest.TestCase):
def test_gc(self):
with save_viz() as viz:
init = bufs_allocated()
a = UOp.new_buffer("NULL", 10, dtypes.char)
a.buffer.allocate()
exec_rewrite(a, [PatternMatcher([])])
del a
self.assertEqual(bufs_allocated()-init, 0)
lst = viz.list_items()
self.assertEqual(len(lst), 1)
@unittest.skip("it's not generic enough to handle arbitrary UOps in arg")
def test_gc_uop_in_arg(self):
with save_viz() as viz:
init = bufs_allocated()
a = UOp.new_buffer("NULL", 10, dtypes.char)
a.buffer.allocate()
exec_rewrite(UOp(Ops.CUSTOM, src=(a,), arg=a), [PatternMatcher([])])
del a
self.assertEqual(bufs_allocated()-init, 0)
lst = viz.list_items()
self.assertEqual(len(lst), 1)
# VIZ integrates with other parts of tinygrad
from tinygrad import Tensor, Device, TinyJit, Variable
class TestVizIntegration(unittest.TestCase):
# codegen supports rendering of code blocks
def test_codegen_tracing(self):
with save_viz() as viz:
ast = (Tensor.empty(4)+Tensor.empty(4)).schedule_linear().src[0].src[0]
prg = to_program(ast, Device[Device.DEFAULT].renderer)
lst = viz.list_items()
self.assertEqual(len(lst), 3)
self.assertEqual(lst[0]["name"], "Callify 1 Buffer n1")
self.assertEqual(lst[1]["name"], "Schedule 1 Kernel n1")
self.assertEqual(lst[2]["name"], prg.arg.name)
# schedule graph CALL nodes have a link to jump to codegen
def test_link_sched_codegen(self):
with save_viz() as viz:
c1 = Tensor.empty(4).add(1)
c2 = Tensor.empty(8).add(1)
sched = c1.schedule_linear(c2)
prgs = [to_program(si.src[0], Device[Device.DEFAULT].renderer).arg.name for si in sched.src]
lst = viz.list_items()
sched_idx = next(i for i,l in enumerate(lst) if l["name"].startswith("Schedule"))
viz_kernel = next(i for i,s in enumerate(lst[sched_idx]["steps"]) if s["name"] == "View Kernel Graph")
graph = next(viz.get_details(sched_idx, viz_kernel))["graph"]
call_nodes = [n for n in graph.values() if n["label"].startswith("CALL")]
for i,n in enumerate(call_nodes):
assert n["ref"] is not None
self.assertEqual(lst[n["ref"]]["name"], prgs[i])
@Context(TRACEMETA=2)
def test_metadata_tracing(self):
with save_viz() as viz:
a = Tensor.empty(1)
b = Tensor.empty(1)
metadata = (alu:=a+b).uop.metadata
alu.schedule_linear()
graph = next(viz.get_details(0, 0))["graph"]
self.assertEqual(len([n for n in graph.values() if repr(metadata) in n["label"]]), 1)
# tracing also works without a track_rewrites context
# all graph_rewrites get put into the default group
def test_default_tracing(self):
with save_viz() as viz:
def test(root):
return graph_rewrite(root, sym)
test(c:=UOp.const(dtypes.int, 1))
test(c+1)
ls = viz.list_items()
self.assertEqual(len(ls), 1)
self.assertEqual(ls[0]["name"], "default graph_rewrite")
# using @track_rewrites organizes function calls into groups
# and nicely counts function calls.
def test_group_traces(self):
with save_viz() as viz:
@track_rewrites()
def test(root):
return graph_rewrite(root, sym)
test(c:=UOp.const(dtypes.int, 1))
test(c+1)
ls = viz.list_items()
self.assertEqual(len(ls), 2)
for i in range(2): self.assertEqual(ls[i]["name"], f"test n{i+1}")
# @track_rewrites always starts a new group.
def test_group_combined(self):
with save_viz() as viz:
def default_test(root): return graph_rewrite(root, sym)
tracked_test = track_rewrites()(default_test)
c = UOp.const(dtypes.int, 1)
default_test(c+1) # goes to the default group
tracked_test(c) # all rewrites after this go inside the second group.
default_test(c+2)
ls = viz.list_items()
self.assertEqual(len(ls), 2)
self.assertEqual(list(next(viz.get_details(0, 0))["graph"]), [id(c+1)])
self.assertEqual(list(next(viz.get_details(1, 0))["graph"]), [id(c)])
self.assertEqual(list(next(viz.get_details(1, 1))["graph"]), [id(c+2)])
def test_recurse(self):
with save_viz() as viz:
a = Tensor.empty(10)
for _ in range(10_000): a += a
graph_rewrite(a.uop, PatternMatcher([]))
lst = viz.list_items()
assert len(lst) == 1
def test_jit(self):
with save_viz():
@TinyJit
def f(a, b, c): return (a+b).contiguous().mul(3), c.assign(a.to(c.device))
a, b, c = Tensor.empty(16, device="NULL"), Tensor.empty(16, device="NULL"), Tensor.empty(16, device="NULL:1")
for _ in range(3): Tensor.realize(*f(a, b, c))
out = load_profile(cpu_events)
self.assertEqual(["NULL", "NULL Graph", "NULL:SDMA:0"], [k for k in out["layout"] if k.startswith("NULL")])
self.assertEqual(len(out["layout"]["NULL"]["events"]), 2*3)
self.assertEqual(len(out["layout"]["NULL:SDMA:0"]["events"]), 3)
self.assertEqual(len(out["layout"]["NULL Graph"]["events"]), 2)
from tinygrad.device import ProfileDeviceEvent, ProfileGraphEvent, ProfileGraphEntry
from tinygrad.viz.serve import get_profile
from tinygrad.viz.cli import decode_profile
def load_profile(lst:list[ProfileEvent]) -> dict: return decode_profile(get_profile(VizData(), lst))
class TestVizProfiler(unittest.TestCase):
def test_transfer_uses_copy_device(self):
with save_viz():
a = Tensor.ones(1, device="NULL").contiguous().realize()
a.to("NULL:1").realize()
range_events = [e for e in cpu_events if isinstance(e, ProfileRangeEvent)]
compute_events = [e for e in range_events if e.device == "NULL"]
copy_events = [e for e in range_events if e.device.endswith(":SDMA:0")]
self.assertGreater(len(compute_events), 0, "expected compute events on base device")
self.assertGreater(len(copy_events), 0, "transfer must produce events with ':SDMA' device suffix")
def test_node(self):
prof = [ProfileRangeEvent(device='NV', name='E_2', st=decimal.Decimal(1000), en=decimal.Decimal(1010)),
ProfileDeviceEvent(device='NV', tdiff=decimal.Decimal(-1000))]
j = load_profile(prof)
dev_events = j['layout']['NV']['events']
self.assertEqual(len(dev_events), 1)
event = dev_events[0]
self.assertEqual(event['name'], 'E_2')
self.assertEqual(event['st'], 0)
self.assertEqual(event['dur'], 10)
assert event['ref'] is None
def test_copy_node(self):
prof = [ProfileRangeEvent(device='NV:SDMA:0', name='COPYxx', st=decimal.Decimal(1000), en=decimal.Decimal(1010)),
ProfileRangeEvent(device='NV:2:SDMA:0', name='COPYxx', st=decimal.Decimal(1000), en=decimal.Decimal(1010)),
ProfileDeviceEvent(device='NV:SDMA:0', tdiff=decimal.Decimal(-100)),
ProfileDeviceEvent(device='NV:2:SDMA:0', tdiff=decimal.Decimal(-80))]
j = load_profile(prof)
event = j['layout']['NV:SDMA:0']['events'][0]
self.assertEqual(event['name'], 'COPYxx')
self.assertEqual(event['st'], 0) # first event
self.assertEqual(event['dur'], 10)
event2 = j['layout']['NV:2:SDMA:0']['events'][0]
self.assertEqual(event2['st'], 20) # second event, diff clock
self.assertEqual(j["dur"], (event2["st"]+event2["dur"])-event["st"])
def test_copy_node_bandwidth(self):
sz = 256*1024*1024
dur = 10_000
prof = [ProfileRangeEvent(device='NV:SDMA:0', name=TracingKey("NV -> NV:1", ret=sz), st=decimal.Decimal(1000), en=decimal.Decimal(1000+dur)),
ProfileDeviceEvent(device='NV:SDMA:0', tdiff=decimal.Decimal(-1000))]
j = load_profile(prof)
event = j['layout']['NV:SDMA:0']['events'][0]
self.assertEqual(event['fmt'], {"B/s": sz/(dur*1e-6), "B": sz})
def test_graph(self):
prof = [ProfileDeviceEvent(device='NV', tdiff=decimal.Decimal(-1000)),
ProfileDeviceEvent(device='NV:1:SDMA:0', tdiff=decimal.Decimal(-50)),
ProfileGraphEvent(ents=[ProfileGraphEntry(device='NV', name='E_25_4n2', st_id=0, en_id=1),
ProfileGraphEntry(device='NV:1:SDMA:0', name='NV -> NV:1', st_id=2, en_id=3)],
deps=[[], [0]],
sigs=[decimal.Decimal(1000), decimal.Decimal(1002), decimal.Decimal(1004), decimal.Decimal(1008)])]
j = load_profile(prof)
tracks = list(j['layout'])
self.assertEqual(tracks[0], 'NV')
self.assertEqual(tracks[1], 'NV Graph')
self.assertEqual(tracks[2], 'NV:1:SDMA:0')
nv_events = j['layout']['NV']['events']
self.assertEqual(nv_events[0]['name'], 'E_25_4n2')
self.assertEqual(nv_events[0]['st'], 0)
self.assertEqual(nv_events[0]['dur'], 2)
sdma_events = j['layout']['NV:1:SDMA:0']['events']
self.assertEqual(sdma_events[0]['name'], 'NV -> NV:1')
self.assertEqual(sdma_events[0]['st'], 954)
graph_events = j['layout']['NV Graph']['events']
self.assertEqual(graph_events[0]['st'], nv_events[0]['st'])
self.assertEqual(graph_events[0]['st']+graph_events[0]['dur'], sdma_events[0]['st']+sdma_events[0]['dur'])
def test_graph_copy_bandwidth(self):
sz = 256*1024*1024
dur = 10_000
prof = [ProfileDeviceEvent(device='NV', tdiff=decimal.Decimal(-1000)),
ProfileDeviceEvent(device='NV:1:SDMA:0', tdiff=decimal.Decimal(-50)),
ProfileGraphEvent(ents=[ProfileGraphEntry(device='NV:1:SDMA:0', name=TracingKey("NV -> NV:1", ret=sz), st_id=0, en_id=1)],
deps=[[]],
sigs=[decimal.Decimal(1004), decimal.Decimal(1004+dur)])]
j = load_profile(prof)
sdma_events = j['layout']['NV:1:SDMA:0']['events']
self.assertEqual(sdma_events[0]["fmt"], {"B/s": sz/(dur*1e-6), "B": sz})
def test_block_ordering(self):
prof = [ProfileDeviceEvent(device='NV', tdiff=decimal.Decimal(-1000)),
ProfileDeviceEvent(device='NV:1', tdiff=decimal.Decimal(-500)),
ProfileDeviceEvent(device='NV:SDMA:0', tdiff=decimal.Decimal(-100)),
ProfileRangeEvent(device='NV', name='E_2', st=decimal.Decimal(1000), en=decimal.Decimal(1010)),
ProfileRangeEvent(device='NV:1', name='E_3', st=decimal.Decimal(1000), en=decimal.Decimal(1010)),
ProfileRangeEvent(device='NV:SDMA:0', name='COPY', st=decimal.Decimal(1000), en=decimal.Decimal(1010)),
ProfileGraphEvent(ents=[ProfileGraphEntry(device='NV', name='E_2', st_id=0, en_id=1)],
deps=[[]], sigs=[decimal.Decimal(1000), decimal.Decimal(1010)])]
j = load_profile(prof)
# graph grouped with its device, memory at the end
self.assertListEqual(list(j['layout']), ['NV', 'NV Graph', 'NV:SDMA:0', 'NV:1'])
@unittest.skipIf(sys.platform == 'win32', "TODO: ops_amd import fails on windows")
def test_multi_sdma_ordering(self):
props = {"gfx_target_version": 0}
D, St, En = decimal.Decimal, decimal.Decimal(1000), decimal.Decimal(1010)
prof = [# 2 AMD GPUs, 2 SDMA engines each
ProfileDeviceEvent(device='AMD', tdiff=D(-1000), props=props),
ProfileDeviceEvent(device='AMD:1', tdiff=D(-900), props=props),
ProfileDeviceEvent(device='AMD:SDMA:0', tdiff=D(-100), props=props),
ProfileDeviceEvent(device='AMD:SDMA:1', tdiff=D(-80), props=props),
ProfileDeviceEvent(device='AMD:1:SDMA:0', tdiff=D(-60), props=props),
ProfileDeviceEvent(device='AMD:1:SDMA:1', tdiff=D(-40), props=props),
# compute + copy events
ProfileRangeEvent(device='AMD', name='E_1', st=St, en=En),
ProfileRangeEvent(device='AMD:1', name='E_2', st=St, en=En),
ProfileRangeEvent(device='AMD:SDMA:0', name='COPY0', st=St, en=En),
ProfileRangeEvent(device='AMD:SDMA:1', name='COPY1', st=St, en=En),
ProfileRangeEvent(device='AMD:1:SDMA:0', name='COPY2', st=St, en=En),
ProfileRangeEvent(device='AMD:1:SDMA:1', name='COPY3', st=St, en=En),
# graph spanning compute + copy on GPU 0
ProfileGraphEvent(ents=[ProfileGraphEntry(device='AMD', name='E_1', st_id=0, en_id=1),
ProfileGraphEntry(device='AMD:SDMA:0', name='COPY0', st_id=2, en_id=3)],
deps=[[], [0]], sigs=[St, En, St, En]),
# memory alloc on both GPUs
ProfilePointEvent(device='AMD', name='alloc', key=0, arg={"sz":1024, "dtype":dtypes.float}, ts=St),
ProfilePointEvent(device='AMD:1', name='alloc', key=1, arg={"sz":512, "dtype":dtypes.float}, ts=St)]
j = load_profile(prof)
# graph grouped with its device, memory at the end
self.assertListEqual(list(j['layout']),
['AMD', 'AMD Graph', 'AMD:SDMA:0', 'AMD:SDMA:1',
'AMD:1', 'AMD:1:SDMA:0', 'AMD:1:SDMA:1',
'AMD Memory', 'AMD:1 Memory'])
def test_bytes_per_kernel(self):
step = 10
n_events = 1_000
prof = [ProfileRangeEvent("CPU", name="k_test", st=decimal.Decimal(ts:=i*step), en=decimal.Decimal(ts)+step) for i in range(n_events)]
sz = len(get_profile(VizData(), prof))
self.assertLessEqual(sz/n_events, 26)
def test_calltrace(self):
with save_viz() as viz:
def fxn(): return Tensor.empty(10).mul(2).realize()
with cpu_profile(TracingKey("test_fxn"), "CUSTOM"):
fxn()
codegen_trace = viz.list_items()[0]["steps"][0]["trace"]
assert any(fxn.__code__.co_filename == f and fxn.__code__.co_firstlineno == l for f,l,*_ in codegen_trace), str(codegen_trace)
profile_ret = load_profile(cpu_events)
e = profile_ret["layout"]["CUSTOM"]["events"][0]
self.assertEqual(e["name"], "test_fxn")
runtime_trace = e["fmt"]["tb"]
assert any(fxn.__code__.co_filename == f and fxn.__code__.co_firstlineno+1 == l for f,l,*_ in runtime_trace), str(runtime_trace)
# can pack up to 1hr 11 min of trace events
def test_trace_duration(self):
dur_mins = 72
n_events = 1_000
step = decimal.Decimal(dur_mins*60*1e6//n_events)
prof = [ProfileRangeEvent("CPU", name="k_test", st=decimal.Decimal(ts:=i*step), en=decimal.Decimal(ts)+step) for i in range(n_events)]
with self.assertRaisesRegex(ValueError, "timestamp out of range"):
get_profile(VizData(), prof)
def test_python_marker(self):
with save_viz():
a = Tensor.empty(1, device="NULL")
b = Tensor.empty(1, device="NULL")
(a+b).realize()
profile_marker("test 1")
(a*b).realize()
profile_marker("test 2")
profile_ret = load_profile(cpu_events)
markers = profile_ret["markers"]
kernels = profile_ret["layout"]["NULL"]["events"]
self.assertEqual(len(markers), 2)
assert kernels[0]["st"] <= markers[0]["ts"] <= kernels[1]["st"]
assert markers[1]["ts"] >= kernels[1]["st"]+kernels[1]["dur"]
def test_layout_order(self):
with save_viz():
def fn(): return
for dname in ["TINY", "USER", "TEST:1 N1", "TEST:2 N1", "TEST:1 N2", "TEST:1:ENGINE:0", "TEST:1:ENGINE:0 N1", "TEST:1"]:
with cpu_profile("fn", dname): fn()
layout = list(load_profile(cpu_events)["layout"])
self.assertListEqual(layout[:2], ["USER","TINY"])
self.assertListEqual(layout[2:], ["TEST:1", "TEST:1 N1", "TEST:1 N2", "TEST:1:ENGINE:0", "TEST:1:ENGINE:0 N1", "TEST:2 N1"])
def _alloc(b:int):
a = Tensor.empty(b, device="NULL", dtype=dtypes.char)
a.uop.buffer.allocate()
return a
class TestVizMemoryLayout(unittest.TestCase):
def test_double_alloc(self):
with save_viz():
a = _alloc(1)
_b = _alloc(1)
profile_ret = load_profile(Buffer.profile_events)
ret = profile_ret["layout"][f"{a.device} Memory"]
self.assertEqual(ret["peak"], 2)
self.assertEqual(len(ret["events"]), 4)
def test_del_once(self):
with save_viz():
a = _alloc(1)
del a
b = _alloc(1)
profile_ret = load_profile(Buffer.profile_events)
ret = profile_ret["layout"][f"{b.device} Memory"]
self.assertEqual(ret["peak"], 1)
self.assertEqual(len(ret["events"]), 4)
def test_alloc_free(self):
with save_viz():
a = _alloc(1)
_b = _alloc(1)
del a
c = _alloc(1)
profile_ret = load_profile(Buffer.profile_events)
ret = profile_ret["layout"][f"{c.device} Memory"]
self.assertEqual(ret["peak"], 2)
self.assertEqual(len(ret["events"]), 6)
def test_free_last(self):
with save_viz():
bufs = []
for _ in range(3):
bufs.append(_alloc(1))
profile_marker("alloc")
device = bufs[0].device
while bufs:
b = bufs.pop()
del b
profile_marker("free")
profile = load_profile(cpu_events+Buffer.profile_events)
ret = profile["layout"][f"{device} Memory"]
self.assertEqual(ret["peak"], 3)
self.assertEqual(len(ret["events"]), 6)
self.assertEqual(len(profile["markers"]), 6)
def test_producer_simple(self):
with save_viz():
a = Tensor.ones(10, device="NULL")
Tensor.realize(a.add(1).contiguous())
b = Tensor.ones(10, device="NULL")
Tensor.realize(b.add(1).contiguous())
profile = load_profile(cpu_events+Buffer.profile_events)
buffers = profile["layout"]["NULL Memory"]["events"]
programs = profile["layout"]["NULL"]["events"]
user_cnt = [len(b["arg"]["users"]) for b in buffers if b["arg"].get("users")]
self.assertEqual(len(user_cnt), len(programs))
@unittest.skip("flaky")
def test_inflight_buf(self):
a = Tensor.empty(1, device="NULL")
n = 4
for i in range(n): (a+i).realize()
profile = load_profile(cpu_events+Buffer.profile_events)
buffers = profile["layout"]["NULL Memory"]["events"]
user_cnt = [len(b["arg"]["users"]) for b in buffers if b["arg"].get("users")]
self.assertEqual(max(user_cnt), n)
input_buf = buffers.pop()
assert all(u[3] == 0 for u in input_buf["arg"]["users"])
def test_annotate_read_write(self):
with save_viz():
a = Tensor.ones(4, device="NULL").contiguous().realize()
b = a.assign(a+2)
c = a+1
Tensor.realize(b, c)
buf_events = load_profile(cpu_events+Buffer.profile_events)["layout"]["NULL Memory"]["events"]
users = next((b["arg"]["users"] for b in buf_events if len(b["arg"].get("users",[])) == 3))
self.assertEqual(users[0][3], 1) # write Tensor.ones
self.assertEqual(users[1][3], 2) # read+write Tensor.assign
self.assertEqual(users[2][3], 0) # readonly
def test_dedup_users(self):
with save_viz():
a = Tensor.empty(1, device="NULL")
for _ in range(n:=4): a.add(1).realize()
profile = load_profile(cpu_events+Buffer.profile_events)
programs = profile["layout"][a.device]["events"]
users = profile["layout"][f"{a.device} Memory"]["events"].pop()["arg"]["users"]
self.assertEqual(len(programs), len(set(users)), n)
from tinygrad.uop.ops import KernelInfo
from tinygrad.renderer.amd.dsl import s
from tinygrad.runtime.autogen.amd.rdna3.ins import (s_add_u32, s_branch, s_cbranch_execz, s_cbranch_scc0, s_cbranch_scc1, s_cmp_eq_i32,
s_cmp_eq_u64, s_code_end, s_endpgm, s_mov_b32, s_nop)
from extra.gemm.amd_asm_matmul import Kernel
class TestCfg(unittest.TestCase):
def setUp(self): self.arch = "gfx1100"
def get_cfg(self, name:str, k:Kernel):
insts = k.finalize()
def fxn(out:UOp) -> UOp:
lidx = UOp.special(1, "lidx0")
gidx = UOp.special(1, "gidx0")
sink = UOp.sink(out.base, lidx, gidx, arg=KernelInfo(name=name))
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg="NULL"), UOp(Ops.LINEAR, src=tuple([UOp(Ops.INS, arg=x) for x in insts]))))
with save_viz() as viz:
with Context(DEV=f"NULL::{self.arch}"):
out = Tensor.custom_kernel(Tensor.empty(1), fxn=fxn)[0]
_ = to_program(out.schedule_linear().src[-1].src[0], Device[out.device].renderer)
codegen_rewrites = next(s for s in viz.list_items() if s["name"] == name)
disasm = next(s for s in codegen_rewrites["steps"] if s["name"] == "View Disassembly")
return get_render(viz.data, disasm["query"])
def test_simple(self):
k = Kernel()
k.label("entry")
k.emit(s_branch(), target="bb1")
k.label("bb1")
k.emit(s_endpgm())
k.emit(s_code_end())
cfg = self.get_cfg("simple", k)["data"]
self.assertEqual(len(cfg["blocks"]), 2)
def test_diamond(self):
k = Kernel()
k.label("entry")
k.emit(s_mov_b32(s[0], 0))
k.emit(s_mov_b32(s[1], 0))
k.emit(s_cmp_eq_u64(s[0:1], 0))
k.emit(s_cbranch_scc1(), target="if")
k.emit(s_branch(), target="else")
k.label("if")
k.emit(s_nop(1))
k.emit(s_branch(), target="end")
k.label("else")
k.emit(s_nop(0))
k.label("end")
k.emit(s_endpgm())
k.emit(s_code_end())
ret = self.get_cfg("diamond", k)
cfg = ret["data"]
self.assertEqual(len(cfg["blocks"]), 5)
edge_count = sum(len(v) for v in cfg["paths"].values())
self.assertEqual(edge_count, 5)
references:dict[str, list[str]] = {}
for pc, tokens in cfg["pc_tokens"].items():
for t in tokens:
for key in t["keys"]: references.setdefault(key, []).append(pc)
self.assertEqual(len(references["r0"]), 2)
insts = [cfg["pc_tokens"][pc][0]["st"] for pc in references["r0"]]
self.assertEqual(insts, ['s_mov_b32', 's_cmp_eq_u64'])
end_block = [" ".join(t["st"] for t in cfg["pc_tokens"][pc]) for pc in list(cfg["blocks"].values())[-1]]
code_line = ret["src"].splitlines()[-1]
self.assertEqual(len(end_block), 2)
for st in [end_block[-1], code_line]:
assert st.startswith("s_code_end") and st.endswith("x)"), st
def test_loop(self):
k = Kernel()
k.label("entry")
k.emit(s_mov_b32(s[1], 4))
k.label("loop")
k.emit(s_add_u32(s[1], s[1], -1))
k.emit(s_cmp_eq_i32(s[1], 0))
k.emit(s_cbranch_scc0(), target="loop")
k.emit(s_endpgm())
k.emit(s_code_end())
self.get_cfg("simple_loop", k)
def test_loop_branch(self):
k = Kernel()
k.label("entry")
k.emit(s_mov_b32(s[1], 4))
k.label("loop")
k.emit(s_add_u32(s[1], s[1], -1))
k.emit(s_cmp_eq_i32(s[1], 2))
k.emit(s_cbranch_scc1(), target="cond")
k.emit(s_branch(), target="cont")
k.label("cond")
k.emit(s_add_u32(s[1], s[1], -2))
k.label("cont")
k.emit(s_cmp_eq_i32(s[1], 0))
k.emit(s_cbranch_scc0(), target="loop")
k.emit(s_endpgm())
k.emit(s_code_end())
self.get_cfg("loop_if", k)
def test_loop_break(self):
k = Kernel()
k.label("entry")
k.emit(s_mov_b32(s[1], 8))
k.label("loop")
k.emit(s_add_u32(s[1], s[1], -1))
k.emit(s_cmp_eq_i32(s[1], 5))
k.emit(s_cbranch_scc1(), target="break")
k.emit(s_cmp_eq_i32(s[1], 0))
k.emit(s_cbranch_scc0(), target="loop")
k.label("break")
k.emit(s_endpgm())
k.emit(s_code_end())
self.get_cfg("loop_break", k)
def test_switch(self):
k = Kernel()
k.label("entry")
k.emit(s_cmp_eq_i32(s[0], 0))
k.emit(s_cbranch_scc1(), target="case0")
k.emit(s_cmp_eq_i32(s[0], 1))
k.emit(s_cbranch_scc1(), target="case1")
k.emit(s_branch(), target="case2")
k.label("case0")
k.emit(s_nop(0))
k.emit(s_branch(), target="join")
k.label("case1")
k.emit(s_nop(1))
k.emit(s_branch(), target="join")
k.label("case2")
k.emit(s_nop(2))
k.emit(s_branch(), target="join")
k.label("join")
k.emit(s_endpgm())
k.emit(s_code_end())
self.get_cfg("switch_case", k)
def test_ping_pong(self):
k = Kernel()
k.label("entry")
k.emit(s_cmp_eq_i32(s[0], 0))
k.emit(s_cbranch_scc1(), target="ping")
k.emit(s_branch(), target="pong")
k.label("ping")
k.emit(s_cmp_eq_i32(s[1], 0))
k.emit(s_cbranch_scc1(), target="pong")
k.emit(s_branch(), target="end")
k.label("pong")
k.emit(s_cmp_eq_i32(s[2], 0))
k.emit(s_cbranch_scc1(), target="ping")
k.label("end")
k.emit(s_endpgm())
k.emit(s_code_end())
self.get_cfg("ping_pong", k)
def test_colored_blocks(self):
N = 10
k = Kernel()
k.label("entry")
k.emit(s_branch(), target="init0")
for i in range(N):
loop = f"loop{i}"
k.label(f"init{i}")
k.emit(s_mov_b32(s[1], i + 1))
k.emit(s_branch(), target=loop)
k.label(loop)
k.emit(s_nop(i & 7))
k.emit(s_add_u32(s[1], s[1], -1))
k.emit(s_cmp_eq_i32(s[1], 0))
k.emit(s_cbranch_scc0(), target=loop)
k.emit(s_branch(), target=f"init{i+1}" if i + 1 < N else "end")
k.label("end")
k.emit(s_endpgm())
k.emit(s_code_end())
self.get_cfg("test_colored_blocks", k)
def test_jump_back_to_end(self):
k = Kernel()
k.label("entry")
k.emit(s_mov_b32(s[1], 2))
k.emit(s_cbranch_execz(), target="loop")
k.label("end")
k.emit(s_endpgm())
k.label("loop")
k.emit(s_add_u32(s[1], s[1], -1))
k.emit(s_cmp_eq_i32(s[1], 0))
k.emit(s_branch(), target="end")
k.emit(s_code_end())
self.get_cfg("jump_back_to_end", k)
# launch viz cli without subprocess
def run_cli(*cli_args) -> str:
from tinygrad.viz.cli import main, get_arg_parser
args = get_arg_parser().parse_args(cli_args)
with contextlib.redirect_stdout(buf:=io.StringIO()):
main(args)
return buf.getvalue().strip()
def call_cli(fxn, *cli_args, debug=2) -> str:
with save_viz() as viz:
fxn()
with tempfile.TemporaryDirectory() as tmpdir:
(r:=Path(tmpdir)/"rewrites.pkl").write_bytes(pickle.dumps(viz.data.trace))
(p:=Path(tmpdir)/"profile.pkl").write_bytes(pickle.dumps(cpu_events))
with Context(DEBUG=debug, NO_COLOR=1):
stdout = run_cli("--rewrites-path", str(r), "--profile-path", str(p), *cli_args)
return stdout
class TestCLI(unittest.TestCase):
def test_reconstruct_debug(self):
def fxn():
Tensor.empty(1, device="NULL").add(2.0).realize()
profile_marker("marker @ 1")
Tensor.empty(1, device="NULL").add(3.0).realize()
out = call_cli(fxn, "-s", "NULL", debug=4)
self.assertIn("void E", out)
self.assertIn("marker @ 1", out)
def test_aggregate(self):
N, CNT = 1024, 5
def fxn():
for _ in range(CNT):
(Tensor.empty(N, N, device="NULL")@Tensor.empty(N, N, device="NULL")).realize()
for _ in range(CNT):
(Tensor.empty(N, N, device="NULL").assign(Tensor.empty(N, N, device="NULL"))).realize()
kernels = [json.loads(line) for line in call_cli(fxn, "-s", "NULL", "-t", "--json").splitlines()]
self.assertEqual(len(kernels), 2)
gemm_summary = [s for s in kernels if s["name"].startswith("r_")][0]
copy_summary = [s for s in kernels if s["name"].startswith("E_")][0]
self.assertEqual(gemm_summary["count"], CNT)
self.assertEqual(copy_summary["count"], CNT)
def test_flops(self):
test_n = [(8, 16), (16, 32), (32, 64)]
def fxn():
@TinyJit
def f(a, b): return (a@a.T), (b@b.T)
a = Tensor.empty(64, 64, device="NULL")
b = Tensor.empty(64, 64, device="NULL")
for i_val, j_val in test_n:
i = Variable("i", 1, 64).bind(i_val)
j = Variable("j", 1, 64).bind(j_val)
Tensor.realize(*f(a[:i], b[:j]))
out = [json.loads(line) for line in call_cli(fxn, "-s", "NULL", "--json").splitlines()]
self.assertEqual(len(out), 3*2)
# flops increases as N gets larger
gflops = [row["fmt"]["FLOPS"] for row in out]
self.assertGreater(gflops[4], gflops[2])
self.assertGreater(gflops[5], gflops[3])
# aggregate flops
out = [json.loads(line) for line in call_cli(fxn, "-s", "NULL", "-t", "--json").splitlines()]
self.assertEqual(len(out), 2)
agg_gflops = [row["fmt"]["FLOPS"] for row in out]
assert all(min(gflops) < v < max(gflops) for v in agg_gflops), f"{agg_gflops}"
if __name__ == "__main__":
unittest.main()