mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-11 23:46:02 +08:00
fix uops stats
This commit is contained in:
@@ -3,7 +3,7 @@ from typing import Callable, cast
|
||||
from dataclasses import dataclass
|
||||
from tinygrad.helpers import prod, Target, EMULATED_DTYPES
|
||||
from tinygrad.uop.ops import Ops, UOp, sint, ssimplify, smin, GroupOp, PatternMatcher
|
||||
from tinygrad.dtype import AddrSpace, PtrDType, DType, dtypes
|
||||
from tinygrad.dtype import AddrSpace, DType, dtypes
|
||||
from tinygrad.codegen.opt.tc import TensorCore
|
||||
from tinygrad.device import Compiler
|
||||
|
||||
@@ -41,7 +41,7 @@ class Estimates:
|
||||
while len(buf.src) and buf.op is not Ops.PARAM: buf = buf.src[0]
|
||||
if buf.op is Ops.PARAM:
|
||||
# u.src[0] is INDEX, cap at buffer size for re-reads (e.g. matmul)
|
||||
accessed = mem.get((buf, u.op), 0) + u.src[0].dtype.base.itemsize * mults
|
||||
accessed = mem.get((buf, u.op), 0) + u.max_numel() * u.src[0].dtype.itemsize * mults
|
||||
mem[(buf, u.op)] = smin(accessed, buf.max_numel() * buf.dtype.itemsize)
|
||||
if u.op is Ops.RANGE:
|
||||
mult_stack.append(mults)
|
||||
@@ -51,10 +51,10 @@ class Estimates:
|
||||
elif u.op is Ops.END: mults = mult_stack.pop(-1)
|
||||
elif u.op is Ops.SPECIAL: mults *= cast(sint, u.src[0].ssimplify()) # NOTE: we don't push to the mult_stack here, you can't end these
|
||||
elif u.op is Ops.DEFINE_VAR and u.arg[0] == 'core_id': mults *= u.arg[2] + 1
|
||||
elif u.op is Ops.LOAD and (not isinstance(u.src[0].dtype, PtrDType) or u.src[0].dtype.addrspace != AddrSpace.REG):
|
||||
lds += u.dtype.itemsize * mults
|
||||
elif u.op is Ops.STORE and (not isinstance(u.src[0].dtype, PtrDType) or u.src[0].dtype.addrspace != AddrSpace.REG):
|
||||
lds += u.src[1].dtype.itemsize * mults
|
||||
elif u.op is Ops.LOAD and u.src[0].addrspace != AddrSpace.REG:
|
||||
lds += u.max_numel() * u.dtype.itemsize * mults
|
||||
elif u.op is Ops.STORE and u.src[0].addrspace != AddrSpace.REG:
|
||||
lds += u.max_numel() * u.src[1].dtype.itemsize * mults
|
||||
elif u.op in GroupOp.ALU and u not in dont_count: flops += (mults * (2 if u.op is Ops.MULACC else 1)) * u.dtype.count
|
||||
elif u.op is Ops.WMMA and u not in dont_count: flops += 2 * prod(u.arg[1]) // u.arg[5] * mults
|
||||
return Estimates(flops, lds, sum(mem.values()))
|
||||
|
||||
Reference in New Issue
Block a user