torch.Tensor.register_hook#
- Tensor.register_hook(hook)[source]#
- Registers a backward hook. - The hook will be called every time a gradient with respect to the Tensor is computed. The hook should have the following signature: - hook(grad) -> Tensor or None - The hook should not modify its argument, but it can optionally return a new gradient which will be used in place of - grad.- This function returns a handle with a method - handle.remove()that removes the hook from the module.- Note - See Backward Hooks execution for more information on how when this hook is executed, and how its execution is ordered relative to other hooks. - Example: - >>> v = torch.tensor([0., 0., 0.], requires_grad=True) >>> h = v.register_hook(lambda grad: grad * 2) # double the gradient >>> v.backward(torch.tensor([1., 2., 3.])) >>> v.grad 2 4 6 [torch.FloatTensor of size (3,)] >>> h.remove() # removes the hook