torch.func.hessian#
- torch.func.hessian(func, argnums=0)[source]#
- Computes the Hessian of - funcwith respect to the arg(s) at index- argnumvia a forward-over-reverse strategy.- The forward-over-reverse strategy (composing - jacfwd(jacrev(func))) is a good default for good performance. It is possible to compute Hessians through other compositions of- jacfwd()and- jacrev()like- jacfwd(jacfwd(func))or- jacrev(jacrev(func)).- Parameters
- Returns
- Returns a function that takes in the same inputs as - funcand returns the Hessian of- funcwith respect to the arg(s) at- argnums.
 - Note - You may see this API error out with “forward-mode AD not implemented for operator X”. If so, please file a bug report and we will prioritize it. An alternative is to use - jacrev(jacrev(func)), which has better operator coverage.- A basic usage with a R^N -> R^1 function gives a N x N Hessian: - >>> from torch.func import hessian >>> def f(x): >>> return x.sin().sum() >>> >>> x = torch.randn(5) >>> hess = hessian(f)(x) # equivalent to jacfwd(jacrev(f))(x) >>> assert torch.allclose(hess, torch.diag(-x.sin()))