mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
update tiny torch backend hook (#12575)
* update the backend to fix torch deprecation warning * use param_hook to avoid full backward hook needlessly firing on inputs which do not require gradients * fix indentation --------- Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
@@ -642,10 +642,11 @@ def get_real_tinygrad_buffers():
|
||||
torch.nn.modules.module.register_module_buffer_registration_hook(register_torch_buffer)
|
||||
|
||||
from torch.nn.modules import Module
|
||||
def backward_hook(model:Module, _grad_input, _grad_out):
|
||||
grads_to_realize = [unwrap(p.grad) for p in model.parameters() if p.grad is not None]
|
||||
if len(grads_to_realize): Tensor.realize(*grads_to_realize)
|
||||
def module_hook(module:Module, _name, _submodule): module.register_backward_hook(backward_hook)
|
||||
def param_hook(_grad):
|
||||
if _grad is not None and _grad.is_tiny: Tensor.realize(unwrap(_grad))
|
||||
def module_hook(module:Module, _name, _submodule):
|
||||
for param in _submodule.parameters(recurse=False):
|
||||
if param.requires_grad: param.register_hook(param_hook)
|
||||
torch.nn.modules.module.register_module_module_registration_hook(module_hook)
|
||||
|
||||
def realize_optimizer_step(optimizer: torch.optim.Optimizer, *args, **kwargs):
|
||||
|
||||
Reference in New Issue
Block a user