diff --git a/extra/utils.py b/extra/utils.py index af0757e47e..b0536e2b04 100644 --- a/extra/utils.py +++ b/extra/utils.py @@ -208,3 +208,13 @@ def get_child(parent, key): else: obj = getattr(obj, k) return obj + +def _tree(lazydata): + if type(lazydata).__name__ == "LazyBuffer": return [f"━━ realized {lazydata.dtype.name} {lazydata.shape}"] if (lazydata.realized) else _tree(lazydata.op) + if len(lazydata.src) == 0: return [f"━━ {lazydata.op.name} {lazydata.arg if lazydata.arg else ''}"] + lines = [f"━┳ {lazydata.op.name} {lazydata.arg if lazydata.arg else ''}"] + childs = [_tree(c) for c in lazydata.src[:]] + for c in childs[:-1]: lines += [f" ┣{c[0]}"] + [f" ┃{l}" for l in c[1:]] + return lines + [" ┗"+childs[-1][0]] + [" "+l for l in childs[-1][1:]] + +def print_tree(tensor:Tensor):print("\n".join([f"{str(i).rjust(3)} {s}" for i,s in enumerate(_tree(tensor.lazydata))])) \ No newline at end of file