Files
tinygrad/tinygrad/graph.py
wozeparrot 5f3d5cfb02 catch cycles in print_tree (#2891)
* feat: smaller tree on references

* fix: shorter line

* fix: huh

* fix: should be all

* feat: cleaner

* fix: extra imports

* fix: pass by reference
2023-12-21 18:40:37 -08:00

116 lines
5.4 KiB
Python

import os, atexit
from typing import List, Any
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, BufferOps, TernaryOps, Op, LazyOp
from tinygrad.device import Device
from tinygrad.helpers import GRAPH, GRAPHPATH, DEBUG, GlobalCounters, getenv
from tinygrad.codegen.linearizer import UOps, UOp
from tinygrad.shape.symbolic import NumNode
# **** debugging and graphing ****
if DEBUG >= 2:
def print_globalcounters():
if GlobalCounters.time_sum_s == 0: return
print(f"avg: {GlobalCounters.global_ops*1e-9/GlobalCounters.time_sum_s:8.2f} GFLOPS {GlobalCounters.global_mem*1e-9/GlobalCounters.time_sum_s:8.2f} GB/s", # noqa: E501
f"{' '*10}total: {GlobalCounters.kernel_count:5d} kernels {GlobalCounters.global_ops*1e-9:8.2f} GOPS {GlobalCounters.global_mem*1e-9:8.2f} GB {GlobalCounters.time_sum_s*1e3:8.2f} ms") # noqa: E501
atexit.register(print_globalcounters)
G:Any = None
def init_graph():
global G
if G is not None: return
import networkx as nx
G = nx.DiGraph()
def save_graph_exit():
print("saving", G, f"to {GRAPHPATH}.svg")
nx.drawing.nx_pydot.write_dot(G, f'{GRAPHPATH}.dot')
# -Gnslimit=100 can make it finish, but you won't like results
os.system(f'dot -Tsvg {GRAPHPATH}.dot -o {GRAPHPATH}.svg')
atexit.register(save_graph_exit)
node_count = 0
def nm(x):
global node_count
if not hasattr(x, 'node_id'):
setattr(x, 'node_id', node_count)
node_count += 1
return x.node_id
buf_count = 0
def bm(x):
global buf_count
if not hasattr(x, 'buf_id'):
setattr(x, 'buf_id', buf_count)
buf_count += 1
return x.buf_id
def get_sop(op: List[Op]):
op = [x for x in op if x not in BufferOps]
if len(op) <= 2: return '.'.join([str(y).split(".")[1] for y in op][::-1])
if len(op) <= 6: return '.'.join([str(y).split(".")[1][0:3] for y in op][::-1])
return str(len(op))
def str_dtype(dtyp):
ret = str(dtyp)[7:]
return "" if ret == 'float' else f"\n{ret}"
def realized_lazybuffer(lb, num):
if GRAPH:
init_graph()
G.nodes[nm(lb)]['style'] = '"filled,bold"'
G.nodes[nm(lb)]['fillcolor'] = G.nodes[nm(lb)]['fillcolor'][:-2]
G.nodes[nm(lb)]['label'] = '"' + G.nodes[nm(lb)]["label"].replace('"', '') + f'\nK:{num} b:{bm(lb.realized)}"'
def log_lazybuffer(lb, scheduled=False):
top_colors = {LoadOps: '#FFFFa0', UnaryOps: "#c0c0c0", ReduceOps: "#FFA0A0", BinaryOps: "#c0c0c0",
MovementOps: "#80ff80", TernaryOps: "#c0c0c0", BufferOps: '#a0a0ff'}
if GRAPH:
init_graph()
if lb.base != lb:
offset = lb.st.expr_node(NumNode(0))[0]
label = f"{lb.st.shape}\n{lb.st.real_strides()}" + (f"\n{offset}" if offset != 0 else "")
G.add_node(nm(lb), style='"filled,dashed"', fillcolor="#80ff8080", color="black", label=label)
G.add_edge(nm(lb.base), nm(lb), color='#00000060')
lb = lb.base
if lb.realized is None:
for x in lb.srcs:
if nm(x) not in G.nodes: log_lazybuffer(x)
G.add_edge(nm(x), nm(lb), color='#a0a0a0')
label = '"' + \
(str(set(x.shape for x in lb.srcs))+"\n"+str(lb.shape) if lb.op in ReduceOps else str(lb.shape)) + \
str_dtype(lb.dtype)+f"\n{lb.op}"+(f"\n{lb.arg}" if lb.op in {LoadOps.CONST, UnaryOps.CAST} else "") + \
(f"\n{lb.device}" if lb.device != Device.DEFAULT else "") + '"'
G.add_node(nm(lb), style='"filled,dashed"', fillcolor=[v for k,v in top_colors.items() if lb.op in k][0] + "80", color="black", label=label)
if scheduled: G.nodes[nm(lb)]['shape'] = 'box'
else:
if nm(lb) not in G.nodes:
# realized but unseen?
G.add_node(nm(lb), label=f'"{str(lb.base.realized)[5:-1].replace(" ", chr(10))}\nb:{bm(lb.realized)}"', style='filled', fillcolor="#f0c08080")
def _tree(lazydata, cycles, cnt, prefix=""):
cnt[0] += 1
if len(lazydata.src) == 0: return [f"━━ {prefix}{lazydata.op.name} {lazydata.arg if lazydata.arg else ''}"]
if (lid := id(lazydata)) in cycles and cycles[lid][1] > (tcnt := getenv("TREE_CYCLE_CNT", 5)) and tcnt >= 0:
return [f"━⬆︎ goto {cycles[id(lazydata)][0]}: {lazydata.op.name}"]
cycles[lid] = (cnt[0], 1 if lid not in cycles else cycles[lid][1]+1)
lines = [f"━┳ {prefix}{lazydata.op.name} {lazydata.arg if lazydata.arg else ''}"]
childs = [_tree(c, cycles, cnt) for c in lazydata.src[:]]
for c in childs[:-1]: lines += [f"{c[0]}"] + [f"{l}" for l in c[1:]]
return lines + [""+childs[-1][0]] + [" "+l for l in childs[-1][1:]]
def print_tree(lazydata:LazyOp): print("\n".join([f"{str(i).rjust(3)} {s}" for i,s in enumerate(_tree(lazydata, {}, [-1]))]))
def graph_uops(uops:List[UOp]):
import networkx as nx
colors = {UOps.ALU: "#ffffc0", UOps.LOAD: "#ffc0c0", UOps.STORE: "#c0ffc0", UOps.SPECIAL: "#c0c0ff", UOps.CONST: "#e0e0e0",
UOps.DEFINE_GLOBAL: "#ffe0b0", UOps.DEFINE_LOCAL: "#ffe0d0", UOps.DEFINE_ACC: "#f0ffe0",
UOps.LOOP: "#c8a0e0", UOps.PHI: "#e0ffc0", UOps.BARRIER: "#ff8080", UOps.IF: "#c8b0c0"}
G = nx.DiGraph()
for u in uops:
if u.uop == UOps.END: continue
G.add_node(uops.index(u), label=f"{str(u.uop)[5:]}{(' '+str(u.arg)) if u.arg is not None else ''}\n{str(u.dtype)}", style="filled", fillcolor=colors.get(u.uop, "#ffffff")) # noqa: E501
for v in u.vin: G.add_edge(uops.index(v), uops.index(u))
GRAPHPATH = "/tmp/uops"
nx.drawing.nx_pydot.write_dot(G, f'{GRAPHPATH}.dot')
os.system(f'dot -Grankdir=LR -Tsvg {GRAPHPATH}.dot -o {GRAPHPATH}.svg')