diff --git a/extra/llama_kernels/cast_amax/__init__.py b/extra/llama_kernels/cast_amax/__init__.py index 45161bc8a1..d35c4f5355 100644 --- a/extra/llama_kernels/cast_amax/__init__.py +++ b/extra/llama_kernels/cast_amax/__init__.py @@ -5,10 +5,10 @@ from tinygrad.uop.ops import UOp, Ops, KernelInfo from tinygrad.renderer import Estimates from extra.llama_kernels import FP8_MAX, NUM_WG, THREADS_PER_WG, compile_cpp, alloc_like, alloc_local, scalar_amax, dname_of -# module-level mailbox: grad_xw13 UOp -> (grad_xw13_fp8 UOp, inv_scale UOp, new_amax UOp, store_effect) +# module-level mailbox: grad_xw13 UOp -> (grad_xw13_fp8 UOp, inv_scale UOp) # lets cdna_asm_gemm's bwd reuse the fp8 companion produced by the fused silu_mul bwd kernel # instead of doing a redundant bf16 -> fp8 quantize. -_grad_fp8_mailbox:dict = {} +_grad_fp8_mailbox:dict[UOp, tuple[UOp, UOp]] = {} @functools.cache def _custom_fused_bwd_w13(grad_xw13:UOp, grad_xw13_fp8:UOp, grad_amax_buf:UOp,