update tiny torch backend hook (#12575)

* update the backend to fix torch deprecation warning

* use param_hook to avoid full backward hook needlessly firing on inputs which do not require gradients

* fix indentation

---------

Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
Daniel
2025-10-15 20:02:33 +02:00
committed by GitHub
parent db5ae846aa
commit d65bd669f8

View File

@@ -642,10 +642,11 @@ def get_real_tinygrad_buffers():
torch.nn.modules.module.register_module_buffer_registration_hook(register_torch_buffer)
from torch.nn.modules import Module
def backward_hook(model:Module, _grad_input, _grad_out):
grads_to_realize = [unwrap(p.grad) for p in model.parameters() if p.grad is not None]
if len(grads_to_realize): Tensor.realize(*grads_to_realize)
def module_hook(module:Module, _name, _submodule): module.register_backward_hook(backward_hook)
def param_hook(_grad):
if _grad is not None and _grad.is_tiny: Tensor.realize(unwrap(_grad))
def module_hook(module:Module, _name, _submodule):
for param in _submodule.parameters(recurse=False):
if param.requires_grad: param.register_hook(param_hook)
torch.nn.modules.module.register_module_module_registration_hook(module_hook)
def realize_optimizer_step(optimizer: torch.optim.Optimizer, *args, **kwargs):