fix uops stats

This commit is contained in:
George Hotz
2026-05-31 14:58:03 -07:00
parent 27835b5a31
commit e808f698bc

View File

@@ -3,7 +3,7 @@ from typing import Callable, cast
from dataclasses import dataclass
from tinygrad.helpers import prod, Target, EMULATED_DTYPES
from tinygrad.uop.ops import Ops, UOp, sint, ssimplify, smin, GroupOp, PatternMatcher
from tinygrad.dtype import AddrSpace, PtrDType, DType, dtypes
from tinygrad.dtype import AddrSpace, DType, dtypes
from tinygrad.codegen.opt.tc import TensorCore
from tinygrad.device import Compiler
@@ -41,7 +41,7 @@ class Estimates:
while len(buf.src) and buf.op is not Ops.PARAM: buf = buf.src[0]
if buf.op is Ops.PARAM:
# u.src[0] is INDEX, cap at buffer size for re-reads (e.g. matmul)
accessed = mem.get((buf, u.op), 0) + u.src[0].dtype.base.itemsize * mults
accessed = mem.get((buf, u.op), 0) + u.max_numel() * u.src[0].dtype.itemsize * mults
mem[(buf, u.op)] = smin(accessed, buf.max_numel() * buf.dtype.itemsize)
if u.op is Ops.RANGE:
mult_stack.append(mults)
@@ -51,10 +51,10 @@ class Estimates:
elif u.op is Ops.END: mults = mult_stack.pop(-1)
elif u.op is Ops.SPECIAL: mults *= cast(sint, u.src[0].ssimplify()) # NOTE: we don't push to the mult_stack here, you can't end these
elif u.op is Ops.DEFINE_VAR and u.arg[0] == 'core_id': mults *= u.arg[2] + 1
elif u.op is Ops.LOAD and (not isinstance(u.src[0].dtype, PtrDType) or u.src[0].dtype.addrspace != AddrSpace.REG):
lds += u.dtype.itemsize * mults
elif u.op is Ops.STORE and (not isinstance(u.src[0].dtype, PtrDType) or u.src[0].dtype.addrspace != AddrSpace.REG):
lds += u.src[1].dtype.itemsize * mults
elif u.op is Ops.LOAD and u.src[0].addrspace != AddrSpace.REG:
lds += u.max_numel() * u.dtype.itemsize * mults
elif u.op is Ops.STORE and u.src[0].addrspace != AddrSpace.REG:
lds += u.max_numel() * u.src[1].dtype.itemsize * mults
elif u.op in GroupOp.ALU and u not in dont_count: flops += (mults * (2 if u.op is Ops.MULACC else 1)) * u.dtype.count
elif u.op is Ops.WMMA and u not in dont_count: flops += 2 * prod(u.arg[1]) // u.arg[5] * mults
return Estimates(flops, lds, sum(mem.values()))