diff --git a/extra/torch_backend/backend.py b/extra/torch_backend/backend.py index 04668ee3ef..37a58efc5a 100644 --- a/extra/torch_backend/backend.py +++ b/extra/torch_backend/backend.py @@ -642,10 +642,11 @@ def get_real_tinygrad_buffers(): torch.nn.modules.module.register_module_buffer_registration_hook(register_torch_buffer) from torch.nn.modules import Module -def backward_hook(model:Module, _grad_input, _grad_out): - grads_to_realize = [unwrap(p.grad) for p in model.parameters() if p.grad is not None] - if len(grads_to_realize): Tensor.realize(*grads_to_realize) -def module_hook(module:Module, _name, _submodule): module.register_backward_hook(backward_hook) +def param_hook(_grad): + if _grad is not None and _grad.is_tiny: Tensor.realize(unwrap(_grad)) +def module_hook(module:Module, _name, _submodule): + for param in _submodule.parameters(recurse=False): + if param.requires_grad: param.register_hook(param_hook) torch.nn.modules.module.register_module_module_registration_hook(module_hook) def realize_optimizer_step(optimizer: torch.optim.Optimizer, *args, **kwargs):