fix viz [pr] (#7519)

* fix viz [pr]

* Update serve.py
This commit is contained in:
George Hotz
2024-11-04 15:02:41 +08:00
committed by GitHub
parent 6bb230287b
commit 9a7cc04843
2 changed files with 5 additions and 1 deletions

View File

@@ -19,6 +19,10 @@ class TestTiny(unittest.TestCase):
out = Tensor.cat(Tensor.ones(8).contiguous(), Tensor.ones(8).contiguous())
self.assertListEqual(out.tolist(), [1]*16)
def test_sum(self):
out = Tensor.ones(256).contiguous().sum()
self.assertEqual(out.item(), 256)
def test_gemm(self, N=4, out_dtype=dtypes.float):
a = Tensor.ones(N,N).contiguous()
b = Tensor.eye(N).contiguous()

View File

@@ -61,7 +61,7 @@ def uop_to_json(x:UOp) -> Dict[int, Tuple[str, str, List[int], str, str]]:
graph: Dict[int, Tuple[str, str, List[int], str, str]] = {}
for u in x.sparents:
if u.op is Ops.CONST: continue
label = f"{str(u.op)[5:]}{(' '+word_wrap(str(u.arg).replace(':', ''))) if u.arg is not None else ''}\n{str(u.dtype)}"
label = f"{str(u.op).split('.')[1]}{(' '+word_wrap(str(u.arg).replace(':', ''))) if u.arg is not None else ''}\n{str(u.dtype)}"
for idx,x in enumerate(u.src):
if x.op is Ops.CONST: label += f"\nCONST{idx} {x.arg:g}"
graph[id(u)] = (label, str(u.dtype), [id(x) for x in u.src if x.op is not Ops.CONST], str(u.arg), uops_colors.get(u.op, "#ffffff"))