Fixed issue with total_weight in nll_loss_forward_decomposition#829
Fixed issue with total_weight in nll_loss_forward_decomposition#829vfdev-5 wants to merge 1 commit intopytorch:mainfrom
Conversation
Chillee
left a comment
There was a problem hiding this comment.
Nice, LGTM!
Ideally, putting this in Python would be awesome, since we have more comprehensive testing for these kinds of decompositions in PyTorch Core (and also allows us to use this decomposition for things like meta tensors and such).
|
@Chillee sure as we discussed that elsewhere, I'll be adding nll_loss_forward to pytorch core for my next task. EDIT: coding the decomposition in pytorch core, it looks like this code is still incorrect for other case. Marking as draft and probably close it later |
Haha, I've run into this a couple times when porting ops into Python :) |
Description:
@Chillee catched that
total_weightoutput is wrong fornll_loss_forward_decompositionC++ implementation:Before this PR:
PR is tested with (as right now there is no way to check total_weight on CI)