From e808f698bcdcdeb0cd69b40fcd2581e55571867b Mon Sep 17 00:00:00 2001 From: George Hotz Date: Sun, 31 May 2026 14:58:03 -0700 Subject: [PATCH] fix uops stats --- tinygrad/renderer/__init__.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tinygrad/renderer/__init__.py b/tinygrad/renderer/__init__.py index 9552b27700..f8f15c146c 100644 --- a/tinygrad/renderer/__init__.py +++ b/tinygrad/renderer/__init__.py @@ -3,7 +3,7 @@ from typing import Callable, cast from dataclasses import dataclass from tinygrad.helpers import prod, Target, EMULATED_DTYPES from tinygrad.uop.ops import Ops, UOp, sint, ssimplify, smin, GroupOp, PatternMatcher -from tinygrad.dtype import AddrSpace, PtrDType, DType, dtypes +from tinygrad.dtype import AddrSpace, DType, dtypes from tinygrad.codegen.opt.tc import TensorCore from tinygrad.device import Compiler @@ -41,7 +41,7 @@ class Estimates: while len(buf.src) and buf.op is not Ops.PARAM: buf = buf.src[0] if buf.op is Ops.PARAM: # u.src[0] is INDEX, cap at buffer size for re-reads (e.g. matmul) - accessed = mem.get((buf, u.op), 0) + u.src[0].dtype.base.itemsize * mults + accessed = mem.get((buf, u.op), 0) + u.max_numel() * u.src[0].dtype.itemsize * mults mem[(buf, u.op)] = smin(accessed, buf.max_numel() * buf.dtype.itemsize) if u.op is Ops.RANGE: mult_stack.append(mults) @@ -51,10 +51,10 @@ class Estimates: elif u.op is Ops.END: mults = mult_stack.pop(-1) elif u.op is Ops.SPECIAL: mults *= cast(sint, u.src[0].ssimplify()) # NOTE: we don't push to the mult_stack here, you can't end these elif u.op is Ops.DEFINE_VAR and u.arg[0] == 'core_id': mults *= u.arg[2] + 1 - elif u.op is Ops.LOAD and (not isinstance(u.src[0].dtype, PtrDType) or u.src[0].dtype.addrspace != AddrSpace.REG): - lds += u.dtype.itemsize * mults - elif u.op is Ops.STORE and (not isinstance(u.src[0].dtype, PtrDType) or u.src[0].dtype.addrspace != AddrSpace.REG): - lds += u.src[1].dtype.itemsize * mults + elif u.op is Ops.LOAD and u.src[0].addrspace != AddrSpace.REG: + lds += u.max_numel() * u.dtype.itemsize * mults + elif u.op is Ops.STORE and u.src[0].addrspace != AddrSpace.REG: + lds += u.max_numel() * u.src[1].dtype.itemsize * mults elif u.op in GroupOp.ALU and u not in dont_count: flops += (mults * (2 if u.op is Ops.MULACC else 1)) * u.dtype.count elif u.op is Ops.WMMA and u not in dont_count: flops += 2 * prod(u.arg[1]) // u.arg[5] * mults return Estimates(flops, lds, sum(mem.values()))