utils.printtree (#1816)

* utils.printtree

* linter compliance

* rename to print_tree
This commit is contained in:
kormann
2023-09-08 08:08:57 +02:00
committed by GitHub
parent 4613c9e77c
commit 7ac65a93b4

View File

@@ -208,3 +208,13 @@ def get_child(parent, key):
else:
obj = getattr(obj, k)
return obj
def _tree(lazydata):
if type(lazydata).__name__ == "LazyBuffer": return [f"━━ realized {lazydata.dtype.name} {lazydata.shape}"] if (lazydata.realized) else _tree(lazydata.op)
if len(lazydata.src) == 0: return [f"━━ {lazydata.op.name} {lazydata.arg if lazydata.arg else ''}"]
lines = [f"━┳ {lazydata.op.name} {lazydata.arg if lazydata.arg else ''}"]
childs = [_tree(c) for c in lazydata.src[:]]
for c in childs[:-1]: lines += [f"{c[0]}"] + [f"{l}" for l in c[1:]]
return lines + [""+childs[-1][0]] + [" "+l for l in childs[-1][1:]]
def print_tree(tensor:Tensor):print("\n".join([f"{str(i).rjust(3)} {s}" for i,s in enumerate(_tree(tensor.lazydata))]))